From a7397e0cb293e237fdb1b05e43ca7523efe2cc15 Mon Sep 17 00:00:00 2001 From: yeliang2258 <30516196+yeliang2258@users.noreply.github.com> Date: Thu, 23 Mar 2023 10:35:35 +0800 Subject: [PATCH] [AMP] Add bfloat16 and float16 tests for compare ops (#51978) * add bf16 and fp16 tests * fix dtype check --- paddle/phi/kernels/cpu/compare_kernel.cc | 6 ++- .../fluid/tests/unittests/test_compare_op.py | 42 +++++++++++++++++++ python/paddle/tensor/logic.py | 4 +- 3 files changed, 48 insertions(+), 4 deletions(-) diff --git a/paddle/phi/kernels/cpu/compare_kernel.cc b/paddle/phi/kernels/cpu/compare_kernel.cc index 7687a872cce..cf8eb47fb42 100644 --- a/paddle/phi/kernels/cpu/compare_kernel.cc +++ b/paddle/phi/kernels/cpu/compare_kernel.cc @@ -94,7 +94,8 @@ PD_REGISTER_KERNEL(equal_all, int64_t, \ float, \ double, \ - phi::dtype::float16) { \ + phi::dtype::float16, \ + phi::dtype::bfloat16) { \ kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \ } \ PD_REGISTER_KERNEL(name##_raw, \ @@ -107,7 +108,8 @@ PD_REGISTER_KERNEL(equal_all, int64_t, \ float, \ double, \ - phi::dtype::float16) { \ + phi::dtype::float16, \ + phi::dtype::bfloat16) { \ kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \ } PD_REGISTER_COMPARE_KERNEL(less_than, LessThan) diff --git a/python/paddle/fluid/tests/unittests/test_compare_op.py b/python/paddle/fluid/tests/unittests/test_compare_op.py index 0ea2e00a716..905c7f2a920 100755 --- a/python/paddle/fluid/tests/unittests/test_compare_op.py +++ b/python/paddle/fluid/tests/unittests/test_compare_op.py @@ -136,6 +136,15 @@ def create_paddle_case(op_type, callback): self.assertEqual((out.numpy() == self.real_result).all(), True) paddle.enable_static() + def test_dynamic_api_float16(self): + paddle.disable_static() + x = paddle.to_tensor(self.input_x, dtype="float16") + y = paddle.to_tensor(self.input_y, dtype="float16") + op = eval("paddle.%s" % (self.op_type)) + out = op(x, y) + self.assertEqual((out.numpy() == self.real_result).all(), True) + paddle.enable_static() + def test_dynamic_api_inf_1(self): if self.op_type == "equal": paddle.disable_static() @@ -434,6 +443,39 @@ create_paddle_case('equal', lambda _a, _b: _a == _b) create_paddle_case('not_equal', lambda _a, _b: _a != _b) +# add bf16 tests +def create_bf16_case(op_type, callback): + class TestCompareOpBF16Op(op_test.OpTest): + def setUp(self): + self.op_type = op_type + self.dtype = np.uint16 + self.python_api = eval("paddle." + op_type) + + x = np.random.uniform(0, 1, [5, 5]).astype(np.float32) + y = np.random.uniform(0, 1, [5, 5]).astype(np.float32) + real_result = callback(x, y) + self.inputs = { + 'X': op_test.convert_float_to_uint16(x), + 'Y': op_test.convert_float_to_uint16(y), + } + self.outputs = {'Out': real_result} + + def test_check_output(self): + self.check_output() + + cls_name = "BF16TestCase_{}".format(op_type) + TestCompareOpBF16Op.__name__ = cls_name + globals()[cls_name] = TestCompareOpBF16Op + + +create_bf16_case('less_than', lambda _a, _b: _a < _b) +create_bf16_case('less_equal', lambda _a, _b: _a <= _b) +create_bf16_case('greater_than', lambda _a, _b: _a > _b) +create_bf16_case('greater_equal', lambda _a, _b: _a >= _b) +create_bf16_case('equal', lambda _a, _b: _a == _b) +create_bf16_case('not_equal', lambda _a, _b: _a != _b) + + class TestCompareOpError(unittest.TestCase): def test_errors(self): paddle.enable_static() diff --git a/python/paddle/tensor/logic.py b/python/paddle/tensor/logic.py index e10e7c647be..f214bf0c861 100644 --- a/python/paddle/tensor/logic.py +++ b/python/paddle/tensor/logic.py @@ -746,13 +746,13 @@ def not_equal(x, y, name=None): check_variable_and_dtype( x, "x", - ["bool", "float32", "float64", "int32", "int64"], + ["bool", "float16", "float32", "float64", "int32", "int64"], "not_equal", ) check_variable_and_dtype( y, "y", - ["bool", "float32", "float64", "int32", "int64"], + ["bool", "float16", "float32", "float64", "int32", "int64"], "not_equal", ) helper = LayerHelper("not_equal", **locals()) -- GitLab