未验证 提交 b6c1493a 编写于 作者: Z zhaoyingli 提交者: GitHub

[AMP OP&Test] add float16 optest for reshape_op (#51678)

* [AMP OP&Test] add float16 optest for reshape_op

* add public_python_api
上级 fc02b1e6
...@@ -119,6 +119,34 @@ class TestReshapeBF16Op(OpTest): ...@@ -119,6 +119,34 @@ class TestReshapeBF16Op(OpTest):
self.check_grad(["X"], "Out", check_prim=True) self.check_grad(["X"], "Out", check_prim=True)
class TestReshapeFP16Op(OpTest):
def setUp(self):
self.init_data()
self.op_type = "reshape2"
self.prim_op_type = "prim"
self.python_api = paddle.tensor.reshape
self.public_python_api = paddle.tensor.reshape
self.python_out_sig = ['Out']
self.dtype = np.float16
self.inputs = {"X": np.random.random(self.ori_shape).astype(self.dtype)}
self.attrs = {"shape": self.new_shape}
self.outputs = {
"Out": self.inputs["X"].reshape(self.infered_shape),
'XShape': np.random.random(self.ori_shape).astype(self.dtype),
}
def init_data(self):
self.ori_shape = (2, 60)
self.new_shape = (12, 10)
self.infered_shape = (12, 10)
def test_check_output(self):
self.check_output(no_check_set=['XShape'])
def test_check_grad(self):
self.check_grad(["X"], "Out", check_prim=True)
class TestReshapeOpDimInfer1(TestReshapeOp): class TestReshapeOpDimInfer1(TestReshapeOp):
def init_data(self): def init_data(self):
self.ori_shape = (5, 25) self.ori_shape = (5, 25)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册