diff --git a/paddle/cinn/hlir/op/transform.cc b/paddle/cinn/hlir/op/transform.cc index 7df8e440c6838173001cf82e283e17548f6c7cb4..12272582044ab7541419a6e1ca3890a5d90db618 100644 --- a/paddle/cinn/hlir/op/transform.cc +++ b/paddle/cinn/hlir/op/transform.cc @@ -1020,8 +1020,8 @@ std::shared_ptr StrategyForReverse( std::vector InferShapeForReverse( const std::vector &inputs_shape, const framework::AttrMapType &attrs) { - CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) - << "The input's shape size is 0! Please check again."; + CHECK(!inputs_shape.empty()) + << "The input's shape is empty! Please check again."; std::vector res{inputs_shape[0]}; if (attrs.find("axis") != attrs.end()) { auto axis = absl::get>(attrs.at("axis")); diff --git a/test/cinn/ops/test_zero_dim_tensor.py b/test/cinn/ops/test_zero_dim_tensor.py index 2b5ae1f5cac710865d09dccc3850f03100ff4f28..4467fe32d3d6c61dc14e531e881b8fce148b3a56 100644 --- a/test/cinn/ops/test_zero_dim_tensor.py +++ b/test/cinn/ops/test_zero_dim_tensor.py @@ -776,6 +776,44 @@ class TestBroadcastToOp3D(TestBroadcastToOp1D): self.broadcast_shape = [3, 3, 3] +@OpTestTool.skip_if( + not is_compiled_with_cuda(), "x86 test will be skipped due to timeout." +) +class TestReverseOp(OpTest): + def setUp(self): + np.random.seed(2023) + self.dtype = "float32" + self.init_input() + + def init_input(self): + self.inputs = { + "x": np.random.randint(-10, 10, []).astype(self.dtype), + } + self.target_shape = () + + def build_paddle_program(self, target): + x = paddle.to_tensor(self.inputs["x"], stop_gradient=False) + out = paddle.reverse(x, axis=[]) + + self.paddle_outputs = [out] + + def build_cinn_program(self, target): + builder = NetBuilder("reverse_op") + x = builder.create_input( + cinn_dtype_convert(self.dtype), self.inputs["x"].shape, "x" + ) + out = builder.reverse(x, []) + + prog = builder.build() + res = self.get_cinn_output(prog, target, [x], [self.inputs["x"]], [out]) + + self.cinn_outputs = res + self.assertEqual(res[0].shape, self.target_shape) + + def test_check_results(self): + self.check_outputs_and_grads() + + @OpTestTool.skip_if( not is_compiled_with_cuda(), "x86 test will be skipped due to timeout." )