From 6bd5fd752662d276e4e53e6d30eae1941377fde7 Mon Sep 17 00:00:00 2001 From: Vvsmile <17864154871@163.com> Date: Mon, 10 Apr 2023 12:55:49 +0800 Subject: [PATCH] [AMP OP&Test] Add fp16 and bf16 test to activation (#52521) * adjust defalut tolerance of output and grad * fix a bug in the grad of OpTest * fix the type of setting defalut value in optest, both forward and backward * add defalut * fix test_sum_op * adjust tolerance * fix the tolerance of eager * add bf16 and fp16 to the activation tests * remove some fixs * fix activation * fix fp16 * fix gelu * fix the activation tests * add bfloat16 specialization to singrad and cosgrad * fix bugs * fix bugs * add unittest * add skip * add fp/bf to rrelu/rrelu_grad * git add rrelu * fix bugs --- paddle/phi/kernels/funcs/activation_functor.h | 22 +- .../phi/kernels/gpu/activation_grad_kernel.cu | 15 +- paddle/phi/kernels/gpu/rrelu_grad_kernel.cu | 1 + paddle/phi/kernels/gpu/rrelu_kernel.cu | 1 + .../fluid/tests/unittests/eager_op_test.py | 28 +- .../tests/unittests/test_activation_op.py | 290 +++++++++++++----- .../fluid/tests/unittests/test_rrelu_op.py | 73 ++++- 7 files changed, 339 insertions(+), 91 deletions(-) diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index 7c48c0a02b4..78a1f8cb24f 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -70,6 +70,13 @@ struct Sine { } }; +template <> +struct Sine { + HOSTDEVICE dtype::bfloat16 operator()(const dtype::bfloat16& val) const { + return dtype::bfloat16(sin(static_cast(val))); + } +}; + template struct Cosine { HOSTDEVICE T operator()(const T& val) const { return cos(val); } @@ -82,6 +89,13 @@ struct Cosine { } }; +template <> +struct Cosine { + HOSTDEVICE dtype::bfloat16 operator()(const dtype::bfloat16& val) const { + return dtype::bfloat16(cos(static_cast(val))); + } +}; + // sine'(x) = cos(x) template struct SinGradFunctor : public BaseActivationFunctor { @@ -2664,10 +2678,12 @@ struct CudaExpGradFunctor : public BaseActivationFunctor { template struct CudaReciprocalFunctor : public BaseActivationFunctor { - T one = static_cast(1.0f); + using MPType = typename phi::dtype::MPTypeTrait::Type; + MPType one = static_cast(1.0f); - // reciprocal(x) = 1 / x - __device__ __forceinline__ T operator()(const T x) const { return one / x; } + __device__ __forceinline__ T operator()(const T x) const { + return static_cast(one / static_cast(x)); + } }; template diff --git a/paddle/phi/kernels/gpu/activation_grad_kernel.cu b/paddle/phi/kernels/gpu/activation_grad_kernel.cu index 5573d666776..e56c3cf4f42 100644 --- a/paddle/phi/kernels/gpu/activation_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_grad_kernel.cu @@ -425,7 +425,8 @@ PD_REGISTER_KERNEL(sin_double_grad, double, int, int64_t, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(sin_triple_grad, GPU, @@ -435,7 +436,8 @@ PD_REGISTER_KERNEL(sin_triple_grad, double, int, int64_t, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(cos_double_grad, GPU, @@ -445,7 +447,8 @@ PD_REGISTER_KERNEL(cos_double_grad, double, int, int64_t, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(cos_triple_grad, GPU, @@ -455,7 +458,8 @@ PD_REGISTER_KERNEL(cos_triple_grad, double, int, int64_t, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} PD_REGISTER_ACTIVATION_GRAD_KERNEL(softsign_grad, SoftsignGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_grad, SigmoidGradKernel) @@ -473,7 +477,8 @@ PD_REGISTER_KERNEL(log_double_grad, phi::LogDoubleGradKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} PD_REGISTER_ACTIVATION_GRAD_KERNEL(hardswish_grad, HardSwishGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(round_grad, RoundGradKernel) diff --git a/paddle/phi/kernels/gpu/rrelu_grad_kernel.cu b/paddle/phi/kernels/gpu/rrelu_grad_kernel.cu index fa9ef450307..361e4c28e16 100644 --- a/paddle/phi/kernels/gpu/rrelu_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/rrelu_grad_kernel.cu @@ -83,4 +83,5 @@ PD_REGISTER_KERNEL(rrelu_grad, phi::RReluGradKernel, float, phi::dtype::float16, + phi::dtype::bfloat16, double) {} diff --git a/paddle/phi/kernels/gpu/rrelu_kernel.cu b/paddle/phi/kernels/gpu/rrelu_kernel.cu index e872cbf3cb6..b15e525a3bc 100644 --- a/paddle/phi/kernels/gpu/rrelu_kernel.cu +++ b/paddle/phi/kernels/gpu/rrelu_kernel.cu @@ -110,4 +110,5 @@ PD_REGISTER_KERNEL(rrelu, phi::RReluKernel, float, phi::dtype::float16, + phi::dtype::bfloat16, double) {} diff --git a/python/paddle/fluid/tests/unittests/eager_op_test.py b/python/paddle/fluid/tests/unittests/eager_op_test.py index b764a1acd0a..9b0868edfa7 100644 --- a/python/paddle/fluid/tests/unittests/eager_op_test.py +++ b/python/paddle/fluid/tests/unittests/eager_op_test.py @@ -885,7 +885,9 @@ class OpTest(unittest.TestCase): np_dyg, rtol=1e-05, equal_nan=False, - err_msg='Output (' + err_msg='Operator (' + + self.op_type + + ') Output (' + name + ') has diff at ' + str(place) @@ -1137,7 +1139,9 @@ class OpTest(unittest.TestCase): actual_out, rtol=1e-05, atol=inplace_atol, - err_msg='Output (' + err_msg='Operator (' + + self.op_type + + ') Output (' + name + ') has diff at ' + str(place) @@ -1626,7 +1630,9 @@ class OpTest(unittest.TestCase): rtol=self.rtol if hasattr(self, 'rtol') else rtol, equal_nan=equal_nan, err_msg=( - "Output (" + "Operator (" + + self.op_type + + ") Output (" + name + ") has diff at " + str(place) @@ -1643,7 +1649,9 @@ class OpTest(unittest.TestCase): rtol=self.rtol if hasattr(self, 'rtol') else rtol, equal_nan=equal_nan, ), - "Output (" + "Operator (" + + self.op_type + + ") Output (" + name + ") has diff at " + str(place) @@ -1815,7 +1823,9 @@ class OpTest(unittest.TestCase): rtol=self.rtol if hasattr(self, 'rtol') else rtol, equal_nan=equal_nan, err_msg=( - "Output (" + "Operator (" + + self.op_type + + ") Output (" + name + ") has diff at " + str(place) @@ -1832,7 +1842,9 @@ class OpTest(unittest.TestCase): rtol=self.rtol if hasattr(self, 'rtol') else rtol, equal_nan=equal_nan, ), - "Output (" + "Operator (" + + self.op_type + + ") Output (" + name + ") has diff at " + str(place) @@ -1882,7 +1894,9 @@ class OpTest(unittest.TestCase): .get_tensor() .recursive_sequence_lengths(), expect[1], - "Output (" + "Operator (" + + self.op_type + + ") Output (" + name + ") has different lod at " + str(place) diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index 8656c564136..dfa95f760ce 100644 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -63,6 +63,8 @@ class TestActivation(OpTest): self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} + self.convert_input_output() + def test_check_output(self): self.check_output() @@ -83,6 +85,9 @@ class TestActivation(OpTest): def init_kernel_type(self): pass + def convert_input_output(self): + pass + class TestActivation_ZeroDim(TestActivation): def init_shape(self): @@ -148,6 +153,7 @@ class TestExpm1(TestActivation): self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} + self.convert_input_output() def test_check_grad(self): self.check_grad(['X'], 'Out') @@ -247,6 +253,8 @@ class TestSigmoid(TestActivation): self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} + self.convert_input_output() + def init_dtype(self): self.dtype = np.float32 @@ -320,10 +328,11 @@ class TestSilu(TestActivation): np.random.seed(1024) x = np.random.uniform(-1, 1, self.shape).astype(self.dtype) out = x / (np.exp(-x) + 1) - - self.inputs = {'X': x} + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} + self.convert_input_output() + def init_dtype(self): self.dtype = np.float32 @@ -401,10 +410,11 @@ class TestLogSigmoid(TestActivation): np.random.seed(2048) x = np.random.uniform(-1, 1, self.shape).astype(self.dtype) out = np.log(1 / (1 + np.exp(-x))) - - self.inputs = {'X': x} + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} + self.convert_input_output() + def test_check_grad(self): if self.dtype == np.float16: return @@ -479,9 +489,9 @@ class TestTanh(TestActivation, TestParameter): np.random.seed(1024) x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype) out = np.tanh(x) - self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} + self.convert_input_output() def test_check_grad(self): if self.dtype == np.float16: @@ -581,6 +591,7 @@ class TestAtan(TestActivation, TestParameter): self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} + self.convert_input_output() def test_check_grad(self): if self.dtype == np.float16: @@ -625,10 +636,11 @@ class TestSinh(TestActivation): np.random.seed(1024) x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype) out = np.sinh(x) - self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} + self.convert_input_output() + def test_check_grad(self): if self.dtype == np.float16: return @@ -716,10 +728,11 @@ class TestCosh(TestActivation): np.random.seed(1024) x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype) out = np.cosh(x) - self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} + self.convert_input_output() + def test_check_grad(self): if self.dtype == np.float16: return @@ -812,10 +825,11 @@ class TestTanhshrink(TestActivation): np.random.seed(1024) x = np.random.uniform(10, 20, self.shape).astype(self.dtype) out = ref_tanhshrink(x) - - self.inputs = {'X': x} + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} + self.convert_input_output() + def test_check_grad(self): if self.dtype == np.float16: return @@ -895,10 +909,12 @@ class TestHardShrink(TestActivation): np.random.seed(1024) x = np.random.uniform(-1, 1, self.shape).astype(self.dtype) * 10 out = ref_hardshrink(x, self.threshold) + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} self.attrs = {'threshold': self.threshold} - self.inputs = {'X': x} - self.outputs = {'Out': out} + + self.convert_input_output() def init_shape(self): self.shape = [10, 12] @@ -1067,10 +1083,12 @@ class TestSoftshrink(TestActivation): np.random.seed(1023) x = np.random.uniform(0.25, 10, self.shape).astype(self.dtype) out = ref_softshrink(x, threshold) - self.inputs = {'X': x} - self.attrs = {"lambda": threshold} + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} + self.attrs = {"lambda": threshold} + def test_check_grad(self): if self.dtype == np.float16: return @@ -1154,6 +1172,7 @@ class TestSqrt(TestActivation, TestParameter): self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} + self.convert_input_output() self.enable_cinn = False # TODO(wanghao107) add prim test @@ -1266,6 +1285,7 @@ class TestSqrtComp(TestActivation, TestParameter): self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} + self.convert_input_output() self.enable_cinn = True def test_check_grad(self): @@ -1320,6 +1340,7 @@ class TestRsqrt(TestActivation): self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} + self.convert_input_output() self.enable_cinn = True def init_shape(self): @@ -1368,6 +1389,7 @@ class TestAbs(TestActivation): self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} + self.convert_input_output() def init_shape(self): self.shape = [4, 25] @@ -1396,6 +1418,7 @@ class TestCeil(TestActivation): self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} + self.convert_input_output() def init_shape(self): self.shape = [10, 12] @@ -1425,6 +1448,7 @@ class TestFloor(TestActivation): self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} + self.convert_input_output() def init_shape(self): self.shape = [10, 12] @@ -1496,9 +1520,9 @@ class TestCos(TestActivation): np.random.seed(1024) x = np.random.uniform(-1, 1, self.shape).astype(self.dtype) out = np.cos(x) - self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} + self.convert_input_output() def init_shape(self): self.shape = [10, 12] @@ -1534,6 +1558,7 @@ class TestTan(TestActivation): self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(self.x_np)} self.outputs = {'Out': out} + self.convert_input_output() def init_shape(self): self.shape = [10, 12] @@ -1603,6 +1628,7 @@ class TestAcos(TestActivation): self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} + self.convert_input_output() def init_shape(self): self.shape = [10, 12] @@ -1632,9 +1658,9 @@ class TestSin(TestActivation, TestParameter): np.random.seed(1024) x = np.random.uniform(-1, 1, self.shape).astype(self.dtype) out = np.sin(x) - self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} + self.convert_input_output() def init_shape(self): self.shape = [10, 12] @@ -1663,6 +1689,7 @@ class TestAsin(TestActivation): self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} + self.convert_input_output() def init_shape(self): self.shape = [10, 12] @@ -1691,6 +1718,7 @@ class TestAcosh(TestActivation): self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} + self.convert_input_output() def init_shape(self): self.shape = [10, 12] @@ -1719,6 +1747,7 @@ class TestAsinh(TestActivation): self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} + self.convert_input_output() def init_shape(self): self.shape = [10, 12] @@ -1747,6 +1776,7 @@ class TestAtanh(TestActivation): self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} + self.convert_input_output() def init_shape(self): self.shape = [10, 12] @@ -1775,6 +1805,7 @@ class TestRound(TestActivation): self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} + self.convert_input_output() def init_shape(self): self.shape = [10, 12] @@ -1799,20 +1830,14 @@ class TestRelu(TestActivation): self.skip_cinn() np.random.seed(1024) - if self.dtype == np.uint16: - x = np.random.uniform(-1, 1, self.shape).astype(np.float32) - # The same reason with TestAbs - x[np.abs(x) < 0.005] = 0.02 - out = convert_float_to_uint16(np.maximum(x, 0)) - self.inputs = {'X': convert_float_to_uint16(x)} - else: - x = np.random.uniform(-1, 1, self.shape).astype(self.dtype) - # The same reason with TestAbs - x[np.abs(x) < 0.005] = 0.02 - out = np.maximum(x, 0) - self.inputs = {'X': x} + x = np.random.uniform(-1, 1, self.shape).astype(self.dtype) + # The same reason with TestAbs + x[np.abs(x) < 0.005] = 0.02 + out = np.maximum(x, 0) + self.inputs = {'X': x} self.outputs = {'Out': out} + self.convert_input_output() def test_check_grad(self): if self.dtype == np.float16: @@ -1921,6 +1946,7 @@ class TestLeakyRelu(TestActivation): self.inputs = {'X': x} self.outputs = {'Out': out} self.attrs = {'alpha': alpha} + self.convert_input_output() def test_check_grad(self): if self.dtype == np.float16: @@ -2066,8 +2092,9 @@ class TestGelu(TestActivation): out = gelu(x, approximate) self.if_enable_cinn() - self.inputs = {'X': x} + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} + self.convert_input_output() self.attrs = {"approximate": approximate} # The backward decomposite of gelu is inconsistent with raw kernel on # cpu, lower threshold to support 1e-8 for pass the unittest @@ -2175,8 +2202,9 @@ class TestBRelu(TestActivation): t[t > t_max] = t_max self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} - self.attrs = {'t_min': t_min, 't_max': t_max} self.outputs = {'Out': t} + self.convert_input_output() + self.attrs = {'t_min': t_min, 't_max': t_max} def test_check_grad(self): if self.dtype == np.float16: @@ -2203,9 +2231,11 @@ class TestRelu6(TestActivation): x[np.abs(x) < 0.005] = 0.02 out = ref_relu6(x) - self.inputs = {'X': x} self.attrs = {'threshold': 6.0} + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} + self.convert_input_output() def init_shape(self): self.shape = [10, 12] @@ -2338,9 +2368,10 @@ class TestHardSwish(TestActivation): x[np.abs(x - threshold + offset) < 0.005] = threshold - offset + 0.02 out = ref_hardswish(x, threshold, scale, offset) - self.inputs = {'X': x} - self.attrs = {'threshold': threshold, 'scale': scale, 'offset': offset} + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} + self.convert_input_output() + self.attrs = {'threshold': threshold, 'scale': scale, 'offset': offset} self.enable_cinn = False def init_shape(self): @@ -2450,8 +2481,9 @@ class TestSoftRelu(TestActivation): out = np.log(np.exp(t) + 1) self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} - self.attrs = {'threshold': threshold} self.outputs = {'Out': out} + self.convert_input_output() + self.attrs = {'threshold': threshold} def test_check_output(self): self.check_output(check_dygraph=False) @@ -2482,9 +2514,11 @@ class TestELU(TestActivation): out = elu(x, alpha) # Note: unlike other Relu extensions, point 0 on standard ELU function (i.e. alpha = 1) # is differentiable, so we can skip modifications like x[np.abs(x) < 0.005] = 0.02 here - self.inputs = {'X': x} - self.attrs = {'alpha': alpha} + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} + self.convert_input_output() + self.attrs = {'alpha': alpha} def init_shape(self): self.shape = [10, 12] @@ -2597,9 +2631,11 @@ class TestCELU(TestActivation): x = np.random.uniform(-3, 3, self.shape).astype(self.dtype) alpha = 1.5 out = celu(x, alpha) - self.inputs = {'X': x} - self.attrs = {'alpha': alpha} + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} + self.convert_input_output() + self.attrs = {'alpha': alpha} def init_shape(self): self.shape = [10, 12] @@ -2696,6 +2732,7 @@ class TestReciprocal(TestActivation): self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} + self.convert_input_output() def test_check_grad(self): if self.dtype == np.float16: @@ -2730,6 +2767,7 @@ class TestLog(TestActivation): self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} + self.convert_input_output() def test_check_grad(self): if self.dtype == np.float16: @@ -2782,6 +2820,7 @@ class TestLog2(TestActivation): self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} + self.convert_input_output() def test_check_grad(self): if self.dtype == np.float16: @@ -2844,6 +2883,7 @@ class TestLog10(TestActivation): self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} + self.convert_input_output() def test_check_grad(self): if self.dtype == np.float16: @@ -2909,6 +2949,7 @@ class TestLog1p(TestActivation): self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} + self.convert_input_output() def test_check_grad(self): if self.dtype == np.float16: @@ -2981,6 +3022,7 @@ class TestSquare(TestActivation): self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} + self.convert_input_output() def test_check_grad(self): if self.dtype == np.float16: @@ -3040,8 +3082,9 @@ class TestPow(TestActivation): out = np.power(x, 3) self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} - self.attrs = {'factor': 3.0} self.outputs = {'Out': out} + self.attrs = {'factor': 3.0} + self.convert_input_output() def test_check_output(self): self.check_output(check_prim=True) @@ -3142,9 +3185,10 @@ class TestSTanh(TestActivation): # The same reason with TestAbs out = ref_stanh(x, scale_a, scale_b) - self.inputs = {'X': x} - self.attrs = {'scale_a': scale_a, 'scale_b': scale_b} + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} + self.attrs = {'scale_a': scale_a, 'scale_b': scale_b} + self.convert_input_output() def test_check_grad(self): if self.dtype == np.float16: @@ -3380,8 +3424,10 @@ class TestSoftsign(TestActivation): np.random.seed(1024) x = np.random.uniform(-1, 1, self.shape).astype(self.dtype) out = ref_softsign(x) - self.inputs = {'X': x} + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} + self.convert_input_output() def init_shape(self): self.shape = [10, 12] @@ -3465,9 +3511,11 @@ class TestThresholdedRelu(TestActivation): x = np.random.uniform(-20, 20, self.shape).astype(self.dtype) x[np.abs(x) < 0.005] = 0.02 out = ref_thresholded_relu(x, threshold) - self.inputs = {'X': x} - self.attrs = {"threshold": threshold} + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} + self.attrs = {"threshold": threshold} + self.convert_input_output() def init_shape(self): self.shape = [10, 12] @@ -3561,8 +3609,10 @@ class TestHardSigmoid(TestActivation): out = ref_hardsigmoid(x, self.slope, self.offset) self.attrs = {'slope': self.slope, 'offset': self.offset} - self.inputs = {'X': x} + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} + self.convert_input_output() def init_shape(self): self.shape = [10, 12] @@ -3666,9 +3716,11 @@ class TestSwish(TestActivation): np.random.seed(1024) x = np.random.uniform(-1, 1, self.shape).astype(self.dtype) out = ref_swish(x) - self.inputs = {'X': x} - self.attrs = {'beta': 1.0} + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} + self.attrs = {'beta': 1.0} + self.convert_input_output() def init_shape(self): self.shape = [10, 12] @@ -3764,8 +3816,10 @@ class TestMish(TestActivation): np.random.seed(1024) x = np.random.uniform(-1, 1, self.shape).astype(self.dtype) out = ref_mish(x) - self.inputs = {'X': x} + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} + self.convert_input_output() def init_shape(self): self.shape = [10, 12] @@ -3872,7 +3926,7 @@ def create_test_act_fp16_class( check_dygraph=True, check_prim=False, enable_cinn=True, - grad_atol=0.80, + grad_atol=1e-2, **kwargs ): @unittest.skipIf( @@ -3914,7 +3968,7 @@ def create_test_act_fp16_class( max_relative_error=grad_atol, ) - cls_name = "{}_{}".format(parent.__name__, "fp16") + cls_name = "{}_{}".format(parent.__name__, "FP16OP") TestActFp16.__name__ = cls_name globals()[cls_name] = TestActFp16 @@ -3933,17 +3987,17 @@ create_test_act_fp16_class(TestSqrtComp, check_prim=True) create_test_act_fp16_class(TestAbs, check_prim=True) create_test_act_fp16_class(TestCeil, grad_check=False) create_test_act_fp16_class(TestFloor, check_prim=True, grad_check=False) -create_test_act_fp16_class(TestCos, grad_atol=0.85) -create_test_act_fp16_class(TestTan, grad_atol=0.85) -create_test_act_fp16_class(TestCosh, grad_atol=0.85) -create_test_act_fp16_class(TestAcos, grad_atol=0.85) +create_test_act_fp16_class(TestCos) +create_test_act_fp16_class(TestTan) +create_test_act_fp16_class(TestCosh) +create_test_act_fp16_class(TestAcos) create_test_act_fp16_class(TestSin) create_test_act_fp16_class(TestSinh) create_test_act_fp16_class(TestAsin) create_test_act_fp16_class(TestAtan) -create_test_act_fp16_class(TestAcosh, grad_atol=0.85) -create_test_act_fp16_class(TestAsinh, grad_atol=0.85) -create_test_act_fp16_class(TestAtanh, grad_atol=0.85) +create_test_act_fp16_class(TestAcosh) +create_test_act_fp16_class(TestAsinh) +create_test_act_fp16_class(TestAtanh) create_test_act_fp16_class(TestRound, grad_check=False) create_test_act_fp16_class(TestRelu, check_prim=True) create_test_act_fp16_class( @@ -3955,38 +4009,63 @@ create_test_act_fp16_class( ) create_test_act_fp16_class(TestBRelu) create_test_act_fp16_class(TestRelu6) -create_test_act_fp16_class(TestSoftRelu, check_dygraph=False, grad_atol=0.85) +create_test_act_fp16_class(TestSoftRelu, check_dygraph=False) create_test_act_fp16_class(TestELU) create_test_act_fp16_class(TestCELU) create_test_act_fp16_class(TestReciprocal) create_test_act_fp16_class(TestLog, check_prim=True) if core.is_compiled_with_rocm(): - create_test_act_fp16_class(TestLog2, atol=5e-2, grad_atol=0.85) + create_test_act_fp16_class(TestLog2) else: - create_test_act_fp16_class(TestLog2, atol=5e-2) -create_test_act_fp16_class(TestLog10, atol=5e-2) -create_test_act_fp16_class(TestLog1p, grad_atol=0.9) + create_test_act_fp16_class(TestLog2) +create_test_act_fp16_class(TestLog10) +create_test_act_fp16_class(TestLog1p) create_test_act_fp16_class(TestSquare) -create_test_act_fp16_class(TestPow, check_prim=True, atol=5e-2) -create_test_act_fp16_class(TestPow_factor_tensor, atol=5e-2) -create_test_act_fp16_class(TestSTanh, grad_atol=0.9) +create_test_act_fp16_class(TestPow, check_prim=True) +create_test_act_fp16_class(TestPow_factor_tensor) +create_test_act_fp16_class(TestSTanh) create_test_act_fp16_class(TestSoftplus) create_test_act_fp16_class(TestSoftsign) create_test_act_fp16_class(TestThresholdedRelu) create_test_act_fp16_class(TestHardSigmoid) -create_test_act_fp16_class(TestSwish, grad_atol=0.85) +create_test_act_fp16_class(TestSwish) create_test_act_fp16_class(TestHardSwish, check_prim=True) -create_test_act_fp16_class(TestMish, grad_atol=0.9) +create_test_act_fp16_class(TestMish) +create_test_act_fp16_class(TestLeakyRelu) +create_test_act_fp16_class(TestLeakyReluAlpha1) +create_test_act_fp16_class(TestLeakyReluAlpha2) +create_test_act_fp16_class(TestLeakyReluAlpha3) +create_test_act_fp16_class(TestLeakyRelu_ZeroDim) +create_test_act_fp16_class(TestRsqrt) def create_test_act_bf16_class( - parent, atol=1e-2, grad_check=True, grad_atol=0.80 + parent, + atol=1e-2, + grad_check=True, + check_dygraph=True, + check_prim=False, + enable_cinn=True, + grad_atol=1e-2, + **kwargs ): @unittest.skipIf( - not paddle.is_compiled_with_cuda(), "core is not compiled with CUDA" + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA and do not support bfloat16", ) class TestActBF16(parent): + def setUp(self): + super().setUp() + for k, v in kwargs.items(): + setattr(self, k, v) + def init_dtype(self): + self.dtype = np.float32 + + def convert_input_output(self): + self.inputs = {'X': convert_float_to_uint16(self.inputs['X'])} + self.outputs = {'Out': convert_float_to_uint16(self.outputs['Out'])} self.dtype = np.uint16 def test_check_output(self): @@ -3995,17 +4074,80 @@ def create_test_act_bf16_class( def test_check_grad(self): place = core.CUDAPlace(0) - self.check_grad_with_place( - place, ['X'], 'Out', max_relative_error=grad_atol - ) + if grad_check: + self.check_grad_with_place( + place, ['X'], 'Out', max_relative_error=grad_atol + ) - cls_name = "{}_{}".format(parent.__name__, "bf16") + cls_name = "{}_{}".format(parent.__name__, "BF16OP") TestActBF16.__name__ = cls_name globals()[cls_name] = TestActBF16 -create_test_act_bf16_class(TestRelu) -create_test_act_bf16_class(TestAbs) +create_test_act_bf16_class(TestActivation, check_prim=True) +create_test_act_bf16_class(TestExpm1) +create_test_act_bf16_class(TestSigmoid, check_prim=True) +create_test_act_bf16_class(TestSilu, check_prim=True) +create_test_act_bf16_class(TestLogSigmoid) +create_test_act_bf16_class(TestTanh) +create_test_act_bf16_class(TestTanhshrink) +create_test_act_bf16_class(TestHardShrink) +create_test_act_bf16_class(TestSoftshrink) +create_test_act_bf16_class(TestSqrt, check_prim=True) +create_test_act_bf16_class(TestSqrtComp, check_prim=True) +create_test_act_bf16_class(TestAbs, check_prim=True) +create_test_act_bf16_class(TestCeil, grad_check=False) +create_test_act_bf16_class(TestFloor, grad_check=False, check_prim=True) +create_test_act_bf16_class(TestCos) +create_test_act_bf16_class(TestTan) +create_test_act_bf16_class(TestCosh) +create_test_act_bf16_class(TestAcos) +create_test_act_bf16_class(TestSin) +create_test_act_bf16_class(TestSinh) +create_test_act_bf16_class(TestAsin) +create_test_act_bf16_class(TestAtan) +create_test_act_bf16_class(TestAcosh) +create_test_act_bf16_class(TestAsinh) +create_test_act_bf16_class(TestAtanh) +create_test_act_bf16_class(TestRound, grad_check=False) +create_test_act_bf16_class(TestRelu, check_prim=True) +create_test_act_bf16_class( + TestGelu, + check_prim=True, + enable_cinn=False, + rev_comp_rtol=1e-2, + rev_comp_atol=1e-2, +) +create_test_act_bf16_class(TestBRelu) +create_test_act_bf16_class(TestRelu6) +create_test_act_bf16_class(TestSoftRelu, check_dygraph=False) +create_test_act_bf16_class(TestELU) +create_test_act_bf16_class(TestCELU) +create_test_act_bf16_class(TestReciprocal) +create_test_act_bf16_class(TestLog, check_prim=True) +if core.is_compiled_with_rocm(): + create_test_act_bf16_class(TestLog2) +else: + create_test_act_bf16_class(TestLog2) +create_test_act_bf16_class(TestLog10) +create_test_act_bf16_class(TestLog1p) +create_test_act_bf16_class(TestSquare) +create_test_act_bf16_class(TestPow, check_prim=True) +create_test_act_bf16_class(TestPow_factor_tensor) +create_test_act_bf16_class(TestSTanh) +create_test_act_bf16_class(TestSoftplus) +create_test_act_bf16_class(TestSoftsign) +create_test_act_bf16_class(TestThresholdedRelu) +create_test_act_bf16_class(TestHardSigmoid) +create_test_act_bf16_class(TestSwish) +create_test_act_bf16_class(TestHardSwish, check_prim=True) +create_test_act_bf16_class(TestMish) +create_test_act_bf16_class(TestLeakyRelu) +create_test_act_bf16_class(TestLeakyReluAlpha1) +create_test_act_bf16_class(TestLeakyReluAlpha2) +create_test_act_bf16_class(TestLeakyReluAlpha3) +create_test_act_bf16_class(TestLeakyRelu_ZeroDim) +create_test_act_bf16_class(TestRsqrt) if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_rrelu_op.py b/python/paddle/fluid/tests/unittests/test_rrelu_op.py index eb7fb9df5e1..b86b7808aba 100644 --- a/python/paddle/fluid/tests/unittests/test_rrelu_op.py +++ b/python/paddle/fluid/tests/unittests/test_rrelu_op.py @@ -15,7 +15,7 @@ import unittest import numpy as np -from eager_op_test import OpTest +from eager_op_test import OpTest, convert_float_to_uint16 import paddle import paddle.nn.functional as F @@ -327,7 +327,7 @@ class RReluTest(OpTest): ] # python out sig is customized output signature. def init_params(self): - self.dtype = "float64" + self.init_dtype() self.x_shape = [2, 3, 4, 5] x_np = np.random.uniform(-1, 1, self.x_shape).astype(self.dtype) @@ -337,12 +337,19 @@ class RReluTest(OpTest): self.inputs = {'X': x_np} self.outputs = {'Out': out_np, 'Noise': noise_np} + self.convert_input_output() self.attrs = { 'lower': self.lower, "upper": self.upper, "is_test": self.is_test, } + def init_dtype(self): + self.dtype = "float64" + + def convert_input_output(self): + pass + def test_check_output(self): self.check_output(no_check_set=['Noise']) @@ -363,5 +370,67 @@ class RReluTrainingTest(RReluTest): ] # python out sig is customized output signature. +class RReluTestFP16OP(RReluTest): + def init_dtype(self): + self.dtype = np.float16 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA and do not support bfloat16", +) +class RReluTestBF16OP(RReluTest): + def init_dtype(self): + self.dtype = np.float32 + + def convert_input_output(self): + self.inputs = {'X': convert_float_to_uint16(self.inputs['X'])} + self.outputs = { + 'Out': convert_float_to_uint16(self.outputs['Out']), + 'Noise': convert_float_to_uint16(self.outputs['Noise']), + } + self.dtype = np.uint16 + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place, no_check_set=['Noise']) + + def test_check_grad(self): + place = core.CUDAPlace(0) + self.check_grad_with_place(place, ['X'], 'Out') + + +class RReluTrainingTestFP16OP(RReluTrainingTest): + def init_dtype(self): + self.dtype = np.float16 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA and do not support bfloat16", +) +class RReluTrainingTestBF16OP(RReluTrainingTest): + def init_dtype(self): + self.dtype = np.float32 + + def convert_input_output(self): + self.inputs = {'X': convert_float_to_uint16(self.inputs['X'])} + self.outputs = { + 'Out': convert_float_to_uint16(self.outputs['Out']), + 'Noise': convert_float_to_uint16(self.outputs['Noise']), + } + self.dtype = np.uint16 + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place, no_check_set=['Noise']) + + def test_check_grad(self): + place = core.CUDAPlace(0) + self.check_grad_with_place(place, ['X'], 'Out') + + if __name__ == "__main__": unittest.main() -- GitLab