diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index c5dea5f18030f2d226c86e3408ea85b2b5989728..b8c0be48336f67635f046994f2e0af13576c0ac3 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -195,6 +195,7 @@ xla_test( name = "math_test", srcs = ["math_test.cc"], deps = [ + ":constants", ":math", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", diff --git a/tensorflow/compiler/xla/client/lib/math_exhaustive_test.cc b/tensorflow/compiler/xla/client/lib/math_exhaustive_test.cc index 7f423d54dbb7ff911398b0137b482ee47f46c5c1..ea6640d10ee48bbc6405cfbc10f0cf718dbf1f2d 100644 --- a/tensorflow/compiler/xla/client/lib/math_exhaustive_test.cc +++ b/tensorflow/compiler/xla/client/lib/math_exhaustive_test.cc @@ -95,6 +95,7 @@ class MathExhaustiveTest : public ClientLibraryTestBase, // Checks a function's behavior on all fp16 values. // // TODO(jlebar): asin and lgamma tests fail on interpreter. +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) XLA_TEST_P(MathExhaustiveTest, DISABLED_ON_INTERPRETER(F16)) { const Testcase& tc = GetParam(); XlaBuilder b(TestName()); @@ -137,6 +138,7 @@ XLA_TEST_P(MathExhaustiveTest, DISABLED_ON_INTERPRETER(F16)) { tc.op(param); ComputeAndCompareR1(&b, expected_result, {}, tc.error); } +#endif // TODO(b/123355973): The following tests from math.cc are missing. // @@ -163,10 +165,10 @@ XLA_TEST_P(MathExhaustiveTest, DISABLED_ON_INTERPRETER(F16)) { // TODO(b/123355973): Test bf16 and f32. // TODO(b/123355973): Get rid of skip_infs / skip_neg_zero below if possible. // TODO(b/123355973): Reduce lgamma error if possible; it is very high. +// TODO(b/123355973): Move these into exhaustive_op_test. INSTANTIATE_TEST_CASE_P( MathExhaustiveTest_Instantiation, MathExhaustiveTest, ::testing::ValuesIn(std::vector{ - Testcase{"sqrt", Sqrt, std::sqrt}.set_skip_neg_inf(), Testcase{"rsqrt", Rsqrt, [](float x) { return 1 / std::sqrt(x); }} .set_tolerance(0.05, 0.05) .set_skip_infs() diff --git a/tensorflow/compiler/xla/client/lib/math_test.cc b/tensorflow/compiler/xla/client/lib/math_test.cc index bdfb0575f573716b54cf9116d155d8a3a55056e8..f5ba3e78056a27f2b2f981df42bc40a3dedbce00 100644 --- a/tensorflow/compiler/xla/client/lib/math_test.cc +++ b/tensorflow/compiler/xla/client/lib/math_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/primitive_util.h" @@ -104,6 +105,37 @@ class MathTypedTest : public MathTest { {true, false, false, false, false, false, false}), {}, error_spec_); } + + // sqrt(x) == pow(x, 0.5) except that + // + // pow(-inf, 0.5) == inf, while + // sqrt(-inf) == nan. + // + // Check that none of our backends are incorrectly assuming that sqrt(x) == + // pow(x, 0.5) without checking this edge case. + // + // For good measure, we also check pow with an exponent other than 0.5. + void TestSqrtPowInequivalence() { + SetFastMathDisabled(true); + + // Tests disable constant folding by default, but this test needs it + // enabled, otherwise we don't tickle the bug we're trying to catch. + // Specifically, without constant folding, the constants we pass to Pow + // below are hidden behind a reshape that's never folded away! + mutable_debug_options()->clear_xla_disable_hlo_passes(); + + const T inf(std::numeric_limits::infinity()); + const T nan(std::numeric_limits::quiet_NaN()); + + XlaBuilder b(TestName()); + auto x = AddParam(LiteralUtil::CreateR1({-inf}), &b); + ConstantR1(&b, {-inf}); + ConcatInDim( + &b, {Sqrt(x), Pow(x, ScalarLike(x, 0.5)), Pow(x, ScalarLike(x, 0.3))}, + 0); + std::vector expected = {nan, inf, inf}; + ComputeAndCompareR1(&b, expected, {}, error_spec_); + } }; // TODO(b/123355973): Add bfloat16 to TestTypes once it's working. @@ -119,6 +151,9 @@ XLA_TYPED_TEST(MathTypedTest, LogEdgeCases) { this->TestLogEdgeCases(); } XLA_TYPED_TEST(MathTypedTest, Log1pEdgeCases) { this->TestLog1pEdgeCases(); } XLA_TYPED_TEST(MathTypedTest, IsInfOrNan) { this->TestIsInfOrNan(); } XLA_TYPED_TEST(MathTypedTest, IsNegZero) { this->TestIsNegZero(); } +XLA_TYPED_TEST(MathTypedTest, SqrtPowInequivalence) { + this->TestSqrtPowInequivalence(); +} // Check that certain ops only support real, floating-point inputs. // @@ -239,6 +274,7 @@ XLA_TEST_F(MathTest, Lgamma) { ComputeAndCompareR1(&builder, expected, {}, error_spec_); } +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) XLA_TEST_F(MathTest, LgammaF16) { SetFastMathDisabled(true); @@ -259,6 +295,7 @@ XLA_TEST_F(MathTest, LgammaF16) { }; ComputeAndCompareR1(&b, expected, {}, ErrorSpec{0.1}); } +#endif XLA_TEST_F(MathTest, Digamma) { XlaBuilder builder(TestName()); diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index dd74788a0e2940e88dfca1ffa4a4cdad7c1997e2..b36ed5c6e0809de8b3bd037587029f0e622d43dc 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "llvm/IR/DerivedTypes.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" // IWYU pragma: no_include "llvm/IR/Attributes.gen.inc" @@ -191,39 +192,6 @@ StatusOr GpuElementalIrEmitter::EmitPowerOp( PrimitiveType lhs_input_type = op->operand(0)->shape().element_type(); PrimitiveType rhs_input_type = op->operand(1)->shape().element_type(); PrimitiveType output_type = op->shape().element_type(); - llvm::Type* llvm_ty = lhs_value->getType(); - - auto make_sqrt = [&, this]() -> StatusOr { - // NVPTX has four relevant square root instructions: - // sqrt.approx{.ftz}.f32 - // sqrt.rn{.ftz}.f32 - // sqrt.rn.f64 - // rsqrt.approx.f64 - // We rely on LLVM's NVPTX backend to pick the right one based on our - // fast-math options. (If fast-math is enabled, llvm may compute the 64-bit - // sqrt from the rsqrt approximation.) - return EmitLlvmIntrinsicMathCall("llvm.sqrt", {lhs_value}, {lhs_input_type}, - output_type); - }; - - const HloInstruction* rhs = op->operand(1); - if (IsFPLiteralWithValue(rhs, .5)) { - VLOG(10) << "emitting pow(A, .5) as sqrt(A): " << op->ToString(); - return make_sqrt(); - } - - if (IsFPLiteralWithValue(rhs, -.5)) { - VLOG(10) << "emitting pow(A, -.5) as 1/sqrt(A): " << op->ToString(); - // LLVM's NVPTX backend knows how to transform 1/sqrt(A) into the NVPTX - // rsqrt.approx instruction. - // - // TODO(jlebar): Does this happen with fastmath disabled? If not, should - // we force-enable it? - TF_ASSIGN_OR_RETURN(auto* sqrt, make_sqrt()); - return FDiv(llvm::ConstantFP::get(llvm_ty, 1), sqrt); - } - - VLOG(10) << "emitting pow as regular call to pow(): " << op->ToString(); return EmitLibdeviceMathCall("__nv_pow", {lhs_value, rhs_value}, {lhs_input_type, rhs_input_type}, output_type); } diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 562854756628df64fbf92d40af859f8b218b0cc2..7158708e9c3935a37aab4c0914a8ff569928b5be 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -674,8 +674,8 @@ xla_test( ) xla_test( - name = "exhaustive_f32_elementwise_op_test", - srcs = ["exhaustive_f32_elementwise_op_test.cc"], + name = "exhaustive_op_test", + srcs = ["exhaustive_op_test.cc"], real_hardware_only = True, # Very slow on the interpreter. shard_count = 48, tags = [ @@ -687,6 +687,7 @@ xla_test( ":client_library_test_base", ":literal_test_util", "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client/lib:constants", "//tensorflow/compiler/xla/client/lib:math", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc deleted file mode 100644 index b961e6102692cb3b90976d621c62cb4cf18a9b6b..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc +++ /dev/null @@ -1,237 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include "absl/base/casts.h" -#include "tensorflow/compiler/xla/client/lib/math.h" -#include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/tests/client_library_test_base.h" -#include "tensorflow/compiler/xla/tests/literal_test_util.h" -#include "tensorflow/compiler/xla/tests/test_macros.h" - -namespace xla { -namespace { - -class ExhaustiveF32ElementwiseOpTest - : public ClientLibraryTestBase, - public ::testing::WithParamInterface> { - protected: - ErrorSpec error_spec_{0.0001, 0.0001}; - - bool IsClose(float expected, float actual) { - float abs_err = std::abs(expected - actual); - float rel_err = abs_err / std::abs(expected); - return abs_err < error_spec_.abs || rel_err < error_spec_.rel || - (std::isnan(expected) && std::isnan(actual)) || - (std::isinf(expected) && std::isinf(actual) && - (expected > 0) == (actual > 0)); - } - - template - void ExhaustivelyTestF32Op(EnqueueOpTy enqueue_op, - float (*evaluate_op)(float), - std::pair known_incorrect_range) { - SetFastMathDisabled(true); - - int64 begin, end; - std::tie(begin, end) = GetParam(); - int64 input_size = end - begin; - - if (begin >= known_incorrect_range.first && - end <= known_incorrect_range.second) { - LOG(INFO) << absl::StreamFormat( - "Skipping this shard, as the range under test, [%d, %d), falls " - "entirely within the known-incorrect range [%d, %d).", - begin, end, known_incorrect_range.first, - known_incorrect_range.second); - return; - } - - LOG(INFO) << "Checking range [" << begin << ", " << end << ")"; - - XlaBuilder builder(TestName()); - - auto ith_input_elem = [&](int64 i) -> float { - i += begin; - // If the operation is known to be buggy on a specific input clamp that - // input to 0 under the assumption that the op is at least correct on 0. - if (i >= known_incorrect_range.first && - i < known_incorrect_range.second) { - return 0; - } - return absl::bit_cast(i); - }; - - Literal input_literal = - LiteralUtil::CreateFromDimensions(F32, {input_size}); - absl::Span input_arr = input_literal.data(); - for (int64 i = 0; i < input_size; i++) { - input_arr[i] = ith_input_elem(i); - } - auto input = Parameter(&builder, 0, input_literal.shape(), "input"); - enqueue_op(&builder, input); - TF_ASSERT_OK_AND_ASSIGN(XlaComputation comp, builder.Build()); - - // Build and run the computation using the LocalClient API, rather than the - // plain Client API, which is used by ClientLibraryTestBase. This is - // because the plain Client API results does more memcpys to/from Literals, - // and that's slow given that we're touching a lot of data here. - // - // Copy debug options from ClientLibraryTestBase. In particular, we're - // interested in disabling constant folding. - ExecutableBuildOptions build_opts; - *build_opts.mutable_debug_options() = *mutable_debug_options(); - TF_ASSERT_OK_AND_ASSIGN( - auto executable, - client_->Compile(comp, {&input_literal.shape()}, build_opts)); - - TF_ASSERT_OK_AND_ASSIGN( - ScopedShapedBuffer input_data, - client_->LiteralToShapedBuffer(input_literal, /*device_ordinal=*/0)); - - ExecutableRunOptions run_opts; - run_opts.set_allocator(client_->backend().memory_allocator()); - run_opts.set_intra_op_thread_pool( - client_->backend().eigen_intra_op_thread_pool_device()); - TF_ASSERT_OK_AND_ASSIGN(ScopedShapedBuffer result, - executable->Run({&input_data}, run_opts)); - - TF_ASSERT_OK_AND_ASSIGN(Literal result_literal, - client_->ShapedBufferToLiteral(result)); - - // We essentially reimplement LiteralTestUtil::Near here because - // a) this streamlined implementation is much faster, and - // b) we can print out better error messages (namely, we can print out - // which floating-point value input failed, while LiteralTestUtil::Near - // can only print out the input index that failed). - // c) we need special handling of certain inputs. For example, we say that - // a denormal input has multiple correct outputs (namely, f(x) and f(0)) - // and just needs to be close to one of them. - absl::Span result_arr = result_literal.data(); - ASSERT_EQ(result_arr.size(), input_arr.size()); - int64 mismatches = 0; - // Hoisting this out of the loop is a nice speedup on shards that have many - // denormals. - const float expected_at_zero = evaluate_op(0); - for (int64 i = 0; i < input_arr.size(); ++i) { - float input = ith_input_elem(i); - float actual = result_arr[i]; - float expected = evaluate_op(input); - if (IsClose(expected, actual)) { - continue; - } - - constexpr int64 kMaxMismatchesPrinted = 1000; - if (std::fpclassify(input) == FP_SUBNORMAL) { - // For denormal inputs, we accept answers that are close to either - // - evaluate_op(input) OR - // - evaluate_op(0). - if (IsClose(expected_at_zero, actual)) { - continue; - } - ++mismatches; - if (mismatches < kMaxMismatchesPrinted || VLOG_IS_ON(2)) { - // Use %0.9g because that's guaranteed to print an f32 to full - // precision. - LOG(ERROR) << absl::StreamFormat( - "Mismatch on denormal value %0.9g (0x%08x). Expected either " - "%0.9g (0x%08x) (evaluated at true value) or %0.9g (0x%08x) " - "(evaluated at zero), but got %0.9g (0x%08x).", - input, absl::bit_cast(input), // - expected, absl::bit_cast(expected), // - expected_at_zero, absl::bit_cast(expected_at_zero), - actual, absl::bit_cast(actual)); - } - } else { - mismatches++; - if (mismatches < kMaxMismatchesPrinted || VLOG_IS_ON(2)) { - LOG(ERROR) << absl::StreamFormat( - "Mismatch on %0.9g (0x%08x). Expected %0.9g (0x%08x), but got " - "%0.9g (0x%08x).", - input, absl::bit_cast(input), // - expected, absl::bit_cast(expected), // - actual, absl::bit_cast(actual)); - } - } - - if (mismatches == kMaxMismatchesPrinted && !VLOG_IS_ON(2)) { - LOG(ERROR) << "Not printing any more mismatches; pass " - "--vmodule=exhaustive_f32_elementwise_op_test=2 to see " - "all of them."; - } - } - EXPECT_EQ(mismatches, 0); - } -}; - -XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, LogF32) { -#if !defined(XLA_TEST_BACKEND_CPU) && !defined(XLA_TEST_BACKEND_GPU) - error_spec_ = ErrorSpec{0.001, 0.001}; -#endif - ExhaustivelyTestF32Op( - [](XlaBuilder* builder, const XlaOp& input) { Log(input); }, std::log, - /*known_incorrect_range=*/{0, 0}); -} - -XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, ExpF32) { -#ifdef XLA_TEST_BACKEND_CPU - // TODO(b/73142289): The vectorized Exp implementation gives results outside - // our error spec in this range (these numbers are bitwise representations of - // floats expressed as a zero extended int64): - std::pair known_incorrect_range = {1107296256 + 11583654, - 1107296256 + 11629080}; -#else - std::pair known_incorrect_range = {0, 0}; -#endif - - ExhaustivelyTestF32Op( - [](XlaBuilder* builder, const XlaOp& input) { Exp(input); }, std::exp, - known_incorrect_range); -} - -XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, TanhF32) { - ExhaustivelyTestF32Op( - [](XlaBuilder* builder, const XlaOp& input) { Tanh(input); }, std::tanh, - /*known_incorrect_range=*/{0, 0}); -} - -XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, ErfF32) { - ExhaustivelyTestF32Op( - [](XlaBuilder* builder, const XlaOp& input) { Erf(input); }, std::erf, - /*known_incorrect_range=*/{0, 0}); -} - -XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, ErfcF32) { - ExhaustivelyTestF32Op( - [](XlaBuilder* builder, const XlaOp& input) { Erfc(input); }, std::erfc, - /*known_incorrect_range=*/{0, 0}); -} - -std::vector> CreateExhaustiveParameters() { - // We break up the 2^32-element space into small'ish chunks to keep peak - // memory usage low. - std::vector> result; - const int64 step = 1 << 25; - for (int64 i = 0; i < (1l << 32); i += step) { - result.push_back({i, i + step}); - } - return result; -} - -INSTANTIATE_TEST_CASE_P(ExhaustiveF32ElementwiseOpTestInstance, - ExhaustiveF32ElementwiseOpTest, - ::testing::ValuesIn(CreateExhaustiveParameters())); -} // namespace -} // namespace xla diff --git a/tensorflow/compiler/xla/tests/exhaustive_op_test.cc b/tensorflow/compiler/xla/tests/exhaustive_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..1d94937656e9b158c9304c27075b4039cc0e6fbb --- /dev/null +++ b/tensorflow/compiler/xla/tests/exhaustive_op_test.cc @@ -0,0 +1,450 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include "absl/base/casts.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" + +namespace xla { +namespace { + +using Eigen::half; + +// For f32, f16, and bf16, we need 9, 5, and 4 decimal places of precision to be +// guaranteed that we're printing the full number. +// +// If we have a floating-point number with S significand bits, we need +// +// ceil(1 + S * log_10(2)) ~= ceil(1 + S * 0.30103) +// +// decimal digits to be guaranteed that we're printing the full number. For +// F32/F16/BF16 this works out to 9/5/4 digits. See +// https://people.eecs.berkeley.edu/~wkahan/Math128/BinDecBin.pdf +string StringifyNum(float x) { + return absl::StrFormat("%0.9g (0x%08x)", x, absl::bit_cast(x)); +} + +string StringifyNum(half x) { + return absl::StrFormat("%0.5g (0x%04x)", static_cast(x), + absl::bit_cast(x)); +} + +string StringifyNum(bfloat16 x) { + return absl::StrFormat("%0.4g (0x%04x)", static_cast(x), + absl::bit_cast(x)); +} + +// Test parameter is a tuple containing +// - primitive type under test, +// - (begin, end) range under test, as zero-extended int64s bitcast to the +// primtive type under test. +class ExhaustiveOpTest + : public ClientLibraryTestBase, + public ::testing::WithParamInterface< + std::tuple>> { + public: + ExhaustiveOpTest() + : ty_(std::get<0>(GetParam())), platform_(client_->platform()->Name()) {} + + void Run(std::function enqueue_op, + float (*evaluate_op)(float)) { + SetFastMathDisabled(true); + + // Run all HLO passes. In particular, constant folding is disabled by + // default for tests, but we need to run it in order to tickle some bugs. + mutable_debug_options()->clear_xla_disable_hlo_passes(); + + PrimitiveType ty; + std::tie(ty, std::ignore) = GetParam(); + + switch (ty) { + case F32: + SetDefaultErrSpec(0.0001, 0.0001); + RunImpl(enqueue_op, evaluate_op); + break; + case F16: + SetDefaultErrSpec(0.001, 0.001); + RunImpl(enqueue_op, evaluate_op); + break; + case BF16: + SetDefaultErrSpec(0.001, 0.01); + RunImpl(enqueue_op, evaluate_op); + break; + default: + LOG(FATAL) << "Unhandled type."; + } + } + + void SetDefaultErrSpec(float abs_err, float rel_err) { + if (!abs_err_.has_value()) { + abs_err_ = abs_err; + } + if (!rel_err_.has_value()) { + rel_err_ = rel_err; + } + } + + template + void RunImpl(std::function enqueue_op, + float (*evaluate_op)(float)) { + static_assert( + sizeof(T) == sizeof(IntegralT), + "IntegralT must be an unsigned integer type of the same width as T."); + + PrimitiveType ty; + std::pair test_range; + std::tie(ty, test_range) = GetParam(); + int64 begin, end; + std::tie(begin, end) = test_range; + + if (begin >= known_incorrect_begin_ && end <= known_incorrect_end_) { + LOG(INFO) << absl::StreamFormat( + "Skipping this shard, as the range under test, [%d, %d), falls " + "entirely within the known-incorrect range [%d, %d).", + begin, end, known_incorrect_begin_, known_incorrect_end_); + return; + } + + LOG(INFO) << "Checking range [" << begin << ", " << end << ")"; + + int64 input_size = end - begin; + Literal input_literal = LiteralUtil::CreateFromDimensions(ty, {input_size}); + absl::Span input_arr = input_literal.data(); + for (int64 i = 0; i < input_size; i++) { + IntegralT input_val = i + begin; + // If the operation is known to be buggy on a specific input clamp that + // input to 0 under the assumption that the op is at least correct on 0. + if (input_val >= known_incorrect_begin_ && + input_val < known_incorrect_end_) { + input_arr[i] = T{0}; + } else { + input_arr[i] = absl::bit_cast(input_val); + } + } + + TF_ASSERT_OK_AND_ASSIGN(Literal result_literal, + BuildAndRunComputation(enqueue_op, input_literal)); + ExpectNear(input_literal, result_literal, evaluate_op); + } + + StatusOr BuildAndRunComputation( + const std::function& enqueue_op, + const Literal& input_literal) { + XlaBuilder builder(TestName()); + auto input = Parameter(&builder, 0, input_literal.shape(), "input"); + enqueue_op(input); + TF_ASSIGN_OR_RETURN(XlaComputation comp, builder.Build()); + + // Build and run the computation using the LocalClient API, rather than the + // plain Client API, which is used by ClientLibraryTestBase. This is + // because the plain Client API results does more memcpys to/from Literals, + // and that's slow given that we're touching a lot of data here. + // + // Copy debug options from ClientLibraryTestBase. In particular, we're + // interested in disabling constant folding. + ExecutableBuildOptions build_opts; + *build_opts.mutable_debug_options() = *mutable_debug_options(); + TF_ASSIGN_OR_RETURN( + auto executable, + client_->Compile(comp, {&input_literal.shape()}, build_opts)); + + TF_ASSIGN_OR_RETURN( + ScopedShapedBuffer input_data, + client_->LiteralToShapedBuffer(input_literal, /*device_ordinal=*/0)); + + ExecutableRunOptions run_opts; + run_opts.set_allocator(client_->backend().memory_allocator()); + run_opts.set_intra_op_thread_pool( + client_->backend().eigen_intra_op_thread_pool_device()); + TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, + executable->Run({&input_data}, run_opts)); + + TF_ASSIGN_OR_RETURN(Literal result_literal, + client_->ShapedBufferToLiteral(result)); + return std::move(result_literal); + } + + template + bool IsClose(T expected, T actual) { + float expected_f32 = static_cast(expected); + float actual_f32 = static_cast(actual); + float abs_err = std::abs(expected_f32 - actual_f32); + float rel_err = abs_err / std::abs(expected_f32); + if (strict_signed_zeros_ && actual == T{0} && expected == T{0}) { + // Check sign of zero. + return std::signbit(actual_f32) == std::signbit(expected_f32); + } + return abs_err < *abs_err_ || rel_err < *rel_err_ || + (std::isnan(expected_f32) && std::isnan(actual_f32)) || + (std::isinf(expected_f32) && std::isinf(actual_f32) && + (expected_f32 > 0) == (actual_f32 > 0)); + } + + template + void ExpectNear(const Literal& input_literal, const Literal& result_literal, + float (*evaluate_op)(float)) { + // We essentially reimplement LiteralTestUtil::Near here because + // a) this streamlined implementation is much faster, and + // b) we can print out better error messages (namely, we can print out + // which floating-point value input failed, while LiteralTestUtil::Near + // can only print out the input index that failed). + // c) we need special handling of certain inputs. For example, we say that + // a denormal input has multiple correct outputs (namely, f(x) and f(0)) + // and just needs to be close to one of them. + absl::Span input_arr = input_literal.data(); + absl::Span result_arr = result_literal.data(); + ASSERT_EQ(result_arr.size(), input_arr.size()); + int64 mismatches = 0; + // Hoisting these out of the loop is a nice speedup on shards that have many + // denormals. + const T expected_at_pos_zero = static_cast(evaluate_op(0)); + const T expected_at_neg_zero = static_cast(evaluate_op(-0.0)); + for (int64 i = 0; i < input_arr.size(); ++i) { + T input = input_arr[i]; + float input_f32 = static_cast(input); + T actual = result_arr[i]; + T expected = static_cast(evaluate_op(input_f32)); + + if (IsClose(expected, actual)) { + continue; + } + + // Easy case: If `input` is not denormal and !IsClose(expected, actual), + // print an error. + // + // TODO(jlebar): This doesn't correctly detect f16 and bfloat16 denormals! + // This seems to be OK for now, but at some point we may need to implement + // fpclassify for half and bfloat. + if (std::fpclassify(input_f32) != FP_SUBNORMAL) { + PrintMismatch(&mismatches, [&] { + return absl::StrFormat("Mismatch on %s. Expected %s, but got %s.", + StringifyNum(input), StringifyNum(expected), + StringifyNum(actual)); + }); + continue; + } + + // Otherwise, `input` is denormal. For denormal inputs, we accept answers + // that are close to any of: + // + // - evaluate_op(input) + // - evaluate_op(+/-0), where the sign of 0 equal to the sign of + // `input`, + // - if relaxed_denormal_signs_, evaluate_op(-/+0), where the sign of + // 0 is the opposite of `input`. + T sign_preserving_ftz_expected = + std::signbit(input_f32) ? expected_at_neg_zero : expected_at_pos_zero; + T sign_nonpreserving_ftz_expected = + std::signbit(input_f32) ? expected_at_pos_zero : expected_at_neg_zero; + if (IsClose(sign_preserving_ftz_expected, actual) || + (relaxed_denormal_signs_ && + IsClose(sign_nonpreserving_ftz_expected, actual))) { + continue; + } + + if (relaxed_denormal_signs_) { + PrintMismatch(&mismatches, [&] { + return absl::StrFormat( + "Mismatch on denormal value %s. Expected one of:\n" + " %10s (evaluated at full-precision value)\n" + " %10s (evaluated after flushing to sign-preserving zero)\n" + " %10s (evaluated after flushing to non-sign-preserving " + "zero)\n" + "but got %s.", + StringifyNum(input), StringifyNum(expected), + StringifyNum(sign_preserving_ftz_expected), + StringifyNum(sign_nonpreserving_ftz_expected), + StringifyNum(actual)); + }); + } else { + PrintMismatch(&mismatches, [&] { + return absl::StrFormat( + "Mismatch on denormal value %s. Expected one of:\n" + " %10s (evaluated at full-precision value)\n" + " %10s (evaluated after flushing to sign-preserving zero)\n" + "but got %s.", + StringifyNum(input), StringifyNum(expected), + StringifyNum(sign_preserving_ftz_expected), StringifyNum(actual)); + }); + } + } + EXPECT_EQ(mismatches, 0); + } + + template + void PrintMismatch(int64* mismatches, const ErrorGenerator& err_generator) { + // We send a few mismatches to gunit so they show up nicely in test logs. + // Then we send more to LOG(ERROR). The remainder we squelch unless we're + // at vlog level 2. + constexpr int64 kMaxMismatchesLoggedToGunit = 10; + constexpr int64 kMaxMismatchesLoggedToErr = 1000; + + (*mismatches)++; + if (*mismatches < kMaxMismatchesLoggedToGunit) { + FAIL() << err_generator(); + } else if (*mismatches < kMaxMismatchesLoggedToErr || VLOG_IS_ON(2)) { + LOG(ERROR) << err_generator(); + } else if (*mismatches == kMaxMismatchesLoggedToErr) { + LOG(ERROR) << "Not printing any more mismatches; pass " + "--vmodule=exhaustive_f32__op_test=2 to see " + "all of them."; + } + } + + // The following members are set during construction so testcases can read + // these values and use them e.g. to influence the values given to the mutable + // members below. + + // The primitive type under test. + const PrimitiveType ty_; + + // The platform under test. + const string platform_; + + // Tests can set the following variables for control over execution. This is + // safe because each XLA_TEST_P instantiates a new instance of this class. + + // Testing will ignore the given range (encoded as bitwise representations of + // the type under test zero-extended to int64). + int64 known_incorrect_begin_ = 0; + int64 known_incorrect_end_ = 0; + + // If unset, reasonable defaults will be used depending on the type under + // test. + absl::optional abs_err_; + absl::optional rel_err_; + + // If true, will consider -0 not near to +0 and vice versa. Note that + // +epsilon may still be considered close to -0, depending on the error spec; + // this only covers the case when both `expected` and `actual` are equal to 0. + bool strict_signed_zeros_ = false; + + // If true, allows denormals to be flushed to non-sign-preserving 0. + // + // For example, normally we'd expect sqrt(-denormal) to be either nan (sqrt of + // a negative number) or -inf (flush the denormal to sign-perserving zero, + // then sqrt(-0)). But with this as true, we'll also accept 0 (sqrt(0)). + // + // XLA:GPU preserves denormal signs, but other backends don't. + bool relaxed_denormal_signs_ = platform_ != "CUDA"; +}; + +XLA_TEST_P(ExhaustiveOpTest, Log) { + if (platform_ != "Host" && platform_ != "CUDA" && ty_ == F32) { + abs_err_ = 0.001; + rel_err_ = 0.001; + } + + Run(Log, std::log); +} + +XLA_TEST_P(ExhaustiveOpTest, Exp) { + if (platform_ == "Host" && ty_ == F32) { + // TODO(b/73142289): The vectorized Exp implementation gives results outside + // our error spec in this range. + known_incorrect_begin_ = 1107296256 + 11583654; + known_incorrect_end_ = 1107296256 + 11629080; + } else if (platform_ == "Host" && ty_ == BF16) { + // TODO(jlebar): Is this a rounding error? Why doesn't it occur on XLA:GPU? + // + // Mismatch on 88.5 (0x42b1). + // Expected 2.72491739e+38 (0x7f4d), but got inf (0x7f80). + known_incorrect_begin_ = 0x42b1; + known_incorrect_end_ = 0x42b2; + } + + Run(Exp, std::exp); +} + +// It feels a little overkill to exhaustively test sqrt and pow(x, 0.5), but +// this *did* find a bug, namely that some backends were assuming sqrt(x) == +// pow(x, 0.5), but this is not true for x == -inf. +XLA_TEST_P(ExhaustiveOpTest, PowOneHalf) { + Run([](XlaOp x) { return Pow(x, ScalarLike(x, 0.5)); }, + +[](float x) { return std::pow(x, 0.5f); }); +} + +XLA_TEST_P(ExhaustiveOpTest, Rsqrt) { + Run( + Rsqrt, +[](float x) { return 1 / std::sqrt(x); }); +} + +XLA_TEST_P(ExhaustiveOpTest, Sqrt) { + if (platform_ == "Host" || platform_ == "CUDA") { + strict_signed_zeros_ = true; + } + + Run(Sqrt, std::sqrt); +} + +XLA_TEST_P(ExhaustiveOpTest, Tanh) { + // TODO(jlebar): Enable this test for (b)f16. + if (ty_ == F16 || ty_ == BF16) { + return; + } + Run(Tanh, std::tanh); +} +XLA_TEST_P(ExhaustiveOpTest, Erf) { + // TODO(jlebar): Enable this test for (b)f16. + if (ty_ == F16 || ty_ == BF16) { + return; + } + Run(Erf, std::erf); +} +XLA_TEST_P(ExhaustiveOpTest, Erfc) { + // TODO(jlebar): Enable this test for (b)f16. + if (ty_ == F16 || ty_ == BF16) { + return; + } + Run(Erfc, std::erfc); +} + +std::vector> CreateExhaustiveF32Ranges() { + // We break up the 2^32-element space into small'ish chunks to keep peak + // memory usage low. + std::vector> result; + const int64 step = 1 << 25; + for (int64 i = 0; i < (1l << 32); i += step) { + result.push_back({i, i + step}); + } + return result; +} + +INSTANTIATE_TEST_SUITE_P( + F32, ExhaustiveOpTest, + ::testing::Combine(::testing::Values(F32), + ::testing::ValuesIn(CreateExhaustiveF32Ranges()))); + +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) +INSTANTIATE_TEST_SUITE_P( + F16, ExhaustiveOpTest, + ::testing::Combine(::testing::Values(F16), + ::testing::Values(std::make_pair(0, 1 << 16)))); +#endif + +#if defined(XLA_BACKEND_SUPPORTS_BFLOAT16) +INSTANTIATE_TEST_SUITE_P( + BF16, ExhaustiveOpTest, + ::testing::Combine(::testing::Values(BF16), + ::testing::Values(std::make_pair(0, 1 << 16)))); +#endif + +} // namespace +} // namespace xla