diff --git a/python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py b/python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py index 1b1b1d7c983282974d2fa46038c35c98de4f9ec2..74f76030a29d2c9ce27278b61548c8877c1467ad 100644 --- a/python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py @@ -325,16 +325,16 @@ class TestArgMinMaxOpError(unittest.TestCase): def test_argmax_dtype_type(): data = paddle.static.data( name="test_argmax", shape=[10], dtype="float32") - output = paddle.argmax(x=data, dtype=1) + output = paddle.argmax(x=data, dtype=None) - self.assertRaises(TypeError, test_argmax_dtype_type) + self.assertRaises(ValueError, test_argmax_dtype_type) def test_argmin_dtype_type(): data = paddle.static.data( name="test_argmin", shape=[10], dtype="float32") - output = paddle.argmin(x=data, dtype=1) + output = paddle.argmin(x=data, dtype=None) - self.assertRaises(TypeError, test_argmin_dtype_type) + self.assertRaises(ValueError, test_argmin_dtype_type) if __name__ == '__main__': diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index ce03d0ef15f0f80f4e01cf57bc8cc449186c2560..f55d285586f0ec6959573af64e720bea5de10c8d 100644 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -167,10 +167,10 @@ def argmax(x, axis=None, keepdim=False, dtype="int64", name=None): "The type of 'axis' must be int or None in argmax, but received %s." % (type(axis))) - if not (isinstance(dtype, str) or isinstance(dtype, np.dtype)): - raise TypeError( - "the type of 'dtype' in argmax must be str or np.dtype, but received {}". - format(type(dtype))) + if dtype is None: + raise ValueError( + "the value of 'dtype' in argmax could not be None, but received None" + ) var_dtype = convert_np_dtype_to_dtype_(dtype) check_dtype(var_dtype, 'dtype', ['int32', 'int64'], 'argmin') @@ -245,10 +245,10 @@ def argmin(x, axis=None, keepdim=False, dtype="int64", name=None): "The type of 'axis' must be int or None in argmin, but received %s." % (type(axis))) - if not (isinstance(dtype, str) or isinstance(dtype, np.dtype)): - raise TypeError( - "the type of 'dtype' in argmin must be str or np.dtype, but received {}". - format(dtype(dtype))) + if dtype is None: + raise ValueError( + "the value of 'dtype' in argmin could not be None, but received None" + ) var_dtype = convert_np_dtype_to_dtype_(dtype) check_dtype(var_dtype, 'dtype', ['int32', 'int64'], 'argmin')