未验证 提交 0e3f7ab1 编写于 作者: L LoneRanger 提交者: GitHub

【PaddlePaddle Hackathon 4】No.56 : add fp16 test and bf16 test for diag,...

【PaddlePaddle Hackathon 4】No.56 : add fp16 test and bf16 test for diag, diagonal, fill and fill_diagonal_tensor (#51649)
上级 cf7c431f
......@@ -20,4 +20,5 @@ REGISTER_OP_CUDA_KERNEL(fill,
ops::FillKernel<double>,
ops::FillKernel<int64_t>,
ops::FillKernel<int>,
ops::FillKernel<paddle::platform::float16>);
ops::FillKernel<paddle::platform::float16>,
ops::FillKernel<paddle::platform::bfloat16>);
......@@ -132,6 +132,7 @@ PD_REGISTER_KERNEL(diag_grad,
ALL_LAYOUT,
phi::DiagGradKernel,
phi::dtype::float16,
phi::dtype::bfloat16,
int,
int64_t,
float,
......
......@@ -135,6 +135,7 @@ PD_REGISTER_KERNEL(diag,
ALL_LAYOUT,
phi::DiagKernel,
phi::dtype::float16,
phi::dtype::bfloat16,
int,
int64_t,
float,
......
......@@ -15,29 +15,39 @@
import unittest
import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16
import paddle
from paddle import fluid
from paddle.fluid import Program, program_guard
from paddle.fluid import Program, core, program_guard
class TestDiagV2Op(OpTest):
def setUp(self):
self.op_type = "diag_v2"
self.python_api = paddle.diag
self.x = np.random.rand(10, 10)
self.init_dtype()
self.init_attrs()
self.init_input_output()
def init_dtype(self):
self.dtype = np.float64
def init_attrs(self):
self.offset = 0
self.padding_value = 0.0
self.out = np.diag(self.x, self.offset)
self.init_config()
self.inputs = {'X': self.x}
def init_input_output(self):
x = np.random.rand(10, 10).astype(self.dtype)
out = np.diag(x, self.offset)
self.attrs = {
'offset': self.offset,
'padding_value': self.padding_value,
}
self.outputs = {'Out': self.out}
self.inputs = {'X': x}
self.outputs = {'Out': out}
def test_check_output(self):
paddle.enable_static()
......@@ -47,9 +57,6 @@ class TestDiagV2Op(OpTest):
paddle.enable_static()
self.check_grad(['X'], 'Out')
def init_config(self):
pass
class TestDiagV2OpCase1(TestDiagV2Op):
def init_config(self):
......@@ -298,6 +305,44 @@ class TestDiagV2API(unittest.TestCase):
self.run_static(use_gpu=True)
class TestDiagV2FP16OP(TestDiagV2Op):
def init_dtype(self):
self.dtype = np.float16
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestDiagV2BF16OP(OpTest):
def setUp(self):
self.op_type = "diag_v2"
self.python_api = paddle.diag
self.dtype = np.uint16
x = np.random.rand(10, 10).astype(np.float32)
offset = 0
padding_value = 0.0
out = np.diag(x, offset)
self.inputs = {'X': convert_float_to_uint16(x)}
self.attrs = {
'offset': offset,
'padding_value': padding_value,
}
self.outputs = {'Out': convert_float_to_uint16(out)}
def test_check_output(self):
paddle.enable_static()
place = core.CUDAPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
paddle.enable_static()
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'Out')
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
......@@ -15,9 +15,10 @@
import unittest
import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16
import paddle
from paddle.fluid import core
paddle.enable_static()
......@@ -26,6 +27,7 @@ class TestDiagonalOp(OpTest):
def setUp(self):
self.op_type = "diagonal"
self.python_api = paddle.diagonal
self.init_dtype()
self.init_config()
self.outputs = {'Out': self.target}
......@@ -35,8 +37,11 @@ class TestDiagonalOp(OpTest):
def test_check_grad(self):
self.check_grad(['Input'], 'Out')
def init_dtype(self):
self.dtype = 'float64'
def init_config(self):
self.case = np.random.randn(10, 5, 2).astype('float64')
self.case = np.random.randn(10, 5, 2).astype(self.dtype)
self.inputs = {'Input': self.case}
self.attrs = {'offset': 0, 'axis1': 0, 'axis2': 1}
self.target = np.diagonal(
......@@ -172,5 +177,43 @@ class TestDiagonalAPI(unittest.TestCase):
paddle.enable_static()
class TestDiagonalFP16OP(TestDiagonalOp):
def init_dtype(self):
self.dtype = np.float16
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestDiagonalBF16OP(OpTest):
def setUp(self):
self.op_type = "diagonal"
self.python_api = paddle.diagonal
self.dtype = np.uint16
self.init_config()
self.outputs = {'Out': convert_float_to_uint16(self.target)}
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['Input'], 'Out')
def init_config(self):
self.case = np.random.randn(10, 5, 2).astype(np.float32)
self.inputs = {'Input': convert_float_to_uint16(self.case)}
self.attrs = {'offset': 0, 'axis1': 0, 'axis2': 1}
self.target = np.diagonal(
self.case,
offset=self.attrs['offset'],
axis1=self.attrs['axis1'],
axis2=self.attrs['axis2'],
).copy()
if __name__ == '__main__':
unittest.main()
......@@ -15,9 +15,10 @@
import unittest
import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16
import paddle
from paddle.fluid import core
def fill_diagonal_ndarray(x, value, offset=0, dim1=0, dim2=1):
......@@ -148,6 +149,57 @@ class TensorFillDiagTensor_Test3(TensorFillDiagTensor_Test):
self.dtype = np.float16
class TensorFillDiagTensorFP16OP(TensorFillDiagTensor_Test):
def init_kernel_type(self):
self.dtype = np.float16
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TensorFillDiagTensorBF16(OpTest):
def setUp(self):
self.op_type = "fill_diagonal_tensor"
self.python_api = paddle.tensor.manipulation.fill_diagonal_tensor
self.init_kernel_type()
self.init_config()
self.init_input_output()
def init_kernel_type(self):
self.dtype = np.uint16
def init_config(self):
self.x = np.random.random((10, 10)).astype(np.float32)
self.y = np.random.random((10,)).astype(np.float32)
self.dim1 = 0
self.dim2 = 1
self.offset = 0
def init_input_output(self):
out = fill_gt(self.x, self.y, self.offset, self.dim1, self.dim2)
self.inputs = {
"X": convert_float_to_uint16(self.x),
"Y": convert_float_to_uint16(self.y),
}
self.outputs = {'Out': convert_float_to_uint16(out)}
self.attrs = {
"offset": self.offset,
"dim1": self.dim1,
"dim2": self.dim2,
}
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'Out')
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
......@@ -15,7 +15,7 @@
import unittest
import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16
from paddle.fluid import core
from paddle.fluid.op import Operator
......@@ -24,6 +24,7 @@ from paddle.fluid.op import Operator
class TestFillOp1(OpTest):
def setUp(self):
self.op_type = "fill"
self.init_dtype()
val = np.random.random(size=[100, 200])
self.inputs = {}
self.attrs = {
......@@ -34,6 +35,9 @@ class TestFillOp1(OpTest):
}
self.outputs = {'Out': val.astype('float64')}
def init_dtype(self):
self.dtype = np.float64
def test_check_output(self):
self.check_output()
......@@ -89,5 +93,34 @@ class TestFillOp3(unittest.TestCase):
self.check_with_place(place, False)
class TestFillFP16OP(TestFillOp1):
def init_dtype(self):
self.dtype = np.float16
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestFillBF16OP(OpTest):
def setUp(self):
self.op_type = "fill"
self.dtype = np.uint16
val = np.random.random(size=[100, 200])
self.inputs = {}
self.attrs = {
'value': val.flatten().tolist(),
'shape': [100, 200],
'dtype': int(core.VarDesc.VarType.BF16),
'force_cpu': False,
}
self.outputs = {'Out': convert_float_to_uint16(val)}
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册