diff --git a/paddle/fluid/operators/arg_min_max_op_base.h b/paddle/fluid/operators/arg_min_max_op_base.h index 69365357084b660b7c2f90149fe250854ea6a014..c296ddcfbef703e8484b6ea0b7f96f037e415186 100644 --- a/paddle/fluid/operators/arg_min_max_op_base.h +++ b/paddle/fluid/operators/arg_min_max_op_base.h @@ -166,10 +166,22 @@ class ArgMinMaxOp : public framework::OperatorWithKernel { platform::errors::InvalidArgument( "'axis'(%d) must be less than Rank(X)(%d).", axis, x_dims.size())); + const int& dtype = ctx->Attrs().Get("dtype"); + PADDLE_ENFORCE_EQ( + (dtype < 0 || dtype == 2 || dtype == 3), true, + platform::errors::InvalidArgument( + "The attribute of dtype in argmin/argmax must be [%s] or [%s], but " + "received [%s]", + paddle::framework::DataTypeToString( + framework::proto::VarType::INT32), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT64), + paddle::framework::DataTypeToString( + static_cast(dtype)))); + auto x_rank = x_dims.size(); if (axis < 0) axis += x_rank; if (ctx->IsRuntime()) { - const int& dtype = ctx->Attrs().Get("dtype"); if (dtype == framework::proto::VarType::INT32) { int64_t all_element_num = 0; if (flatten) { diff --git a/python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py b/python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py index 0fd9863948aedb64052e8fa0668f03600ae3197c..1b1b1d7c983282974d2fa46038c35c98de4f9ec2 100644 --- a/python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py @@ -322,6 +322,20 @@ class TestArgMinMaxOpError(unittest.TestCase): self.assertRaises(TypeError, test_argmin_axis_type) + def test_argmax_dtype_type(): + data = paddle.static.data( + name="test_argmax", shape=[10], dtype="float32") + output = paddle.argmax(x=data, dtype=1) + + self.assertRaises(TypeError, test_argmax_dtype_type) + + def test_argmin_dtype_type(): + data = paddle.static.data( + name="test_argmin", shape=[10], dtype="float32") + output = paddle.argmin(x=data, dtype=1) + + self.assertRaises(TypeError, test_argmin_dtype_type) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index 7543b091383f535108b02bb0ec08521942bb7e6a..ce03d0ef15f0f80f4e01cf57bc8cc449186c2560 100644 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -166,6 +166,12 @@ def argmax(x, axis=None, keepdim=False, dtype="int64", name=None): raise TypeError( "The type of 'axis' must be int or None in argmax, but received %s." % (type(axis))) + + if not (isinstance(dtype, str) or isinstance(dtype, np.dtype)): + raise TypeError( + "the type of 'dtype' in argmax must be str or np.dtype, but received {}". + format(type(dtype))) + var_dtype = convert_np_dtype_to_dtype_(dtype) check_dtype(var_dtype, 'dtype', ['int32', 'int64'], 'argmin') flatten = False @@ -238,6 +244,12 @@ def argmin(x, axis=None, keepdim=False, dtype="int64", name=None): raise TypeError( "The type of 'axis' must be int or None in argmin, but received %s." % (type(axis))) + + if not (isinstance(dtype, str) or isinstance(dtype, np.dtype)): + raise TypeError( + "the type of 'dtype' in argmin must be str or np.dtype, but received {}". + format(dtype(dtype))) + var_dtype = convert_np_dtype_to_dtype_(dtype) check_dtype(var_dtype, 'dtype', ['int32', 'int64'], 'argmin') flatten = False