diff --git a/python/paddle/fluid/tests/unittests/test_reshape_op.py b/python/paddle/fluid/tests/unittests/test_reshape_op.py index 24b8661694edbfde50458dd36ef22b9ba9a3fef7..d5acc54d5721b58d46aade3566c80c8e0533f46b 100755 --- a/python/paddle/fluid/tests/unittests/test_reshape_op.py +++ b/python/paddle/fluid/tests/unittests/test_reshape_op.py @@ -119,6 +119,34 @@ class TestReshapeBF16Op(OpTest): self.check_grad(["X"], "Out", check_prim=True) +class TestReshapeFP16Op(OpTest): + def setUp(self): + self.init_data() + self.op_type = "reshape2" + self.prim_op_type = "prim" + self.python_api = paddle.tensor.reshape + self.public_python_api = paddle.tensor.reshape + self.python_out_sig = ['Out'] + self.dtype = np.float16 + self.inputs = {"X": np.random.random(self.ori_shape).astype(self.dtype)} + self.attrs = {"shape": self.new_shape} + self.outputs = { + "Out": self.inputs["X"].reshape(self.infered_shape), + 'XShape': np.random.random(self.ori_shape).astype(self.dtype), + } + + def init_data(self): + self.ori_shape = (2, 60) + self.new_shape = (12, 10) + self.infered_shape = (12, 10) + + def test_check_output(self): + self.check_output(no_check_set=['XShape']) + + def test_check_grad(self): + self.check_grad(["X"], "Out", check_prim=True) + + class TestReshapeOpDimInfer1(TestReshapeOp): def init_data(self): self.ori_shape = (5, 25)