From 0072490f9eaf46600c8d62d183ca6ca76727b328 Mon Sep 17 00:00:00 2001 From: wawltor Date: Tue, 8 Sep 2020 10:16:10 +0800 Subject: [PATCH] udpate the dtype check for the argmin, argmax remove the bug for the checking the type of dtype --- .../tests/unittests/test_arg_min_max_v2_op.py | 8 ++++---- python/paddle/tensor/search.py | 16 ++++++++-------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py b/python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py index 1b1b1d7c983..74f76030a29 100644 --- a/python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py @@ -325,16 +325,16 @@ class TestArgMinMaxOpError(unittest.TestCase): def test_argmax_dtype_type(): data = paddle.static.data( name="test_argmax", shape=[10], dtype="float32") - output = paddle.argmax(x=data, dtype=1) + output = paddle.argmax(x=data, dtype=None) - self.assertRaises(TypeError, test_argmax_dtype_type) + self.assertRaises(ValueError, test_argmax_dtype_type) def test_argmin_dtype_type(): data = paddle.static.data( name="test_argmin", shape=[10], dtype="float32") - output = paddle.argmin(x=data, dtype=1) + output = paddle.argmin(x=data, dtype=None) - self.assertRaises(TypeError, test_argmin_dtype_type) + self.assertRaises(ValueError, test_argmin_dtype_type) if __name__ == '__main__': diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index ce03d0ef15f..f55d285586f 100644 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -167,10 +167,10 @@ def argmax(x, axis=None, keepdim=False, dtype="int64", name=None): "The type of 'axis' must be int or None in argmax, but received %s." % (type(axis))) - if not (isinstance(dtype, str) or isinstance(dtype, np.dtype)): - raise TypeError( - "the type of 'dtype' in argmax must be str or np.dtype, but received {}". - format(type(dtype))) + if dtype is None: + raise ValueError( + "the value of 'dtype' in argmax could not be None, but received None" + ) var_dtype = convert_np_dtype_to_dtype_(dtype) check_dtype(var_dtype, 'dtype', ['int32', 'int64'], 'argmin') @@ -245,10 +245,10 @@ def argmin(x, axis=None, keepdim=False, dtype="int64", name=None): "The type of 'axis' must be int or None in argmin, but received %s." % (type(axis))) - if not (isinstance(dtype, str) or isinstance(dtype, np.dtype)): - raise TypeError( - "the type of 'dtype' in argmin must be str or np.dtype, but received {}". - format(dtype(dtype))) + if dtype is None: + raise ValueError( + "the value of 'dtype' in argmin could not be None, but received None" + ) var_dtype = convert_np_dtype_to_dtype_(dtype) check_dtype(var_dtype, 'dtype', ['int32', 'int64'], 'argmin') -- GitLab