diff --git a/python/paddle/fluid/tests/unittests/test_where_op.py b/python/paddle/fluid/tests/unittests/test_where_op.py index 6e2cd730fa1390c92b694367bc0931af7aed0dc4..d1e9e4097975c6a7007b19a44225efe3ea057f99 100644 --- a/python/paddle/fluid/tests/unittests/test_where_op.py +++ b/python/paddle/fluid/tests/unittests/test_where_op.py @@ -15,11 +15,11 @@ import unittest import numpy as np -from op_test import OpTest +from op_test import OpTest, convert_float_to_uint16 import paddle import paddle.fluid as fluid -from paddle.fluid import Program, program_guard +from paddle.fluid import Program, core, program_guard from paddle.fluid.backward import append_backward @@ -50,6 +50,50 @@ class TestWhereOp2(TestWhereOp): self.cond = np.ones((60, 2)).astype('bool') +class TestWhereFP16OP(TestWhereOp): + def init_config(self): + self.dtype = np.float16 + self.x = np.random.uniform((-5), 5, (60, 2)).astype(self.dtype) + self.y = np.random.uniform((-5), 5, (60, 2)).astype(self.dtype) + self.cond = np.ones((60, 2)).astype('bool') + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not complied with CUDA and not support the bfloat16", +) +class TestWhereBF16OP(OpTest): + def setUp(self): + self.op_type = 'where' + self.dtype = np.uint16 + self.python_api = paddle.where + self.init_config() + self.inputs = { + 'Condition': self.cond, + 'X': convert_float_to_uint16(self.x), + 'Y': convert_float_to_uint16(self.y), + } + self.outputs = { + 'Out': convert_float_to_uint16(np.where(self.cond, self.x, self.y)) + } + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place, check_eager=False) + + def test_check_grad(self): + place = core.CUDAPlace(0) + self.check_grad_with_place( + place, ['X', 'Y'], 'Out', check_eager=False, numeric_grad_delta=0.05 + ) + + def init_config(self): + self.x = np.random.uniform((-5), 5, (60, 2)).astype(np.float32) + self.y = np.random.uniform((-5), 5, (60, 2)).astype(np.float32) + self.cond = np.random.randint(2, size=(60, 2)).astype('bool') + + class TestWhereOp3(TestWhereOp): def init_config(self): self.x = np.random.uniform((-3), 5, (20, 2, 4)).astype('float64')