未验证 提交 ccfe7681 编写于 作者: W Weilong Wu 提交者: GitHub

support reshape test on prim and cinn (#51276)

* support reshape test on prim and cinn

* fix mkldnn test

* polish test case
上级 782454bd
......@@ -1219,7 +1219,8 @@ set(TEST_CINN_OPS
test_elementwise_mul_op
test_gather_nd_op
test_elementwise_pow_op
test_transpose_op)
test_transpose_op
test_reshape_op)
foreach(TEST_CINN_OPS ${TEST_CINN_OPS})
if(WITH_CINN)
......
......@@ -19,22 +19,36 @@ import numpy as np
import paddle
import paddle.fluid.core as core
from paddle.fluid.tests.unittests.op_test import (
OpTest,
OpTestTool,
convert_float_to_uint16,
)
from paddle.fluid.tests.unittests.test_reshape_op import TestReshapeOp
paddle.enable_static()
class TestReshape2OneDNNOp(TestReshapeOp):
class TestReshape2OneDNNOp(OpTest):
def setUp(self):
super().setUp()
self.init_data()
self.op_type = "reshape2"
self.python_api = paddle.tensor.reshape
self.python_out_sig = ['Out']
self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")}
self.attrs = {"shape": self.new_shape}
self.outputs = {
"Out": self.inputs["X"].reshape(self.infered_shape),
'XShape': np.random.random(self.ori_shape).astype("float32"),
}
self.x = self.inputs["X"]
self.attrs['use_mkldnn'] = True
self.set_additional_inputs()
self.set_outputs()
def init_data(self):
self.ori_shape = (2, 60)
self.new_shape = (12, 10)
self.infered_shape = (12, 10)
def init_dtype(self):
self.dtype = np.float32
......@@ -44,6 +58,12 @@ class TestReshape2OneDNNOp(TestReshapeOp):
def set_outputs(self):
pass
def test_check_output(self):
self.check_output(no_check_set=['XShape'])
def test_check_grad(self):
self.check_grad(["X"], "Out")
class TestReshape2OneDNNOpDimInfer1(TestReshape2OneDNNOp):
def init_data(self):
......
......@@ -27,6 +27,7 @@ class TestReshapeOp(OpTest):
def setUp(self):
self.init_data()
self.op_type = "reshape2"
self.prim_op_type = "prim"
self.python_api = paddle.tensor.reshape
self.python_out_sig = ['Out']
self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")}
......@@ -45,17 +46,31 @@ class TestReshapeOp(OpTest):
self.check_output(no_check_set=['XShape'])
def test_check_grad(self):
self.check_grad(["X"], "Out")
self.check_grad(["X"], "Out", check_prim=True)
class TestReshapeOp_ZeroDim1(TestReshapeOp):
def setUp(self):
self.init_data()
self.op_type = "reshape2"
self.prim_op_type = "prim"
self.enable_cinn = False
self.python_api = paddle.tensor.reshape
self.python_out_sig = ['Out']
self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")}
self.attrs = {"shape": self.new_shape}
self.outputs = {
"Out": self.inputs["X"].reshape(self.infered_shape),
'XShape': np.random.random(self.ori_shape).astype("float32"),
}
class TestReshapeOp_ZeroDim1(OpTest):
def init_data(self):
self.ori_shape = ()
self.new_shape = (1,)
self.infered_shape = (1,)
class TestReshapeOp_ZeroDim2(OpTest):
class TestReshapeOp_ZeroDim2(TestReshapeOp_ZeroDim1):
def init_data(self):
self.ori_shape = ()
self.new_shape = (-1,)
......@@ -73,6 +88,8 @@ class TestReshapeBF16Op(OpTest):
def setUp(self):
self.init_data()
self.op_type = "reshape2"
self.prim_op_type = "prim"
self.enable_cinn = False
self.python_api = paddle.tensor.reshape
self.python_out_sig = ['Out']
self.dtype = np.uint16
......@@ -96,7 +113,7 @@ class TestReshapeBF16Op(OpTest):
self.check_output(no_check_set=['XShape'])
def test_check_grad(self):
self.check_grad(["X"], "Out")
self.check_grad(["X"], "Out", check_prim=True)
class TestReshapeOpDimInfer1(TestReshapeOp):
......
......@@ -30,4 +30,5 @@ NEED_TO_FIX_OP_LIST = [
'rnn',
'multi_dot',
'index_add',
'reshape2',
]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册