未验证 提交 cb6de765 编写于 作者: C chenxujun 提交者: GitHub

Add overlap_add, sign tests (#52667)

上级 a67d3bb7
...@@ -32,6 +32,7 @@ struct EigenSign<Eigen::GpuDevice, T> { ...@@ -32,6 +32,7 @@ struct EigenSign<Eigen::GpuDevice, T> {
template struct EigenSign<Eigen::GpuDevice, float>; template struct EigenSign<Eigen::GpuDevice, float>;
template struct EigenSign<Eigen::GpuDevice, double>; template struct EigenSign<Eigen::GpuDevice, double>;
template struct EigenSign<Eigen::GpuDevice, dtype::float16>; template struct EigenSign<Eigen::GpuDevice, dtype::float16>;
template struct EigenSign<Eigen::GpuDevice, dtype::bfloat16>;
} // namespace funcs } // namespace funcs
} // namespace phi } // namespace phi
...@@ -161,5 +161,6 @@ PD_REGISTER_KERNEL(overlap_add_grad, ...@@ -161,5 +161,6 @@ PD_REGISTER_KERNEL(overlap_add_grad,
float, float,
double, double,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
...@@ -147,5 +147,6 @@ PD_REGISTER_KERNEL(overlap_add, ...@@ -147,5 +147,6 @@ PD_REGISTER_KERNEL(overlap_add,
float, float,
double, double,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
...@@ -19,9 +19,13 @@ limitations under the License. */ ...@@ -19,9 +19,13 @@ limitations under the License. */
#include "paddle/phi/kernels/impl/sign_kernel_impl.h" #include "paddle/phi/kernels/impl/sign_kernel_impl.h"
// See Note [ Why still include the fluid headers? ] // See Note [ Why still include the fluid headers? ]
#include "paddle/phi/common/float16.h" #include "paddle/phi/common/amp_type_traits.h"
using float16 = phi::dtype::float16; PD_REGISTER_KERNEL(sign,
GPU,
PD_REGISTER_KERNEL( ALL_LAYOUT,
sign, GPU, ALL_LAYOUT, phi::SignKernel, float, double, float16) {} phi::SignKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -15,14 +15,15 @@ ...@@ -15,14 +15,15 @@
import unittest import unittest
import numpy as np import numpy as np
from eager_op_test import OpTest from eager_op_test import OpTest, convert_float_to_uint16
import paddle import paddle
from paddle.fluid import core
def overlap_add(x, hop_length, axis=-1): def overlap_add(x, hop_length, axis=-1):
assert axis in [0, -1], 'axis should be 0/-1.' assert axis in [0, -1], 'axis should be 0/-1.'
assert len(x.shape) >= 2, 'Input dims shoulb be >= 2.' assert len(x.shape) >= 2, 'Input dims should be >= 2.'
squeeze_output = False squeeze_output = False
if len(x.shape) == 2: if len(x.shape) == 2:
...@@ -101,6 +102,58 @@ class TestOverlapAddOp(OpTest): ...@@ -101,6 +102,58 @@ class TestOverlapAddOp(OpTest):
paddle.disable_static() paddle.disable_static()
class TestOverlapAddFP16Op(TestOverlapAddOp):
def initTestCase(self):
input_shape = (50, 3)
input_type = 'float16'
attrs = {
'hop_length': 4,
'axis': -1,
}
return input_shape, input_type, attrs
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support bfloat16",
)
class TestOverlapAddBF16Op(OpTest):
def setUp(self):
self.op_type = "overlap_add"
self.python_api = paddle.signal.overlap_add
self.shape, self.type, self.attrs = self.initTestCase()
self.np_dtype = np.float32
self.dtype = np.uint16
self.inputs = {
'X': np.random.random(size=self.shape).astype(self.np_dtype),
}
self.outputs = {'Out': overlap_add(x=self.inputs['X'], **self.attrs)}
self.inputs['X'] = convert_float_to_uint16(self.inputs['X'])
self.outputs['Out'] = convert_float_to_uint16(self.outputs['Out'])
self.place = core.CUDAPlace(0)
def initTestCase(self):
input_shape = (50, 3)
input_type = np.uint16
attrs = {
'hop_length': 4,
'axis': -1,
}
return input_shape, input_type, attrs
def test_check_output(self):
paddle.enable_static()
self.check_output_with_place(self.place)
paddle.disable_static()
def test_check_grad_normal(self):
paddle.enable_static()
self.check_grad_with_place(self.place, ['X'], 'Out')
paddle.disable_static()
class TestCase1(TestOverlapAddOp): class TestCase1(TestOverlapAddOp):
def initTestCase(self): def initTestCase(self):
input_shape = (3, 50) input_shape = (3, 50)
......
...@@ -17,7 +17,7 @@ import unittest ...@@ -17,7 +17,7 @@ import unittest
import gradient_checker import gradient_checker
import numpy as np import numpy as np
from decorator_helper import prog_scope from decorator_helper import prog_scope
from eager_op_test import OpTest from eager_op_test import OpTest, convert_float_to_uint16
import paddle import paddle
from paddle import fluid from paddle import fluid
...@@ -40,6 +40,42 @@ class TestSignOp(OpTest): ...@@ -40,6 +40,42 @@ class TestSignOp(OpTest):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
class TestSignFP16Op(TestSignOp):
def setUp(self):
self.op_type = "sign"
self.python_api = paddle.sign
self.inputs = {
'X': np.random.uniform(-10, 10, (10, 10)).astype("float16")
}
self.outputs = {'Out': np.sign(self.inputs['X'])}
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support bfloat16",
)
class TestSignBF16Op(OpTest):
def setUp(self):
self.op_type = "sign"
self.python_api = paddle.sign
self.dtype = np.uint16
self.inputs = {
'X': np.random.uniform(-10, 10, (10, 10)).astype("float32")
}
self.outputs = {'Out': np.sign(self.inputs['X'])}
self.inputs['X'] = convert_float_to_uint16(self.inputs['X'])
self.outputs['Out'] = convert_float_to_uint16(self.outputs['Out'])
self.place = core.CUDAPlace(0)
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(self.place, ['X'], 'Out')
class TestSignOpError(unittest.TestCase): class TestSignOpError(unittest.TestCase):
def test_errors(self): def test_errors(self):
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
...@@ -97,7 +133,7 @@ class TestSignDoubleGradCheck(unittest.TestCase): ...@@ -97,7 +133,7 @@ class TestSignDoubleGradCheck(unittest.TestCase):
@prog_scope() @prog_scope()
def func(self, place): def func(self, place):
# the shape of input variable should be clearly specified, not inlcude -1. # the shape of input variable should be clearly specified, not include -1.
eps = 0.005 eps = 0.005
dtype = np.float32 dtype = np.float32
...@@ -128,7 +164,7 @@ class TestSignTripleGradCheck(unittest.TestCase): ...@@ -128,7 +164,7 @@ class TestSignTripleGradCheck(unittest.TestCase):
@prog_scope() @prog_scope()
def func(self, place): def func(self, place):
# the shape of input variable should be clearly specified, not inlcude -1. # the shape of input variable should be clearly specified, not include -1.
eps = 0.005 eps = 0.005
dtype = np.float32 dtype = np.float32
......
...@@ -219,7 +219,10 @@ def overlap_add(x, hop_length, axis=-1, name=None): ...@@ -219,7 +219,10 @@ def overlap_add(x, hop_length, axis=-1, name=None):
out = op(x, *attrs) out = op(x, *attrs)
else: else:
check_variable_and_dtype( check_variable_and_dtype(
x, 'x', ['int32', 'int64', 'float16', 'float32', 'float64'], op_type x,
'x',
['int32', 'int64', 'float16', 'float32', 'float64', 'uint16'],
op_type,
) )
helper = LayerHelper(op_type, **locals()) helper = LayerHelper(op_type, **locals())
dtype = helper.input_dtype(input_param_name='x') dtype = helper.input_dtype(input_param_name='x')
......
...@@ -3677,7 +3677,7 @@ def sign(x, name=None): ...@@ -3677,7 +3677,7 @@ def sign(x, name=None):
return _C_ops.sign(x) return _C_ops.sign(x)
else: else:
check_variable_and_dtype( check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], 'sign' x, 'x', ['float16', 'float32', 'float64', 'uint16'], 'sign'
) )
helper = LayerHelper("sign", **locals()) helper = LayerHelper("sign", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册