From b6c1493ab282fb2af6e5b4d46741f5a97b8937d6 Mon Sep 17 00:00:00 2001 From: zhaoyingli <86812880+zhaoyinglia@users.noreply.github.com> Date: Wed, 29 Mar 2023 16:35:57 +0800 Subject: [PATCH] [AMP OP&Test] add float16 optest for reshape_op (#51678) * [AMP OP&Test] add float16 optest for reshape_op * add public_python_api --- .../fluid/tests/unittests/test_reshape_op.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_reshape_op.py b/python/paddle/fluid/tests/unittests/test_reshape_op.py index 24b8661694e..d5acc54d572 100755 --- a/python/paddle/fluid/tests/unittests/test_reshape_op.py +++ b/python/paddle/fluid/tests/unittests/test_reshape_op.py @@ -119,6 +119,34 @@ class TestReshapeBF16Op(OpTest): self.check_grad(["X"], "Out", check_prim=True) +class TestReshapeFP16Op(OpTest): + def setUp(self): + self.init_data() + self.op_type = "reshape2" + self.prim_op_type = "prim" + self.python_api = paddle.tensor.reshape + self.public_python_api = paddle.tensor.reshape + self.python_out_sig = ['Out'] + self.dtype = np.float16 + self.inputs = {"X": np.random.random(self.ori_shape).astype(self.dtype)} + self.attrs = {"shape": self.new_shape} + self.outputs = { + "Out": self.inputs["X"].reshape(self.infered_shape), + 'XShape': np.random.random(self.ori_shape).astype(self.dtype), + } + + def init_data(self): + self.ori_shape = (2, 60) + self.new_shape = (12, 10) + self.infered_shape = (12, 10) + + def test_check_output(self): + self.check_output(no_check_set=['XShape']) + + def test_check_grad(self): + self.check_grad(["X"], "Out", check_prim=True) + + class TestReshapeOpDimInfer1(TestReshapeOp): def init_data(self): self.ori_shape = (5, 25) -- GitLab