diff --git a/paddle/cinn/hlir/op/contrib/argmin.cc b/paddle/cinn/hlir/op/contrib/argmin.cc index d0d1c63a332e7e88656167ba3cf37e3f12a3673f..ada26762e1938a22699504c56f9a7b90a9c8df9d 100644 --- a/paddle/cinn/hlir/op/contrib/argmin.cc +++ b/paddle/cinn/hlir/op/contrib/argmin.cc @@ -183,17 +183,22 @@ std::vector InferShapeForArgmin( const framework::AttrMapType &attrs) { CHECK_EQ(inputs_shape.size(), 1UL); auto ndim = inputs_shape[0].size(); - CHECK_GT(ndim, 0) << "tensor's dim must be more than 0"; int axis; bool keep_dim; CHECK(attrs.find("axis") != attrs.end()); axis = absl::get(attrs.at("axis")); - if (axis < 0) { - axis = static_cast(ndim) + axis; + if (ndim > 0) { + if (axis < 0) { + axis = static_cast(ndim) + axis; + } + CHECK_LT(axis, ndim) << "Axis must be less than tensor's dim"; + CHECK_GE(axis, 0) << "Axis must be more than 0"; + } else { + // 0D Tensor + CHECK(axis == 0 || axis == -1) + << "Axis must be 0 or -1 if input tensor is 0-dim"; } - CHECK_LT(axis, ndim) << "Axis must be less than tensor's dim"; - CHECK_GE(axis, 0) << "Axis must be more than 0"; CHECK(attrs.find("keep_dim") != attrs.end()); keep_dim = absl::get(attrs.at("keep_dim")); @@ -212,11 +217,7 @@ std::vector InferShapeForArgmin( if (keep_dim) { CHECK_EQ(ndim, out_shapes.size()); } else { - CHECK_EQ(ndim - 1, out_shapes.size()); - } - - if (out_shapes.empty()) { - out_shapes.push_back(1); + CHECK(ndim - 1 == out_shapes.size() || ndim == 0 && out_shapes.empty()); } return {out_shapes}; diff --git a/test/cinn/ops/test_zero_dim_tensor.py b/test/cinn/ops/test_zero_dim_tensor.py index 8e3dbdf98d6faee6315fcf1a858ff57f5eddd430..86e57e7e2feaa043be40abe8297fd2fdf175ee66 100644 --- a/test/cinn/ops/test_zero_dim_tensor.py +++ b/test/cinn/ops/test_zero_dim_tensor.py @@ -698,7 +698,7 @@ class TestArgmaxOp(OpTest): self.paddle_outputs = [out] def build_cinn_program(self, target): - builder = NetBuilder("squeeze_op") + builder = NetBuilder("argmax_op") x = builder.create_input( cinn_dtype_convert(self.dtype), self.inputs["x"].shape, "x" ) @@ -750,6 +750,82 @@ class TestArgmaxOp2DKeepDim(TestArgmaxOp): self.target_shape = (1, 5) +@OpTestTool.skip_if( + not is_compiled_with_cuda(), "x86 test will be skipped due to timeout." +) +class TestArgminOp(OpTest): + def setUp(self): + np.random.seed(2023) + self.dtype = "float32" + self.init_input() + + def init_input(self): + self.inputs = { + "x": np.random.randint(-10, 10, []).astype(self.dtype), + } + self.param = (0,) + self.target_shape = () + + def build_paddle_program(self, target): + x = paddle.to_tensor(self.inputs["x"], stop_gradient=False) + out = paddle.argmin(x, *self.param) + out = paddle.cast(out, 'int32') + + self.paddle_outputs = [out] + + def build_cinn_program(self, target): + builder = NetBuilder("argmin_op") + x = builder.create_input( + cinn_dtype_convert(self.dtype), self.inputs["x"].shape, "x" + ) + out = builder.argmin(x, *self.param) + + prog = builder.build() + res = self.get_cinn_output(prog, target, [x], [self.inputs["x"]], [out]) + + self.cinn_outputs = res + self.assertEqual(res[0].shape, self.target_shape) + + def test_check_results(self): + self.check_outputs_and_grads() + + +class TestArgminOp2(TestArgminOp): + def init_input(self): + self.inputs = { + "x": np.random.randint(-10, 10, []).astype(self.dtype), + } + self.param = (-1,) + self.target_shape = () + + +class TestArgminOp1D(TestArgminOp): + def init_input(self): + self.inputs = { + "x": np.random.randint(-10, 10, [5]).astype(self.dtype), + } + self.param = (0,) + self.target_shape = () + + +class TestArgminOp2D(TestArgminOp): + def init_input(self): + self.inputs = { + "x": np.random.randint(-10, 10, [3, 5]).astype(self.dtype), + } + self.param = (0,) + self.target_shape = (5,) + + +class TestArgminOp2DKeepDim(TestArgminOp): + def init_input(self): + self.inputs = { + "x": np.random.randint(-10, 10, [3, 5]).astype(self.dtype), + } + self.param = (0, True) + self.target_shape = (1, 5) + + @OpTestTool.skip_if( not is_compiled_with_cuda(), "x86 test will be skipped due to timeout." )