From cb6de765d969d1350e8b29f0235b4aa5040ee590 Mon Sep 17 00:00:00 2001 From: chenxujun Date: Thu, 13 Apr 2023 15:08:35 +0800 Subject: [PATCH] Add overlap_add, sign tests (#52667) --- paddle/phi/kernels/funcs/eigen/sign.cu | 1 + .../kernels/gpu/overlap_add_grad_kernel.cu | 1 + paddle/phi/kernels/gpu/overlap_add_kernel.cu | 1 + paddle/phi/kernels/gpu/sign_kernel.cu.cc | 16 ++++-- .../tests/unittests/test_overlap_add_op.py | 57 ++++++++++++++++++- .../fluid/tests/unittests/test_sign_op.py | 42 +++++++++++++- python/paddle/signal.py | 5 +- python/paddle/tensor/math.py | 2 +- 8 files changed, 112 insertions(+), 13 deletions(-) diff --git a/paddle/phi/kernels/funcs/eigen/sign.cu b/paddle/phi/kernels/funcs/eigen/sign.cu index 4caed688013..b630ba7bb6c 100644 --- a/paddle/phi/kernels/funcs/eigen/sign.cu +++ b/paddle/phi/kernels/funcs/eigen/sign.cu @@ -32,6 +32,7 @@ struct EigenSign { template struct EigenSign; template struct EigenSign; template struct EigenSign; +template struct EigenSign; } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/gpu/overlap_add_grad_kernel.cu b/paddle/phi/kernels/gpu/overlap_add_grad_kernel.cu index 057f7e465c0..a2ec60109d6 100644 --- a/paddle/phi/kernels/gpu/overlap_add_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/overlap_add_grad_kernel.cu @@ -161,5 +161,6 @@ PD_REGISTER_KERNEL(overlap_add_grad, float, double, phi::dtype::float16, + phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/overlap_add_kernel.cu b/paddle/phi/kernels/gpu/overlap_add_kernel.cu index cf56095db5e..b8726b8d8e1 100644 --- a/paddle/phi/kernels/gpu/overlap_add_kernel.cu +++ b/paddle/phi/kernels/gpu/overlap_add_kernel.cu @@ -147,5 +147,6 @@ PD_REGISTER_KERNEL(overlap_add, float, double, phi::dtype::float16, + phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/sign_kernel.cu.cc b/paddle/phi/kernels/gpu/sign_kernel.cu.cc index 37f10243dc5..71cd1d39b68 100644 --- a/paddle/phi/kernels/gpu/sign_kernel.cu.cc +++ b/paddle/phi/kernels/gpu/sign_kernel.cu.cc @@ -19,9 +19,13 @@ limitations under the License. */ #include "paddle/phi/kernels/impl/sign_kernel_impl.h" // See Note [ Why still include the fluid headers? ] -#include "paddle/phi/common/float16.h" - -using float16 = phi::dtype::float16; - -PD_REGISTER_KERNEL( - sign, GPU, ALL_LAYOUT, phi::SignKernel, float, double, float16) {} +#include "paddle/phi/common/amp_type_traits.h" + +PD_REGISTER_KERNEL(sign, + GPU, + ALL_LAYOUT, + phi::SignKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/python/paddle/fluid/tests/unittests/test_overlap_add_op.py b/python/paddle/fluid/tests/unittests/test_overlap_add_op.py index d0e5cd79c3b..98d4ce10aaa 100644 --- a/python/paddle/fluid/tests/unittests/test_overlap_add_op.py +++ b/python/paddle/fluid/tests/unittests/test_overlap_add_op.py @@ -15,14 +15,15 @@ import unittest import numpy as np -from eager_op_test import OpTest +from eager_op_test import OpTest, convert_float_to_uint16 import paddle +from paddle.fluid import core def overlap_add(x, hop_length, axis=-1): assert axis in [0, -1], 'axis should be 0/-1.' - assert len(x.shape) >= 2, 'Input dims shoulb be >= 2.' + assert len(x.shape) >= 2, 'Input dims should be >= 2.' squeeze_output = False if len(x.shape) == 2: @@ -101,6 +102,58 @@ class TestOverlapAddOp(OpTest): paddle.disable_static() +class TestOverlapAddFP16Op(TestOverlapAddOp): + def initTestCase(self): + input_shape = (50, 3) + input_type = 'float16' + attrs = { + 'hop_length': 4, + 'axis': -1, + } + return input_shape, input_type, attrs + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA or not support bfloat16", +) +class TestOverlapAddBF16Op(OpTest): + def setUp(self): + self.op_type = "overlap_add" + self.python_api = paddle.signal.overlap_add + self.shape, self.type, self.attrs = self.initTestCase() + self.np_dtype = np.float32 + self.dtype = np.uint16 + self.inputs = { + 'X': np.random.random(size=self.shape).astype(self.np_dtype), + } + self.outputs = {'Out': overlap_add(x=self.inputs['X'], **self.attrs)} + + self.inputs['X'] = convert_float_to_uint16(self.inputs['X']) + self.outputs['Out'] = convert_float_to_uint16(self.outputs['Out']) + self.place = core.CUDAPlace(0) + + def initTestCase(self): + input_shape = (50, 3) + input_type = np.uint16 + attrs = { + 'hop_length': 4, + 'axis': -1, + } + return input_shape, input_type, attrs + + def test_check_output(self): + paddle.enable_static() + self.check_output_with_place(self.place) + paddle.disable_static() + + def test_check_grad_normal(self): + paddle.enable_static() + self.check_grad_with_place(self.place, ['X'], 'Out') + paddle.disable_static() + + class TestCase1(TestOverlapAddOp): def initTestCase(self): input_shape = (3, 50) diff --git a/python/paddle/fluid/tests/unittests/test_sign_op.py b/python/paddle/fluid/tests/unittests/test_sign_op.py index 79ee4ceff5f..2617c2451f3 100644 --- a/python/paddle/fluid/tests/unittests/test_sign_op.py +++ b/python/paddle/fluid/tests/unittests/test_sign_op.py @@ -17,7 +17,7 @@ import unittest import gradient_checker import numpy as np from decorator_helper import prog_scope -from eager_op_test import OpTest +from eager_op_test import OpTest, convert_float_to_uint16 import paddle from paddle import fluid @@ -40,6 +40,42 @@ class TestSignOp(OpTest): self.check_grad(['X'], 'Out') +class TestSignFP16Op(TestSignOp): + def setUp(self): + self.op_type = "sign" + self.python_api = paddle.sign + self.inputs = { + 'X': np.random.uniform(-10, 10, (10, 10)).astype("float16") + } + self.outputs = {'Out': np.sign(self.inputs['X'])} + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA or not support bfloat16", +) +class TestSignBF16Op(OpTest): + def setUp(self): + self.op_type = "sign" + self.python_api = paddle.sign + self.dtype = np.uint16 + self.inputs = { + 'X': np.random.uniform(-10, 10, (10, 10)).astype("float32") + } + self.outputs = {'Out': np.sign(self.inputs['X'])} + + self.inputs['X'] = convert_float_to_uint16(self.inputs['X']) + self.outputs['Out'] = convert_float_to_uint16(self.outputs['Out']) + self.place = core.CUDAPlace(0) + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + self.check_grad_with_place(self.place, ['X'], 'Out') + + class TestSignOpError(unittest.TestCase): def test_errors(self): with program_guard(Program(), Program()): @@ -97,7 +133,7 @@ class TestSignDoubleGradCheck(unittest.TestCase): @prog_scope() def func(self, place): - # the shape of input variable should be clearly specified, not inlcude -1. + # the shape of input variable should be clearly specified, not include -1. eps = 0.005 dtype = np.float32 @@ -128,7 +164,7 @@ class TestSignTripleGradCheck(unittest.TestCase): @prog_scope() def func(self, place): - # the shape of input variable should be clearly specified, not inlcude -1. + # the shape of input variable should be clearly specified, not include -1. eps = 0.005 dtype = np.float32 diff --git a/python/paddle/signal.py b/python/paddle/signal.py index e404ec08ffb..e1580b00075 100644 --- a/python/paddle/signal.py +++ b/python/paddle/signal.py @@ -219,7 +219,10 @@ def overlap_add(x, hop_length, axis=-1, name=None): out = op(x, *attrs) else: check_variable_and_dtype( - x, 'x', ['int32', 'int64', 'float16', 'float32', 'float64'], op_type + x, + 'x', + ['int32', 'int64', 'float16', 'float32', 'float64', 'uint16'], + op_type, ) helper = LayerHelper(op_type, **locals()) dtype = helper.input_dtype(input_param_name='x') diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 1e969be8804..fe412003787 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -3677,7 +3677,7 @@ def sign(x, name=None): return _C_ops.sign(x) else: check_variable_and_dtype( - x, 'x', ['float16', 'float32', 'float64'], 'sign' + x, 'x', ['float16', 'float32', 'float64', 'uint16'], 'sign' ) helper = LayerHelper("sign", **locals()) out = helper.create_variable_for_type_inference(dtype=x.dtype) -- GitLab