diff --git a/libc/config/linux/aarch64/entrypoints.txt b/libc/config/linux/aarch64/entrypoints.txt index db96a80051a8d..2b2d0985a8992 100644 --- a/libc/config/linux/aarch64/entrypoints.txt +++ b/libc/config/linux/aarch64/entrypoints.txt @@ -503,6 +503,7 @@ if(LIBC_TYPES_HAS_FLOAT16) libc.src.math.canonicalizef16 libc.src.math.ceilf16 libc.src.math.copysignf16 + libc.src.math.f16sqrtf libc.src.math.fabsf16 libc.src.math.fdimf16 libc.src.math.floorf16 diff --git a/libc/config/linux/x86_64/entrypoints.txt b/libc/config/linux/x86_64/entrypoints.txt index 355eaf33ace6d..2d36ca296c3a4 100644 --- a/libc/config/linux/x86_64/entrypoints.txt +++ b/libc/config/linux/x86_64/entrypoints.txt @@ -535,6 +535,7 @@ if(LIBC_TYPES_HAS_FLOAT16) libc.src.math.canonicalizef16 libc.src.math.ceilf16 libc.src.math.copysignf16 + libc.src.math.f16sqrtf libc.src.math.fabsf16 libc.src.math.fdimf16 libc.src.math.floorf16 diff --git a/libc/docs/math/index.rst b/libc/docs/math/index.rst index d556885eda622..790786147c164 100644 --- a/libc/docs/math/index.rst +++ b/libc/docs/math/index.rst @@ -280,6 +280,8 @@ Higher Math Functions +-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+ | fma | |check| | |check| | | | | 7.12.13.1 | F.10.10.1 | +-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+ +| f16sqrt | |check| | | | N/A | | 7.12.14.6 | F.10.11 | ++-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+ | fsqrt | N/A | | | N/A | | 7.12.14.6 | F.10.11 | +-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+ | hypot | |check| | |check| | | | | 7.12.7.4 | F.10.4.4 | diff --git a/libc/spec/stdc.td b/libc/spec/stdc.td index b134ec00a7d7a..7c4135032a0b2 100644 --- a/libc/spec/stdc.td +++ b/libc/spec/stdc.td @@ -714,6 +714,8 @@ def StdC : StandardSpec<"stdc"> { GuardedFunctionSpec<"totalorderf16", RetValSpec, [ArgSpec, ArgSpec], "LIBC_TYPES_HAS_FLOAT16">, GuardedFunctionSpec<"totalordermagf16", RetValSpec, [ArgSpec, ArgSpec], "LIBC_TYPES_HAS_FLOAT16">, + + GuardedFunctionSpec<"f16sqrtf", RetValSpec, [ArgSpec], "LIBC_TYPES_HAS_FLOAT16">, ] >; diff --git a/libc/src/__support/FPUtil/generic/CMakeLists.txt b/libc/src/__support/FPUtil/generic/CMakeLists.txt index 09eede1570962..595656e3e8d90 100644 --- a/libc/src/__support/FPUtil/generic/CMakeLists.txt +++ b/libc/src/__support/FPUtil/generic/CMakeLists.txt @@ -4,6 +4,7 @@ add_header_library( sqrt.h sqrt_80_bit_long_double.h DEPENDS + libc.hdr.fenv_macros libc.src.__support.common libc.src.__support.CPP.bit libc.src.__support.CPP.type_traits diff --git a/libc/src/__support/FPUtil/generic/sqrt.h b/libc/src/__support/FPUtil/generic/sqrt.h index 7e7600ba6502a..d6e894fdfe021 100644 --- a/libc/src/__support/FPUtil/generic/sqrt.h +++ b/libc/src/__support/FPUtil/generic/sqrt.h @@ -18,6 +18,8 @@ #include "src/__support/common.h" #include "src/__support/uint128.h" +#include "hdr/fenv_macros.h" + namespace LIBC_NAMESPACE { namespace fputil { @@ -64,40 +66,50 @@ LIBC_INLINE void normalize(int &exponent, UInt128 &mantissa) { // Correctly rounded IEEE 754 SQRT for all rounding modes. // Shift-and-add algorithm. -template -LIBC_INLINE cpp::enable_if_t, T> sqrt(T x) { - - if constexpr (internal::SpecialLongDouble::VALUE) { +template +LIBC_INLINE cpp::enable_if_t && + cpp::is_floating_point_v && + sizeof(OutType) <= sizeof(InType), + OutType> +sqrt(InType x) { + if constexpr (internal::SpecialLongDouble::VALUE && + internal::SpecialLongDouble::VALUE) { // Special 80-bit long double. return x86::sqrt(x); } else { // IEEE floating points formats. - using FPBits_t = typename fputil::FPBits; - using StorageType = typename FPBits_t::StorageType; - constexpr StorageType ONE = StorageType(1) << FPBits_t::FRACTION_LEN; - constexpr auto FLT_NAN = FPBits_t::quiet_nan().get_val(); - - FPBits_t bits(x); - - if (bits == FPBits_t::inf(Sign::POS) || bits.is_zero() || bits.is_nan()) { + using OutFPBits = typename fputil::FPBits; + using OutStorageType = typename OutFPBits::StorageType; + using InFPBits = typename fputil::FPBits; + using InStorageType = typename InFPBits::StorageType; + constexpr InStorageType ONE = InStorageType(1) << InFPBits::FRACTION_LEN; + constexpr auto FLT_NAN = OutFPBits::quiet_nan().get_val(); + constexpr int EXTRA_FRACTION_LEN = + InFPBits::FRACTION_LEN - OutFPBits::FRACTION_LEN; + constexpr InStorageType EXTRA_FRACTION_MASK = + (InStorageType(1) << EXTRA_FRACTION_LEN) - 1; + + InFPBits bits(x); + + if (bits == InFPBits::inf(Sign::POS) || bits.is_zero() || bits.is_nan()) { // sqrt(+Inf) = +Inf // sqrt(+0) = +0 // sqrt(-0) = -0 // sqrt(NaN) = NaN // sqrt(-NaN) = -NaN - return x; + return static_cast(x); } else if (bits.is_neg()) { // sqrt(-Inf) = NaN // sqrt(-x) = NaN return FLT_NAN; } else { int x_exp = bits.get_exponent(); - StorageType x_mant = bits.get_mantissa(); + InStorageType x_mant = bits.get_mantissa(); // Step 1a: Normalize denormal input and append hidden bit to the mantissa if (bits.is_subnormal()) { ++x_exp; // let x_exp be the correct exponent of ONE bit. - internal::normalize(x_exp, x_mant); + internal::normalize(x_exp, x_mant); } else { x_mant |= ONE; } @@ -120,12 +132,13 @@ LIBC_INLINE cpp::enable_if_t, T> sqrt(T x) { // So the nth digit y_n of the mantissa of sqrt(x) can be found by: // y_n = 1 if 2*r(n-1) >= 2*y(n - 1) + 2^(-n-1) // 0 otherwise. - StorageType y = ONE; - StorageType r = x_mant - ONE; + InStorageType y = ONE; + InStorageType r = x_mant - ONE; - for (StorageType current_bit = ONE >> 1; current_bit; current_bit >>= 1) { + for (InStorageType current_bit = ONE >> 1; current_bit; + current_bit >>= 1) { r <<= 1; - StorageType tmp = (y << 1) + current_bit; // 2*y(n - 1) + 2^(-n-1) + InStorageType tmp = (y << 1) + current_bit; // 2*y(n - 1) + 2^(-n-1) if (r >= tmp) { r -= tmp; y += current_bit; @@ -133,34 +146,91 @@ LIBC_INLINE cpp::enable_if_t, T> sqrt(T x) { } // We compute one more iteration in order to round correctly. - bool lsb = static_cast(y & 1); // Least significant bit - bool rb = false; // Round bit + bool lsb = (y & (InStorageType(1) << EXTRA_FRACTION_LEN)) != + 0; // Least significant bit + bool rb = false; // Round bit r <<= 2; - StorageType tmp = (y << 2) + 1; + InStorageType tmp = (y << 2) + 1; if (r >= tmp) { r -= tmp; rb = true; } + bool sticky = false; + + if constexpr (EXTRA_FRACTION_LEN > 0) { + sticky = rb || (y & EXTRA_FRACTION_MASK) != 0; + rb = (y & (InStorageType(1) << (EXTRA_FRACTION_LEN - 1))) != 0; + } + // Remove hidden bit and append the exponent field. - x_exp = ((x_exp >> 1) + FPBits_t::EXP_BIAS); + x_exp = ((x_exp >> 1) + OutFPBits::EXP_BIAS); + + OutStorageType y_out = static_cast( + ((y - ONE) >> EXTRA_FRACTION_LEN) | + (static_cast(x_exp) << OutFPBits::FRACTION_LEN)); + + if constexpr (EXTRA_FRACTION_LEN > 0) { + if (x_exp >= OutFPBits::MAX_BIASED_EXPONENT) { + switch (quick_get_round()) { + case FE_TONEAREST: + case FE_UPWARD: + return OutFPBits::inf().get_val(); + default: + return OutFPBits::max_normal().get_val(); + } + } + + if (x_exp < + -OutFPBits::EXP_BIAS - OutFPBits::SIG_LEN + EXTRA_FRACTION_LEN) { + switch (quick_get_round()) { + case FE_UPWARD: + return OutFPBits::min_subnormal().get_val(); + default: + return OutType(0.0); + } + } - y = (y - ONE) | - (static_cast(x_exp) << FPBits_t::FRACTION_LEN); + if (x_exp <= 0) { + int underflow_extra_fraction_len = EXTRA_FRACTION_LEN - x_exp + 1; + InStorageType underflow_extra_fraction_mask = + (InStorageType(1) << underflow_extra_fraction_len) - 1; + + rb = (y & (InStorageType(1) << (underflow_extra_fraction_len - 1))) != + 0; + OutStorageType subnormal_mant = + static_cast(y >> underflow_extra_fraction_len); + lsb = (subnormal_mant & 1) != 0; + sticky = sticky || (y & underflow_extra_fraction_mask) != 0; + + switch (quick_get_round()) { + case FE_TONEAREST: + if (rb && (lsb || sticky)) + ++subnormal_mant; + break; + case FE_UPWARD: + if (rb || sticky) + ++subnormal_mant; + break; + } + + return cpp::bit_cast(subnormal_mant); + } + } switch (quick_get_round()) { case FE_TONEAREST: // Round to nearest, ties to even if (rb && (lsb || (r != 0))) - ++y; + ++y_out; break; case FE_UPWARD: - if (rb || (r != 0)) - ++y; + if (rb || (r != 0) || sticky) + ++y_out; break; } - return cpp::bit_cast(y); + return cpp::bit_cast(y_out); } } } diff --git a/libc/src/math/CMakeLists.txt b/libc/src/math/CMakeLists.txt index 2446c293b8ef5..df8e6c0b253da 100644 --- a/libc/src/math/CMakeLists.txt +++ b/libc/src/math/CMakeLists.txt @@ -99,6 +99,8 @@ add_math_entrypoint_object(exp10f) add_math_entrypoint_object(expm1) add_math_entrypoint_object(expm1f) +add_math_entrypoint_object(f16sqrtf) + add_math_entrypoint_object(fabs) add_math_entrypoint_object(fabsf) add_math_entrypoint_object(fabsl) diff --git a/libc/src/math/f16sqrtf.h b/libc/src/math/f16sqrtf.h new file mode 100644 index 0000000000000..197ebe6db8016 --- /dev/null +++ b/libc/src/math/f16sqrtf.h @@ -0,0 +1,20 @@ +//===-- Implementation header for f16sqrtf ----------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIBC_SRC_MATH_F16SQRTF_H +#define LLVM_LIBC_SRC_MATH_F16SQRTF_H + +#include "src/__support/macros/properties/types.h" + +namespace LIBC_NAMESPACE { + +float16 f16sqrtf(float x); + +} // namespace LIBC_NAMESPACE + +#endif // LLVM_LIBC_SRC_MATH_F16SQRTF_H diff --git a/libc/src/math/generic/CMakeLists.txt b/libc/src/math/generic/CMakeLists.txt index 673bef516b13d..f1f7d6c367be2 100644 --- a/libc/src/math/generic/CMakeLists.txt +++ b/libc/src/math/generic/CMakeLists.txt @@ -3601,3 +3601,16 @@ add_entrypoint_object( COMPILE_OPTIONS -O3 ) + +add_entrypoint_object( + f16sqrtf + SRCS + f16sqrtf.cpp + HDRS + ../f16sqrtf.h + DEPENDS + libc.src.__support.macros.properties.types + libc.src.__support.FPUtil.sqrt + COMPILE_OPTIONS + -O3 +) diff --git a/libc/src/math/generic/acosf.cpp b/libc/src/math/generic/acosf.cpp index e6e28d43ef61f..f02edec267174 100644 --- a/libc/src/math/generic/acosf.cpp +++ b/libc/src/math/generic/acosf.cpp @@ -113,7 +113,7 @@ LLVM_LIBC_FUNCTION(float, acosf, (float x)) { xbits.set_sign(Sign::POS); double xd = static_cast(xbits.get_val()); double u = fputil::multiply_add(-0.5, xd, 0.5); - double cv = 2 * fputil::sqrt(u); + double cv = 2 * fputil::sqrt(u); double r3 = asin_eval(u); double r = fputil::multiply_add(cv * u, r3, cv); diff --git a/libc/src/math/generic/acoshf.cpp b/libc/src/math/generic/acoshf.cpp index a4a75a7b04385..9422ec63e1ce2 100644 --- a/libc/src/math/generic/acoshf.cpp +++ b/libc/src/math/generic/acoshf.cpp @@ -66,8 +66,8 @@ LLVM_LIBC_FUNCTION(float, acoshf, (float x)) { double x_d = static_cast(x); // acosh(x) = log(x + sqrt(x^2 - 1)) - return static_cast( - log_eval(x_d + fputil::sqrt(fputil::multiply_add(x_d, x_d, -1.0)))); + return static_cast(log_eval( + x_d + fputil::sqrt(fputil::multiply_add(x_d, x_d, -1.0)))); } } // namespace LIBC_NAMESPACE diff --git a/libc/src/math/generic/asinf.cpp b/libc/src/math/generic/asinf.cpp index d9133333d2561..c4afca493a713 100644 --- a/libc/src/math/generic/asinf.cpp +++ b/libc/src/math/generic/asinf.cpp @@ -144,7 +144,7 @@ LLVM_LIBC_FUNCTION(float, asinf, (float x)) { double sign = SIGN[x_sign]; double xd = static_cast(xbits.get_val()); double u = fputil::multiply_add(-0.5, xd, 0.5); - double c1 = sign * (-2 * fputil::sqrt(u)); + double c1 = sign * (-2 * fputil::sqrt(u)); double c2 = fputil::multiply_add(sign, M_MATH_PI_2, c1); double c3 = c1 * u; diff --git a/libc/src/math/generic/asinhf.cpp b/libc/src/math/generic/asinhf.cpp index 6e351786e3eca..82dc2a31ebc22 100644 --- a/libc/src/math/generic/asinhf.cpp +++ b/libc/src/math/generic/asinhf.cpp @@ -97,9 +97,9 @@ LLVM_LIBC_FUNCTION(float, asinhf, (float x)) { // asinh(x) = log(x + sqrt(x^2 + 1)) return static_cast( - x_sign * - log_eval(fputil::multiply_add( - x_d, x_sign, fputil::sqrt(fputil::multiply_add(x_d, x_d, 1.0))))); + x_sign * log_eval(fputil::multiply_add( + x_d, x_sign, + fputil::sqrt(fputil::multiply_add(x_d, x_d, 1.0))))); } } // namespace LIBC_NAMESPACE diff --git a/libc/src/math/generic/f16sqrtf.cpp b/libc/src/math/generic/f16sqrtf.cpp new file mode 100644 index 0000000000000..1f7ee2df29e86 --- /dev/null +++ b/libc/src/math/generic/f16sqrtf.cpp @@ -0,0 +1,19 @@ +//===-- Implementation of f16sqrtf function -------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "src/math/f16sqrtf.h" +#include "src/__support/FPUtil/sqrt.h" +#include "src/__support/common.h" + +namespace LIBC_NAMESPACE { + +LLVM_LIBC_FUNCTION(float16, f16sqrtf, (float x)) { + return fputil::sqrt(x); +} + +} // namespace LIBC_NAMESPACE diff --git a/libc/src/math/generic/hypotf.cpp b/libc/src/math/generic/hypotf.cpp index ffbf706aefaf6..b09d09ad7f9c9 100644 --- a/libc/src/math/generic/hypotf.cpp +++ b/libc/src/math/generic/hypotf.cpp @@ -42,7 +42,7 @@ LLVM_LIBC_FUNCTION(float, hypotf, (float x, float y)) { double err = (x_sq >= y_sq) ? (sum_sq - x_sq) - y_sq : (sum_sq - y_sq) - x_sq; // Take sqrt in double precision. - DoubleBits result(fputil::sqrt(sum_sq)); + DoubleBits result(fputil::sqrt(sum_sq)); if (!DoubleBits(sum_sq).is_inf_or_nan()) { // Correct rounding. diff --git a/libc/src/math/generic/powf.cpp b/libc/src/math/generic/powf.cpp index 59efc3f424c76..13c04240f59c2 100644 --- a/libc/src/math/generic/powf.cpp +++ b/libc/src/math/generic/powf.cpp @@ -562,7 +562,7 @@ LLVM_LIBC_FUNCTION(float, powf, (float x, float y)) { switch (y_u) { case 0x3f00'0000: // y = 0.5f // pow(x, 1/2) = sqrt(x) - return fputil::sqrt(x); + return fputil::sqrt(x); case 0x3f80'0000: // y = 1.0f return x; case 0x4000'0000: // y = 2.0f diff --git a/libc/src/math/generic/sqrt.cpp b/libc/src/math/generic/sqrt.cpp index b4d02785dcb43..f33b0a2cdcf74 100644 --- a/libc/src/math/generic/sqrt.cpp +++ b/libc/src/math/generic/sqrt.cpp @@ -12,6 +12,6 @@ namespace LIBC_NAMESPACE { -LLVM_LIBC_FUNCTION(double, sqrt, (double x)) { return fputil::sqrt(x); } +LLVM_LIBC_FUNCTION(double, sqrt, (double x)) { return fputil::sqrt(x); } } // namespace LIBC_NAMESPACE diff --git a/libc/src/math/generic/sqrtf.cpp b/libc/src/math/generic/sqrtf.cpp index bc74252295b3a..26a53e9077c1c 100644 --- a/libc/src/math/generic/sqrtf.cpp +++ b/libc/src/math/generic/sqrtf.cpp @@ -12,6 +12,6 @@ namespace LIBC_NAMESPACE { -LLVM_LIBC_FUNCTION(float, sqrtf, (float x)) { return fputil::sqrt(x); } +LLVM_LIBC_FUNCTION(float, sqrtf, (float x)) { return fputil::sqrt(x); } } // namespace LIBC_NAMESPACE diff --git a/libc/src/math/generic/sqrtf128.cpp b/libc/src/math/generic/sqrtf128.cpp index 0196c3e0a96ae..70e28ddb692d4 100644 --- a/libc/src/math/generic/sqrtf128.cpp +++ b/libc/src/math/generic/sqrtf128.cpp @@ -12,6 +12,8 @@ namespace LIBC_NAMESPACE { -LLVM_LIBC_FUNCTION(float128, sqrtf128, (float128 x)) { return fputil::sqrt(x); } +LLVM_LIBC_FUNCTION(float128, sqrtf128, (float128 x)) { + return fputil::sqrt(x); +} } // namespace LIBC_NAMESPACE diff --git a/libc/src/math/generic/sqrtl.cpp b/libc/src/math/generic/sqrtl.cpp index b2aaa279f9c2a..9f0cc87853823 100644 --- a/libc/src/math/generic/sqrtl.cpp +++ b/libc/src/math/generic/sqrtl.cpp @@ -13,7 +13,7 @@ namespace LIBC_NAMESPACE { LLVM_LIBC_FUNCTION(long double, sqrtl, (long double x)) { - return fputil::sqrt(x); + return fputil::sqrt(x); } } // namespace LIBC_NAMESPACE diff --git a/libc/test/src/math/exhaustive/CMakeLists.txt b/libc/test/src/math/exhaustive/CMakeLists.txt index 938e519aff084..34df8720ed4db 100644 --- a/libc/test/src/math/exhaustive/CMakeLists.txt +++ b/libc/test/src/math/exhaustive/CMakeLists.txt @@ -420,3 +420,18 @@ add_fp_unittest( LINK_LIBRARIES -lpthread ) + +add_fp_unittest( + f16sqrtf_test + NO_RUN_POSTBUILD + NEED_MPFR + SUITE + libc_math_exhaustive_tests + SRCS + f16sqrtf_test.cpp + DEPENDS + .exhaustive_test + libc.src.math.f16sqrtf + LINK_LIBRARIES + -lpthread +) diff --git a/libc/test/src/math/exhaustive/exhaustive_test.h b/libc/test/src/math/exhaustive/exhaustive_test.h index c4ae382688a03..13e272783250b 100644 --- a/libc/test/src/math/exhaustive/exhaustive_test.h +++ b/libc/test/src/math/exhaustive/exhaustive_test.h @@ -35,16 +35,16 @@ // LlvmLibcUnaryOpExhaustiveMathTest. namespace mpfr = LIBC_NAMESPACE::testing::mpfr; -template using UnaryOp = T(T); +template +using UnaryOp = OutType(InType); -template Func> +template Func> struct UnaryOpChecker : public virtual LIBC_NAMESPACE::testing::Test { - using FloatType = T; + using FloatType = InType; using FPBits = LIBC_NAMESPACE::fputil::FPBits; using StorageType = typename FPBits::StorageType; - static constexpr UnaryOp *FUNC = Func; - // Check in a range, return the number of failures. uint64_t check(StorageType start, StorageType stop, mpfr::RoundingMode rounding) { @@ -57,11 +57,11 @@ struct UnaryOpChecker : public virtual LIBC_NAMESPACE::testing::Test { FPBits xbits(bits); FloatType x = xbits.get_val(); bool correct = - TEST_MPFR_MATCH_ROUNDING_SILENTLY(Op, x, FUNC(x), 0.5, rounding); + TEST_MPFR_MATCH_ROUNDING_SILENTLY(Op, x, Func(x), 0.5, rounding); failed += (!correct); // Uncomment to print out failed values. // if (!correct) { - // TEST_MPFR_MATCH(Op::Operation, x, Op::func(x), 0.5, rounding); + // EXPECT_MPFR_MATCH_ROUNDING(Op, x, Func(x), 0.5, rounding); // } } while (bits++ < stop); return failed; @@ -169,4 +169,9 @@ struct LlvmLibcExhaustiveMathTest template Func> using LlvmLibcUnaryOpExhaustiveMathTest = - LlvmLibcExhaustiveMathTest>; + LlvmLibcExhaustiveMathTest>; + +template Func> +using LlvmLibcUnaryNarrowingOpExhaustiveMathTest = + LlvmLibcExhaustiveMathTest>; diff --git a/libc/test/src/math/exhaustive/f16sqrtf_test.cpp b/libc/test/src/math/exhaustive/f16sqrtf_test.cpp new file mode 100644 index 0000000000000..3a42ff8e0725d --- /dev/null +++ b/libc/test/src/math/exhaustive/f16sqrtf_test.cpp @@ -0,0 +1,25 @@ +//===-- Exhaustive test for f16sqrtf --------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "exhaustive_test.h" +#include "src/math/f16sqrtf.h" +#include "utils/MPFRWrapper/MPFRUtils.h" + +namespace mpfr = LIBC_NAMESPACE::testing::mpfr; + +using LlvmLibcF16sqrtfExhaustiveTest = + LlvmLibcUnaryNarrowingOpExhaustiveMathTest< + float16, float, mpfr::Operation::Sqrt, LIBC_NAMESPACE::f16sqrtf>; + +// Range: [0, Inf]; +static constexpr uint32_t POS_START = 0x0000'0000U; +static constexpr uint32_t POS_STOP = 0x7f80'0000U; + +TEST_F(LlvmLibcF16sqrtfExhaustiveTest, PostiveRange) { + test_full_range_all_roundings(POS_START, POS_STOP); +} diff --git a/libc/test/src/math/smoke/CMakeLists.txt b/libc/test/src/math/smoke/CMakeLists.txt index 68cd412b14e9d..3bb87d2b0d0f3 100644 --- a/libc/test/src/math/smoke/CMakeLists.txt +++ b/libc/test/src/math/smoke/CMakeLists.txt @@ -2504,9 +2504,10 @@ add_fp_unittest( libc-math-smoke-tests SRCS sqrtf_test.cpp + HDRS + SqrtTest.h DEPENDS libc.src.math.sqrtf - libc.src.__support.FPUtil.fp_bits ) add_fp_unittest( @@ -2515,9 +2516,10 @@ add_fp_unittest( libc-math-smoke-tests SRCS sqrt_test.cpp + HDRS + SqrtTest.h DEPENDS libc.src.math.sqrt - libc.src.__support.FPUtil.fp_bits ) add_fp_unittest( @@ -2526,9 +2528,10 @@ add_fp_unittest( libc-math-smoke-tests SRCS sqrtl_test.cpp + HDRS + SqrtTest.h DEPENDS libc.src.math.sqrtl - libc.src.__support.FPUtil.fp_bits ) add_fp_unittest( @@ -2537,9 +2540,10 @@ add_fp_unittest( libc-math-smoke-tests SRCS sqrtf128_test.cpp + HDRS + SqrtTest.h DEPENDS libc.src.math.sqrtf128 - libc.src.__support.FPUtil.fp_bits ) add_fp_unittest( @@ -2548,9 +2552,9 @@ add_fp_unittest( libc-math-smoke-tests SRCS generic_sqrtf_test.cpp + HDRS + SqrtTest.h DEPENDS - libc.src.math.sqrtf - libc.src.__support.FPUtil.fp_bits libc.src.__support.FPUtil.generic.sqrt COMPILE_OPTIONS -O3 @@ -2562,9 +2566,9 @@ add_fp_unittest( libc-math-smoke-tests SRCS generic_sqrt_test.cpp + HDRS + SqrtTest.h DEPENDS - libc.src.math.sqrt - libc.src.__support.FPUtil.fp_bits libc.src.__support.FPUtil.generic.sqrt COMPILE_OPTIONS -O3 @@ -2576,9 +2580,9 @@ add_fp_unittest( libc-math-smoke-tests SRCS generic_sqrtl_test.cpp + HDRS + SqrtTest.h DEPENDS - libc.src.math.sqrtl - libc.src.__support.FPUtil.fp_bits libc.src.__support.FPUtil.generic.sqrt COMPILE_OPTIONS -O3 @@ -2590,9 +2594,9 @@ add_fp_unittest( libc-math-smoke-tests SRCS generic_sqrtf128_test.cpp + HDRS + SqrtTest.h DEPENDS - libc.src.math.sqrtf128 - libc.src.__support.FPUtil.fp_bits libc.src.__support.FPUtil.generic.sqrt COMPILE_OPTIONS -O3 @@ -3543,3 +3547,15 @@ add_fp_unittest( DEPENDS libc.src.math.totalordermagf16 ) + +add_fp_unittest( + f16sqrtf_test + SUITE + libc-math-smoke-tests + SRCS + f16sqrtf_test.cpp + HDRS + SqrtTest.h + DEPENDS + libc.src.math.f16sqrtf +) diff --git a/libc/test/src/math/smoke/SqrtTest.h b/libc/test/src/math/smoke/SqrtTest.h index 8afacaf01ae42..ce9f2f85b4604 100644 --- a/libc/test/src/math/smoke/SqrtTest.h +++ b/libc/test/src/math/smoke/SqrtTest.h @@ -6,37 +6,35 @@ // //===----------------------------------------------------------------------===// -#include "src/__support/CPP/bit.h" #include "test/UnitTest/FEnvSafeTest.h" #include "test/UnitTest/FPMatcher.h" #include "test/UnitTest/Test.h" -#include "hdr/math_macros.h" - -template +template class SqrtTest : public LIBC_NAMESPACE::testing::FEnvSafeTest { - DECLARE_SPECIAL_CONSTANTS(T) - - static constexpr StorageType HIDDEN_BIT = - StorageType(1) << LIBC_NAMESPACE::fputil::FPBits::FRACTION_LEN; + DECLARE_SPECIAL_CONSTANTS(OutType) public: - typedef T (*SqrtFunc)(T); + typedef OutType (*SqrtFunc)(InType); void test_special_numbers(SqrtFunc func) { ASSERT_FP_EQ(aNaN, func(aNaN)); ASSERT_FP_EQ(inf, func(inf)); ASSERT_FP_EQ(aNaN, func(neg_inf)); - ASSERT_FP_EQ(0.0, func(0.0)); - ASSERT_FP_EQ(-0.0, func(-0.0)); - ASSERT_FP_EQ(aNaN, func(T(-1.0))); - ASSERT_FP_EQ(T(1.0), func(T(1.0))); - ASSERT_FP_EQ(T(2.0), func(T(4.0))); - ASSERT_FP_EQ(T(3.0), func(T(9.0))); + ASSERT_FP_EQ(zero, func(zero)); + ASSERT_FP_EQ(neg_zero, func(neg_zero)); + ASSERT_FP_EQ(aNaN, func(InType(-1.0))); + ASSERT_FP_EQ(OutType(1.0), func(InType(1.0))); + ASSERT_FP_EQ(OutType(2.0), func(InType(4.0))); + ASSERT_FP_EQ(OutType(3.0), func(InType(9.0))); } }; #define LIST_SQRT_TESTS(T, func) \ - using LlvmLibcSqrtTest = SqrtTest; \ + using LlvmLibcSqrtTest = SqrtTest; \ + TEST_F(LlvmLibcSqrtTest, SpecialNumbers) { test_special_numbers(&func); } + +#define LIST_NARROWING_SQRT_TESTS(OutType, InType, func) \ + using LlvmLibcSqrtTest = SqrtTest; \ TEST_F(LlvmLibcSqrtTest, SpecialNumbers) { test_special_numbers(&func); } diff --git a/libc/test/src/math/smoke/f16sqrtf_test.cpp b/libc/test/src/math/smoke/f16sqrtf_test.cpp new file mode 100644 index 0000000000000..36231aeb4184d --- /dev/null +++ b/libc/test/src/math/smoke/f16sqrtf_test.cpp @@ -0,0 +1,13 @@ +//===-- Unittests for f16sqrtf --------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "SqrtTest.h" + +#include "src/math/f16sqrtf.h" + +LIST_NARROWING_SQRT_TESTS(float16, float, LIBC_NAMESPACE::f16sqrtf) diff --git a/libc/utils/MPFRWrapper/CMakeLists.txt b/libc/utils/MPFRWrapper/CMakeLists.txt index 6af6fd7707041..e74b02204ed6f 100644 --- a/libc/utils/MPFRWrapper/CMakeLists.txt +++ b/libc/utils/MPFRWrapper/CMakeLists.txt @@ -7,6 +7,8 @@ if(LIBC_TESTS_CAN_USE_MPFR) target_compile_options(libcMPFRWrapper PRIVATE -O3) add_dependencies( libcMPFRWrapper + libc.src.__support.CPP.array + libc.src.__support.CPP.stringstream libc.src.__support.CPP.string_view libc.src.__support.CPP.type_traits libc.src.__support.FPUtil.fp_bits diff --git a/libc/utils/MPFRWrapper/MPFRUtils.cpp b/libc/utils/MPFRWrapper/MPFRUtils.cpp index 6918139fa83b7..100c6b1644b16 100644 --- a/libc/utils/MPFRWrapper/MPFRUtils.cpp +++ b/libc/utils/MPFRWrapper/MPFRUtils.cpp @@ -8,8 +8,10 @@ #include "MPFRUtils.h" +#include "src/__support/CPP/array.h" #include "src/__support/CPP/string.h" #include "src/__support/CPP/string_view.h" +#include "src/__support/CPP/stringstream.h" #include "src/__support/FPUtil/FPBits.h" #include "src/__support/FPUtil/fpbits_str.h" #include "src/__support/macros/properties/types.h" @@ -755,39 +757,51 @@ ternary_operation_one_output(Operation op, InputType x, InputType y, // to build the complete error messages before sending it to the outstream `OS` // once at the end. This will stop the error messages from interleaving when // the tests are running concurrently. -template -void explain_unary_operation_single_output_error(Operation op, T input, - T matchValue, +template +void explain_unary_operation_single_output_error(Operation op, InputType input, + OutputType matchValue, double ulp_tolerance, RoundingMode rounding) { - unsigned int precision = get_precision(ulp_tolerance); + unsigned int precision = get_precision(ulp_tolerance); MPFRNumber mpfrInput(input, precision); MPFRNumber mpfr_result; mpfr_result = unary_operation(op, input, precision, rounding); MPFRNumber mpfrMatchValue(matchValue); - tlog << "Match value not within tolerance value of MPFR result:\n" - << " Input decimal: " << mpfrInput.str() << '\n'; - tlog << " Input bits: " << str(FPBits(input)) << '\n'; - tlog << '\n' << " Match decimal: " << mpfrMatchValue.str() << '\n'; - tlog << " Match bits: " << str(FPBits(matchValue)) << '\n'; - tlog << '\n' << " MPFR result: " << mpfr_result.str() << '\n'; - tlog << " MPFR rounded: " << str(FPBits(mpfr_result.as())) << '\n'; - tlog << '\n'; - tlog << " ULP error: " - << mpfr_result.ulp_as_mpfr_number(matchValue).str() << '\n'; + cpp::array msg_buf; + cpp::StringStream msg(msg_buf); + msg << "Match value not within tolerance value of MPFR result:\n" + << " Input decimal: " << mpfrInput.str() << '\n'; + msg << " Input bits: " << str(FPBits(input)) << '\n'; + msg << '\n' << " Match decimal: " << mpfrMatchValue.str() << '\n'; + msg << " Match bits: " << str(FPBits(matchValue)) << '\n'; + msg << '\n' << " MPFR result: " << mpfr_result.str() << '\n'; + msg << " MPFR rounded: " + << str(FPBits(mpfr_result.as())) << '\n'; + msg << '\n'; + msg << " ULP error: " << mpfr_result.ulp_as_mpfr_number(matchValue).str() + << '\n'; + if (msg.overflow()) + __builtin_unreachable(); + tlog << msg.str(); } -template void explain_unary_operation_single_output_error(Operation op, - float, float, - double, - RoundingMode); -template void explain_unary_operation_single_output_error( - Operation op, double, double, double, RoundingMode); -template void explain_unary_operation_single_output_error( - Operation op, long double, long double, double, RoundingMode); +template void explain_unary_operation_single_output_error(Operation op, float, + float, double, + RoundingMode); +template void explain_unary_operation_single_output_error(Operation op, double, + double, double, + RoundingMode); +template void explain_unary_operation_single_output_error(Operation op, + long double, + long double, double, + RoundingMode); #ifdef LIBC_TYPES_HAS_FLOAT16 -template void explain_unary_operation_single_output_error( - Operation op, float16, float16, double, RoundingMode); +template void explain_unary_operation_single_output_error(Operation op, float16, + float16, double, + RoundingMode); +template void explain_unary_operation_single_output_error(Operation op, float, + float16, double, + RoundingMode); #endif template @@ -949,29 +963,30 @@ template void explain_ternary_operation_one_output_error( Operation, const TernaryInput &, long double, double, RoundingMode); -template -bool compare_unary_operation_single_output(Operation op, T input, T libc_result, +template +bool compare_unary_operation_single_output(Operation op, InputType input, + OutputType libc_result, double ulp_tolerance, RoundingMode rounding) { - unsigned int precision = get_precision(ulp_tolerance); + unsigned int precision = get_precision(ulp_tolerance); MPFRNumber mpfr_result; mpfr_result = unary_operation(op, input, precision, rounding); double ulp = mpfr_result.ulp(libc_result); return (ulp <= ulp_tolerance); } -template bool compare_unary_operation_single_output(Operation, float, - float, double, - RoundingMode); -template bool compare_unary_operation_single_output(Operation, double, - double, double, - RoundingMode); -template bool compare_unary_operation_single_output( - Operation, long double, long double, double, RoundingMode); +template bool compare_unary_operation_single_output(Operation, float, float, + double, RoundingMode); +template bool compare_unary_operation_single_output(Operation, double, double, + double, RoundingMode); +template bool compare_unary_operation_single_output(Operation, long double, + long double, double, + RoundingMode); #ifdef LIBC_TYPES_HAS_FLOAT16 -template bool compare_unary_operation_single_output(Operation, float16, - float16, double, - RoundingMode); +template bool compare_unary_operation_single_output(Operation, float16, float16, + double, RoundingMode); +template bool compare_unary_operation_single_output(Operation, float, float16, + double, RoundingMode); #endif template diff --git a/libc/utils/MPFRWrapper/MPFRUtils.h b/libc/utils/MPFRWrapper/MPFRUtils.h index d2f73e2628e16..805678b96c2ef 100644 --- a/libc/utils/MPFRWrapper/MPFRUtils.h +++ b/libc/utils/MPFRWrapper/MPFRUtils.h @@ -129,8 +129,9 @@ struct AreMatchingBinaryInputAndBinaryOutput, BinaryOutput> { static constexpr bool VALUE = cpp::is_floating_point_v; }; -template -bool compare_unary_operation_single_output(Operation op, T input, T libc_output, +template +bool compare_unary_operation_single_output(Operation op, InputType input, + OutputType libc_output, double ulp_tolerance, RoundingMode rounding); template @@ -157,9 +158,9 @@ bool compare_ternary_operation_one_output(Operation op, T libc_output, double ulp_tolerance, RoundingMode rounding); -template -void explain_unary_operation_single_output_error(Operation op, T input, - T match_value, +template +void explain_unary_operation_single_output_error(Operation op, InputType input, + OutputType match_value, double ulp_tolerance, RoundingMode rounding); template @@ -212,7 +213,7 @@ class MPFRMatcher : public testing::Matcher { bool is_silent() const override { return silent; } private: - template bool match(T in, T out) { + template bool match(T in, U out) { return compare_unary_operation_single_output(op, in, out, ulp_tolerance, rounding); } @@ -238,7 +239,7 @@ class MPFRMatcher : public testing::Matcher { rounding); } - template void explain_error(T in, T out) { + template void explain_error(T in, U out) { explain_unary_operation_single_output_error(op, in, out, ulp_tolerance, rounding); } @@ -271,6 +272,12 @@ class MPFRMatcher : public testing::Matcher { // types. template constexpr bool is_valid_operation() { + constexpr bool IS_NARROWING_OP = op == Operation::Sqrt && + cpp::is_floating_point_v && + cpp::is_floating_point_v && + sizeof(OutputType) <= sizeof(InputType); + if (IS_NARROWING_OP) + return true; return (Operation::BeginUnaryOperationsSingleOutput < op && op < Operation::EndUnaryOperationsSingleOutput && cpp::is_same_v &&