提交 fd53853d 编写于 作者: J Justin Lebar 提交者: TensorFlower Gardener

[XLA] Remove math_exhaustive_test.

Now mostly covered by exhaustive_op_test.  The big missing gap is the trig
functions.  But to fix these, we need to carefully define how they're supposed
to work over complex inputs, and that's a bigger task than I'm ready to take on
at the moment.

PiperOrigin-RevId: 235281417
上级 150f75e4
...@@ -208,22 +208,6 @@ xla_test( ...@@ -208,22 +208,6 @@ xla_test(
], ],
) )
xla_test(
name = "math_exhaustive_test",
srcs = ["math_exhaustive_test.cc"],
shard_count = 16,
deps = [
":math",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)
cc_library( cc_library(
name = "matrix", name = "matrix",
srcs = ["matrix.cc"], srcs = ["matrix.cc"],
......
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/types.h"
namespace xla {
namespace {
using Eigen::half;
struct Testcase {
Testcase(string name, const std::function<XlaOp(XlaOp)>& op,
float (*host_op)(float))
: name(name), op(op), host_op(host_op) {}
Testcase& set_tolerance(float abs_err, float rel_err) {
error.abs = abs_err;
error.rel = rel_err;
return *this;
}
Testcase& set_relaxed_nans() {
error.relaxed_nans = true;
return *this;
}
Testcase& set_fewer_infs_ok() {
error.fewer_infs_ok = true;
return *this;
}
Testcase& set_skip_pos_inf() {
skip_pos_inf = true;
return *this;
}
Testcase& set_skip_neg_inf() {
skip_neg_inf = true;
return *this;
}
Testcase& set_skip_infs() {
skip_pos_inf = true;
skip_neg_inf = true;
return *this;
}
Testcase& set_skip_neg_zero() {
skip_neg_zero = true;
return *this;
}
string name;
std::function<XlaOp(XlaOp)> op;
float (*host_op)(float);
ErrorSpec error{0.01, 0.01};
// If true, don't test +/-infinity or negative 0.
bool skip_pos_inf = false;
bool skip_neg_inf = false;
bool skip_neg_zero = false;
};
void PrintTo(const Testcase& tc, std::ostream* os) { *os << tc.name; }
class MathExhaustiveTest : public ClientLibraryTestBase,
public ::testing::WithParamInterface<Testcase> {
public:
MathExhaustiveTest() {
// Disable fast-math, otherwise we get the wrong results for e.g.
// sqrt(-inf).
SetFastMathDisabled(true);
}
};
// Checks a function's behavior on all fp16 values.
//
// TODO(jlebar): asin and lgamma tests fail on interpreter.
#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16)
XLA_TEST_P(MathExhaustiveTest, DISABLED_ON_INTERPRETER(F16)) {
const Testcase& tc = GetParam();
XlaBuilder b(TestName());
std::vector<half> input;
for (uint32 i = 0; i < 1 << 16; ++i) {
half h;
h.x = i;
// If we're not using infinity as an input, use 0 as a placeholder rather
// than simply skipping this element. We do this because when the test
// framework reports an incorrect answer, it tells us which index failed.
// So long as our inputs are a simple list of all possible float16s, we can
// convert an index to a half with e.g. the following Python:
//
// np.frombuffer(array('H', [12345]), dtype=np.float16)[0]
//
// but as soon as our list of inputs has any gaps, this doesn't work.
if (std::isinf(static_cast<float>(h)) &&
((tc.skip_pos_inf && h > half{0}) ||
(tc.skip_neg_inf && h < half{0}))) {
h = half{0};
}
if (h == half{0} && tc.skip_neg_zero &&
std::signbit(static_cast<float>(h))) {
h = half{0};
}
input.push_back(h);
}
std::vector<half> expected_result;
for (const auto& h : input) {
expected_result.push_back(
static_cast<half>(tc.host_op(static_cast<float>(h))));
}
XlaOp param = AddParam(LiteralUtil::CreateR1<half>(input), &b);
tc.op(param);
ComputeAndCompareR1<half>(&b, expected_result, {}, tc.error);
}
#endif
// TODO(b/123355973): The following tests from math.cc are missing.
//
// - Many failures.
//
// Testcase{"acosh", Acosh, std::acosh}.set_relaxed_nans(),
// Testcase{"asinh", Asinh, std::asinh},
// Testcase{"sinh", Sinh, std::sinh},
// Testcase{"cosh", Cosh, std::cosh}.set_fewer_infs_ok(),
// Testcase{"round_to_even", RoundToEven,
// [](float x) { return std::nearbyint(x / 2) * 2; }},
//
// - Needs a special test (function takes two args, and simply computing in f32
// and downcasting to f16 doesn't give the correct answer).
//
// Testcase{"nextafter", NextAfter, std::nextafter},
//
// TODO(b/123355973): Test math functions not from math.cc (e.g. log).
// TODO(b/123355973): Test bf16 and f32.
// TODO(b/123355973): Get rid of skip_infs / skip_neg_zero below if possible.
// TODO(b/123355973): Move these into exhaustive_op_test.
INSTANTIATE_TEST_CASE_P(
MathExhaustiveTest_Instantiation, MathExhaustiveTest,
::testing::ValuesIn(std::vector<Testcase>{
Testcase{"square", Square, [](float x) { return x * x; }},
Testcase{"reciprocal", Reciprocal, [](float x) { return 1 / x; }},
Testcase{"asin", Asin, std::asin}.set_skip_infs(),
Testcase{"acos", Acos, std::acos}.set_skip_infs(),
Testcase{"atan", Atan, std::atan},
Testcase{"tan", Tan, std::tan}.set_tolerance(0.05, 0.05),
}));
} // namespace
} // namespace xla
...@@ -162,13 +162,12 @@ float HostDigamma(float x) { ...@@ -162,13 +162,12 @@ float HostDigamma(float x) {
// For f32, f16, and bf16, we need 9, 5, and 4 decimal places of precision to be // For f32, f16, and bf16, we need 9, 5, and 4 decimal places of precision to be
// guaranteed that we're printing the full number. // guaranteed that we're printing the full number.
// //
// If we have a floating-point number with S significand bits, we need // (The general formula is, given a floating-point number with S significand
// bits, the number of decimal digits needed to print it to full precision is
// //
// ceil(1 + S * log_10(2)) ~= ceil(1 + S * 0.30103) // ceil(1 + S * log_10(2)) ~= ceil(1 + S * 0.30103).
// //
// decimal digits to be guaranteed that we're printing the full number. For // See https://people.eecs.berkeley.edu/~wkahan/Math128/BinDecBin.pdf.)
// F32/F16/BF16 this works out to 9/5/4 digits. See
// https://people.eecs.berkeley.edu/~wkahan/Math128/BinDecBin.pdf
string StringifyNum(float x) { string StringifyNum(float x) {
return absl::StrFormat("%0.9g (0x%08x)", x, absl::bit_cast<uint32>(x)); return absl::StrFormat("%0.9g (0x%08x)", x, absl::bit_cast<uint32>(x));
} }
...@@ -361,9 +360,9 @@ class ExhaustiveOpTest ...@@ -361,9 +360,9 @@ class ExhaustiveOpTest
// Easy case: If `input` is not denormal and !IsClose(expected, actual), // Easy case: If `input` is not denormal and !IsClose(expected, actual),
// print an error. // print an error.
// //
// TODO(jlebar): This doesn't correctly detect f16 and bfloat16 denormals! // (This doesn't correctly detect f16 and bfloat16 denormals! This seems
// This seems to be OK for now, but at some point we may need to implement // to be OK for now, but at some point we may need to implement fpclassify
// fpclassify for half and bfloat. // for half and bfloat.)
if (std::fpclassify(input_f32) != FP_SUBNORMAL) { if (std::fpclassify(input_f32) != FP_SUBNORMAL) {
PrintMismatch(&mismatches, [&] { PrintMismatch(&mismatches, [&] {
return absl::StrFormat("Mismatch on %s. Expected %s, but got %s.", return absl::StrFormat("Mismatch on %s. Expected %s, but got %s.",
...@@ -526,7 +525,10 @@ XLA_TEST_P(ExhaustiveOpTest, Sqrt) { ...@@ -526,7 +525,10 @@ XLA_TEST_P(ExhaustiveOpTest, Sqrt) {
Run(Sqrt, std::sqrt); Run(Sqrt, std::sqrt);
} }
// TODO(jlebar): Add remaining trig functions. Don't forget Atan2!
// TODO(jlebar): Test trig functions over complex inputs.
XLA_TEST_P(ExhaustiveOpTest, Tanh) { Run(Tanh, std::tanh); } XLA_TEST_P(ExhaustiveOpTest, Tanh) { Run(Tanh, std::tanh); }
XLA_TEST_P(ExhaustiveOpTest, Erf) { Run(Erf, std::erf); } XLA_TEST_P(ExhaustiveOpTest, Erf) { Run(Erf, std::erf); }
XLA_TEST_P(ExhaustiveOpTest, Erfc) { Run(Erfc, std::erfc); } XLA_TEST_P(ExhaustiveOpTest, Erfc) { Run(Erfc, std::erfc); }
XLA_TEST_P(ExhaustiveOpTest, ErfInv) { Run(ErfInv, HostErfInv); } XLA_TEST_P(ExhaustiveOpTest, ErfInv) { Run(ErfInv, HostErfInv); }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册