未验证 提交 f547ee92 编写于 作者: C chenxujun 提交者: GitHub

Add kron float16/bfloat16, unbind float16 tests (#52413)

上级 3c949ba9
......@@ -27,5 +27,6 @@ PD_REGISTER_KERNEL(kron_grad,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
......@@ -27,5 +27,6 @@ PD_REGISTER_KERNEL(kron,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
......@@ -15,11 +15,12 @@
import unittest
import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16
import paddle
import paddle.fluid.dygraph as dg
from paddle import fluid
from paddle.fluid import core
class TestKronOp(OpTest):
......@@ -73,6 +74,50 @@ class TestKronOp3(TestKronOp):
self.outputs = {'Out': out_ref}
class TestKronFP16Op(TestKronOp):
def _init_dtype(self):
return "float16"
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestKronBF16Op(TestKronOp):
def setUp(self):
self.op_type = "kron"
self.python_api = paddle.kron
self.dtype = np.uint16
self.np_dtype = "float32"
x = np.random.uniform(size=(10, 10)).astype(self.np_dtype)
y = np.random.uniform(size=(10, 10)).astype(self.np_dtype)
out_ref = np.kron(x, y)
self.inputs = {
'X': convert_float_to_uint16(x),
'Y': convert_float_to_uint16(y),
}
self.outputs = {'Out': convert_float_to_uint16(out_ref)}
# bfloat16 requires using place
self.place = core.CUDAPlace(0)
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(self.place, ['X', 'Y'], 'Out')
def test_check_grad_ignore_x(self):
self.check_grad_with_place(
self.place, ['Y'], 'Out', no_grad_set=set('X')
)
def test_check_grad_ignore_y(self):
self.check_grad_with_place(
self.place, ['X'], 'Out', no_grad_set=set('Y')
)
class TestKronLayer(unittest.TestCase):
def test_case(self):
a = np.random.randn(10, 10).astype(np.float64)
......
......@@ -199,6 +199,30 @@ class TestUnbindOp4(TestUnbindOp):
self.out[1] = self.out[1].reshape((3, 2))
class TestUnbindFP16Op(OpTest):
def setUp(self):
paddle.disable_static()
self.op_type = "unbind"
self.python_api = paddle.unbind
self.dtype = self.get_dtype()
self.axis = 0
self.num = 3
x = np.arange(12).reshape(3, 2, 2).astype(self.dtype)
self.out = np.split(x, self.num, self.axis)
self.inputs = {'X': x}
self.attrs = {'axis': self.axis}
self.outputs = {
'Out': [('out%d' % i, self.out[i]) for i in range(len(self.out))]
}
self.python_out_sig = ['out%d' % i for i in range(len(self.out))]
def get_dtype(self):
return np.float16
def test_check_output(self):
self.check_output()
class TestUnbindBF16Op(OpTest):
def setUp(self):
paddle.disable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册