From fd53853d50028ca744d316537754c92a072cf9f5 Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Fri, 22 Feb 2019 16:35:39 -0800 Subject: [PATCH] [XLA] Remove math_exhaustive_test. Now mostly covered by exhaustive_op_test. The big missing gap is the trig functions. But to fix these, we need to carefully define how they're supposed to work over complex inputs, and that's a bigger task than I'm ready to take on at the moment. PiperOrigin-RevId: 235281417 --- tensorflow/compiler/xla/client/lib/BUILD | 16 -- .../xla/client/lib/math_exhaustive_test.cc | 175 ------------------ .../compiler/xla/tests/exhaustive_op_test.cc | 18 +- 3 files changed, 10 insertions(+), 199 deletions(-) delete mode 100644 tensorflow/compiler/xla/client/lib/math_exhaustive_test.cc diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index b8c0be48336..f264ec50a26 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -208,22 +208,6 @@ xla_test( ], ) -xla_test( - name = "math_exhaustive_test", - srcs = ["math_exhaustive_test.cc"], - shard_count = 16, - deps = [ - ":math", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/tests:client_library_test_base", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - ], -) - cc_library( name = "matrix", srcs = ["matrix.cc"], diff --git a/tensorflow/compiler/xla/client/lib/math_exhaustive_test.cc b/tensorflow/compiler/xla/client/lib/math_exhaustive_test.cc deleted file mode 100644 index 7f5add15bae..00000000000 --- a/tensorflow/compiler/xla/client/lib/math_exhaustive_test.cc +++ /dev/null @@ -1,175 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/client/lib/math.h" -#include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/client_library_test_base.h" -#include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/compiler/xla/types.h" - -namespace xla { -namespace { - -using Eigen::half; - -struct Testcase { - Testcase(string name, const std::function& op, - float (*host_op)(float)) - : name(name), op(op), host_op(host_op) {} - - Testcase& set_tolerance(float abs_err, float rel_err) { - error.abs = abs_err; - error.rel = rel_err; - return *this; - } - - Testcase& set_relaxed_nans() { - error.relaxed_nans = true; - return *this; - } - - Testcase& set_fewer_infs_ok() { - error.fewer_infs_ok = true; - return *this; - } - - Testcase& set_skip_pos_inf() { - skip_pos_inf = true; - return *this; - } - - Testcase& set_skip_neg_inf() { - skip_neg_inf = true; - return *this; - } - - Testcase& set_skip_infs() { - skip_pos_inf = true; - skip_neg_inf = true; - return *this; - } - - Testcase& set_skip_neg_zero() { - skip_neg_zero = true; - return *this; - } - - string name; - std::function op; - float (*host_op)(float); - - ErrorSpec error{0.01, 0.01}; - - // If true, don't test +/-infinity or negative 0. - bool skip_pos_inf = false; - bool skip_neg_inf = false; - bool skip_neg_zero = false; -}; - -void PrintTo(const Testcase& tc, std::ostream* os) { *os << tc.name; } - -class MathExhaustiveTest : public ClientLibraryTestBase, - public ::testing::WithParamInterface { - public: - MathExhaustiveTest() { - // Disable fast-math, otherwise we get the wrong results for e.g. - // sqrt(-inf). - SetFastMathDisabled(true); - } -}; - -// Checks a function's behavior on all fp16 values. -// -// TODO(jlebar): asin and lgamma tests fail on interpreter. -#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) -XLA_TEST_P(MathExhaustiveTest, DISABLED_ON_INTERPRETER(F16)) { - const Testcase& tc = GetParam(); - XlaBuilder b(TestName()); - - std::vector input; - for (uint32 i = 0; i < 1 << 16; ++i) { - half h; - h.x = i; - - // If we're not using infinity as an input, use 0 as a placeholder rather - // than simply skipping this element. We do this because when the test - // framework reports an incorrect answer, it tells us which index failed. - // So long as our inputs are a simple list of all possible float16s, we can - // convert an index to a half with e.g. the following Python: - // - // np.frombuffer(array('H', [12345]), dtype=np.float16)[0] - // - // but as soon as our list of inputs has any gaps, this doesn't work. - if (std::isinf(static_cast(h)) && - ((tc.skip_pos_inf && h > half{0}) || - (tc.skip_neg_inf && h < half{0}))) { - h = half{0}; - } - - if (h == half{0} && tc.skip_neg_zero && - std::signbit(static_cast(h))) { - h = half{0}; - } - - input.push_back(h); - } - - std::vector expected_result; - for (const auto& h : input) { - expected_result.push_back( - static_cast(tc.host_op(static_cast(h)))); - } - - XlaOp param = AddParam(LiteralUtil::CreateR1(input), &b); - tc.op(param); - ComputeAndCompareR1(&b, expected_result, {}, tc.error); -} -#endif - -// TODO(b/123355973): The following tests from math.cc are missing. -// -// - Many failures. -// -// Testcase{"acosh", Acosh, std::acosh}.set_relaxed_nans(), -// Testcase{"asinh", Asinh, std::asinh}, -// Testcase{"sinh", Sinh, std::sinh}, -// Testcase{"cosh", Cosh, std::cosh}.set_fewer_infs_ok(), -// Testcase{"round_to_even", RoundToEven, -// [](float x) { return std::nearbyint(x / 2) * 2; }}, -// -// - Needs a special test (function takes two args, and simply computing in f32 -// and downcasting to f16 doesn't give the correct answer). -// -// Testcase{"nextafter", NextAfter, std::nextafter}, -// -// TODO(b/123355973): Test math functions not from math.cc (e.g. log). -// TODO(b/123355973): Test bf16 and f32. -// TODO(b/123355973): Get rid of skip_infs / skip_neg_zero below if possible. -// TODO(b/123355973): Move these into exhaustive_op_test. -INSTANTIATE_TEST_CASE_P( - MathExhaustiveTest_Instantiation, MathExhaustiveTest, - ::testing::ValuesIn(std::vector{ - Testcase{"square", Square, [](float x) { return x * x; }}, - Testcase{"reciprocal", Reciprocal, [](float x) { return 1 / x; }}, - Testcase{"asin", Asin, std::asin}.set_skip_infs(), - Testcase{"acos", Acos, std::acos}.set_skip_infs(), - Testcase{"atan", Atan, std::atan}, - Testcase{"tan", Tan, std::tan}.set_tolerance(0.05, 0.05), - })); - -} // namespace -} // namespace xla diff --git a/tensorflow/compiler/xla/tests/exhaustive_op_test.cc b/tensorflow/compiler/xla/tests/exhaustive_op_test.cc index 9ec53488a2c..b409eee87d8 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_op_test.cc +++ b/tensorflow/compiler/xla/tests/exhaustive_op_test.cc @@ -162,13 +162,12 @@ float HostDigamma(float x) { // For f32, f16, and bf16, we need 9, 5, and 4 decimal places of precision to be // guaranteed that we're printing the full number. // -// If we have a floating-point number with S significand bits, we need +// (The general formula is, given a floating-point number with S significand +// bits, the number of decimal digits needed to print it to full precision is // -// ceil(1 + S * log_10(2)) ~= ceil(1 + S * 0.30103) +// ceil(1 + S * log_10(2)) ~= ceil(1 + S * 0.30103). // -// decimal digits to be guaranteed that we're printing the full number. For -// F32/F16/BF16 this works out to 9/5/4 digits. See -// https://people.eecs.berkeley.edu/~wkahan/Math128/BinDecBin.pdf +// See https://people.eecs.berkeley.edu/~wkahan/Math128/BinDecBin.pdf.) string StringifyNum(float x) { return absl::StrFormat("%0.9g (0x%08x)", x, absl::bit_cast(x)); } @@ -361,9 +360,9 @@ class ExhaustiveOpTest // Easy case: If `input` is not denormal and !IsClose(expected, actual), // print an error. // - // TODO(jlebar): This doesn't correctly detect f16 and bfloat16 denormals! - // This seems to be OK for now, but at some point we may need to implement - // fpclassify for half and bfloat. + // (This doesn't correctly detect f16 and bfloat16 denormals! This seems + // to be OK for now, but at some point we may need to implement fpclassify + // for half and bfloat.) if (std::fpclassify(input_f32) != FP_SUBNORMAL) { PrintMismatch(&mismatches, [&] { return absl::StrFormat("Mismatch on %s. Expected %s, but got %s.", @@ -526,7 +525,10 @@ XLA_TEST_P(ExhaustiveOpTest, Sqrt) { Run(Sqrt, std::sqrt); } +// TODO(jlebar): Add remaining trig functions. Don't forget Atan2! +// TODO(jlebar): Test trig functions over complex inputs. XLA_TEST_P(ExhaustiveOpTest, Tanh) { Run(Tanh, std::tanh); } + XLA_TEST_P(ExhaustiveOpTest, Erf) { Run(Erf, std::erf); } XLA_TEST_P(ExhaustiveOpTest, Erfc) { Run(Erfc, std::erfc); } XLA_TEST_P(ExhaustiveOpTest, ErfInv) { Run(ErfInv, HostErfInv); } -- GitLab