未验证 提交 f5e4a316 编写于 作者: H HongyuJia 提交者: GitHub

[0D-Tensor] CINN supports reshape (#55326)

上级 d1b74ba5
...@@ -805,6 +805,84 @@ class TestDropoutOp(OpTest): ...@@ -805,6 +805,84 @@ class TestDropoutOp(OpTest):
self.check_outputs_and_grads() self.check_outputs_and_grads()
@OpTestTool.skip_if(
not is_compiled_with_cuda(), "x86 test will be skipped due to timeout."
)
class TestReshapeOp(OpTest):
def setUp(self):
np.random.seed(2023)
self.dtype = "float32"
self.init_input()
def init_input(self):
self.inputs = {
"x": np.random.randint(-10, 10, []).astype(self.dtype),
}
self.target_shape = [1]
def build_paddle_program(self, target):
x = paddle.to_tensor(self.inputs["x"], stop_gradient=False)
out = paddle.reshape(x, self.target_shape)
self.paddle_outputs = [out]
def build_cinn_program(self, target):
builder = NetBuilder("reshape_op")
x = builder.create_input(
cinn_dtype_convert(self.dtype), self.inputs["x"].shape, "x"
)
out = builder.reshape(x, self.target_shape)
prog = builder.build()
res = self.get_cinn_output(prog, target, [x], [self.inputs["x"]], [out])
self.cinn_outputs = res
self.assertEqual(list(res[0].shape), [1] * len(self.target_shape))
def test_check_results(self):
self.check_outputs_and_grads()
class TestReshapeOp0DTo2D(TestReshapeOp):
def init_input(self):
self.inputs = {
"x": np.random.randint(-10, 10, []).astype(self.dtype),
}
self.target_shape = [1, 1]
class TestReshapeOp0DTo1D_DS(TestReshapeOp):
def init_input(self):
self.inputs = {
"x": np.random.randint(-10, 10, []).astype(self.dtype),
}
self.target_shape = [-1]
class TestReshapeOp0DTo2D_DS(TestReshapeOp):
def init_input(self):
self.inputs = {
"x": np.random.randint(-10, 10, []).astype(self.dtype),
}
self.target_shape = [-1, 1]
class TestReshapeOp0DTo0D(TestReshapeOp):
def init_input(self):
self.inputs = {
"x": np.random.randint(-10, 10, []).astype(self.dtype),
}
self.target_shape = []
class TestReshapeOp1DTo0D(TestReshapeOp):
def init_input(self):
self.inputs = {
"x": np.random.randint(-10, 10, [1]).astype(self.dtype),
}
self.target_shape = []
@OpTestTool.skip_if( @OpTestTool.skip_if(
not is_compiled_with_cuda(), "x86 test will be skipped due to timeout." not is_compiled_with_cuda(), "x86 test will be skipped due to timeout."
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册