diff --git a/test/cinn/ops/test_zero_dim_tensor.py b/test/cinn/ops/test_zero_dim_tensor.py index 16c110b2298a202ce36808fb3a4e2eb977948d3d..26a7ce52375cf6c29c1727b9bec7dfaa2bdc12d8 100644 --- a/test/cinn/ops/test_zero_dim_tensor.py +++ b/test/cinn/ops/test_zero_dim_tensor.py @@ -805,6 +805,84 @@ class TestDropoutOp(OpTest): self.check_outputs_and_grads() +@OpTestTool.skip_if( + not is_compiled_with_cuda(), "x86 test will be skipped due to timeout." +) +class TestReshapeOp(OpTest): + def setUp(self): + np.random.seed(2023) + self.dtype = "float32" + self.init_input() + + def init_input(self): + self.inputs = { + "x": np.random.randint(-10, 10, []).astype(self.dtype), + } + self.target_shape = [1] + + def build_paddle_program(self, target): + x = paddle.to_tensor(self.inputs["x"], stop_gradient=False) + out = paddle.reshape(x, self.target_shape) + + self.paddle_outputs = [out] + + def build_cinn_program(self, target): + builder = NetBuilder("reshape_op") + x = builder.create_input( + cinn_dtype_convert(self.dtype), self.inputs["x"].shape, "x" + ) + out = builder.reshape(x, self.target_shape) + + prog = builder.build() + res = self.get_cinn_output(prog, target, [x], [self.inputs["x"]], [out]) + + self.cinn_outputs = res + self.assertEqual(list(res[0].shape), [1] * len(self.target_shape)) + + def test_check_results(self): + self.check_outputs_and_grads() + + +class TestReshapeOp0DTo2D(TestReshapeOp): + def init_input(self): + self.inputs = { + "x": np.random.randint(-10, 10, []).astype(self.dtype), + } + self.target_shape = [1, 1] + + +class TestReshapeOp0DTo1D_DS(TestReshapeOp): + def init_input(self): + self.inputs = { + "x": np.random.randint(-10, 10, []).astype(self.dtype), + } + self.target_shape = [-1] + + +class TestReshapeOp0DTo2D_DS(TestReshapeOp): + def init_input(self): + self.inputs = { + "x": np.random.randint(-10, 10, []).astype(self.dtype), + } + self.target_shape = [-1, 1] + + +class TestReshapeOp0DTo0D(TestReshapeOp): + def init_input(self): + self.inputs = { + "x": np.random.randint(-10, 10, []).astype(self.dtype), + } + self.target_shape = [] + + +class TestReshapeOp1DTo0D(TestReshapeOp): + def init_input(self): + self.inputs = { + "x": np.random.randint(-10, 10, [1]).astype(self.dtype), + } + self.target_shape = [] + + @OpTestTool.skip_if( not is_compiled_with_cuda(), "x86 test will be skipped due to timeout." )