未验证 提交 6d7ee668 编写于 作者: C chenxujun 提交者: GitHub

Add angle,bmm tests (#52630)

上级 281ea2f4
......@@ -16,6 +16,7 @@
#include "paddle/phi/kernels/impl/angle_grad_kernel_impl.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(angle_grad,
......@@ -24,6 +25,8 @@ PD_REGISTER_KERNEL(angle_grad,
phi::AngleGradKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->InputAt(1).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
......
......@@ -17,6 +17,7 @@
#include "paddle/phi/kernels/impl/angle_kernel_impl.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(angle,
......@@ -25,6 +26,8 @@ PD_REGISTER_KERNEL(angle,
phi::AngleKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
......
......@@ -24,4 +24,5 @@ PD_REGISTER_KERNEL(bmm_grad,
phi::BmmGradKernel,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -18,5 +18,11 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/bmm_kernel_impl.h"
PD_REGISTER_KERNEL(
bmm, GPU, ALL_LAYOUT, phi::BmmKernel, float, double, phi::dtype::float16) {}
PD_REGISTER_KERNEL(bmm,
GPU,
ALL_LAYOUT,
phi::BmmKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -15,11 +15,11 @@
import unittest
import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16
import paddle
from paddle import static
from paddle.fluid import dygraph
from paddle.fluid import core, dygraph
paddle.enable_static()
......@@ -61,6 +61,51 @@ class TestAngleOpFloat(OpTest):
)
class TestAngleFP16Op(TestAngleOpFloat):
def setUp(self):
self.op_type = "angle"
self.python_api = paddle.angle
self.dtype = "float16"
self.x = np.linspace(-5, 5, 101).astype(self.dtype)
out_ref = np.angle(self.x)
self.inputs = {'X': self.x}
self.outputs = {'Out': out_ref}
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support bfloat16",
)
class TestAngleBF16Op(OpTest):
def setUp(self):
self.op_type = "angle"
self.python_api = paddle.angle
self.dtype = np.uint16
self.np_dtype = np.float32
self.x = np.linspace(-5, 5, 101).astype(self.np_dtype)
out_ref = np.angle(self.x)
self.inputs = {'X': self.x}
self.outputs = {'Out': out_ref}
self.inputs['X'] = convert_float_to_uint16(self.inputs['X'])
self.outputs['Out'] = convert_float_to_uint16(self.outputs['Out'])
self.place = core.CUDAPlace(0)
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(
self.place,
['X'],
'Out',
user_defined_grads=[
angle_grad(self.x, np.ones_like(self.x) / self.x.size)
],
)
class TestAngleOpComplex(OpTest):
def setUp(self):
self.op_type = "angle"
......
......@@ -15,10 +15,11 @@
import unittest
import numpy as np
from eager_op_test import OpTest, paddle_static_guard
from eager_op_test import OpTest, convert_float_to_uint16, paddle_static_guard
import paddle
from paddle import fluid
from paddle.fluid import core
class TestBmmOp(OpTest):
......@@ -38,6 +39,52 @@ class TestBmmOp(OpTest):
self.check_grad(['X', 'Y'], 'Out')
class TestBmmFP16Op(OpTest):
def setUp(self):
self.op_type = "bmm"
self.dtype = np.float16
self.python_api = paddle.tensor.bmm
X = np.random.random((10, 3, 4)).astype("float16")
Y = np.random.random((10, 4, 5)).astype("float16")
self.inputs = {'X': X, 'Y': Y}
Out = np.matmul(X, Y)
self.outputs = {'Out': Out}
def test_check_output(self):
self.check_output()
def test_checkout_grad(self):
self.check_grad(['X', 'Y'], 'Out')
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support bfloat16",
)
class TestBmmBF16Op(OpTest):
def setUp(self):
self.op_type = "bmm"
self.dtype = np.uint16
self.python_api = paddle.tensor.bmm
X = np.random.random((10, 3, 4)).astype("float32")
Y = np.random.random((10, 4, 5)).astype("float32")
self.inputs = {'X': X, 'Y': Y}
Out = np.matmul(X, Y)
self.outputs = {'Out': Out}
self.inputs['X'] = convert_float_to_uint16(self.inputs['X'])
self.inputs['Y'] = convert_float_to_uint16(self.inputs['Y'])
self.outputs['Out'] = convert_float_to_uint16(self.outputs['Out'])
self.place = core.CUDAPlace(0)
def test_check_output(self):
self.check_output_with_place(self.place)
def test_checkout_grad(self):
self.check_grad_with_place(self.place, ['X', 'Y'], 'Out')
class API_TestBmm(unittest.TestCase):
def test_out(self):
with paddle_static_guard():
......
......@@ -4858,7 +4858,17 @@ def angle(x, name=None):
return _C_ops.angle(x)
else:
check_variable_and_dtype(
x, 'x', ['float32', 'float64', 'complex64', 'complex128'], 'angle'
x,
'x',
[
'float16',
'float32',
'float64',
'complex64',
'complex128',
'uint16',
],
'angle',
)
op_type = "angle"
helper = LayerHelper(op_type, **locals())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册