未验证 提交 0072490f 编写于 作者: W wawltor 提交者: GitHub

udpate the dtype check for the argmin, argmax

remove the bug for the checking the type of dtype
上级 f2c21ff9
...@@ -325,16 +325,16 @@ class TestArgMinMaxOpError(unittest.TestCase): ...@@ -325,16 +325,16 @@ class TestArgMinMaxOpError(unittest.TestCase):
def test_argmax_dtype_type(): def test_argmax_dtype_type():
data = paddle.static.data( data = paddle.static.data(
name="test_argmax", shape=[10], dtype="float32") name="test_argmax", shape=[10], dtype="float32")
output = paddle.argmax(x=data, dtype=1) output = paddle.argmax(x=data, dtype=None)
self.assertRaises(TypeError, test_argmax_dtype_type) self.assertRaises(ValueError, test_argmax_dtype_type)
def test_argmin_dtype_type(): def test_argmin_dtype_type():
data = paddle.static.data( data = paddle.static.data(
name="test_argmin", shape=[10], dtype="float32") name="test_argmin", shape=[10], dtype="float32")
output = paddle.argmin(x=data, dtype=1) output = paddle.argmin(x=data, dtype=None)
self.assertRaises(TypeError, test_argmin_dtype_type) self.assertRaises(ValueError, test_argmin_dtype_type)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -167,10 +167,10 @@ def argmax(x, axis=None, keepdim=False, dtype="int64", name=None): ...@@ -167,10 +167,10 @@ def argmax(x, axis=None, keepdim=False, dtype="int64", name=None):
"The type of 'axis' must be int or None in argmax, but received %s." "The type of 'axis' must be int or None in argmax, but received %s."
% (type(axis))) % (type(axis)))
if not (isinstance(dtype, str) or isinstance(dtype, np.dtype)): if dtype is None:
raise TypeError( raise ValueError(
"the type of 'dtype' in argmax must be str or np.dtype, but received {}". "the value of 'dtype' in argmax could not be None, but received None"
format(type(dtype))) )
var_dtype = convert_np_dtype_to_dtype_(dtype) var_dtype = convert_np_dtype_to_dtype_(dtype)
check_dtype(var_dtype, 'dtype', ['int32', 'int64'], 'argmin') check_dtype(var_dtype, 'dtype', ['int32', 'int64'], 'argmin')
...@@ -245,10 +245,10 @@ def argmin(x, axis=None, keepdim=False, dtype="int64", name=None): ...@@ -245,10 +245,10 @@ def argmin(x, axis=None, keepdim=False, dtype="int64", name=None):
"The type of 'axis' must be int or None in argmin, but received %s." "The type of 'axis' must be int or None in argmin, but received %s."
% (type(axis))) % (type(axis)))
if not (isinstance(dtype, str) or isinstance(dtype, np.dtype)): if dtype is None:
raise TypeError( raise ValueError(
"the type of 'dtype' in argmin must be str or np.dtype, but received {}". "the value of 'dtype' in argmin could not be None, but received None"
format(dtype(dtype))) )
var_dtype = convert_np_dtype_to_dtype_(dtype) var_dtype = convert_np_dtype_to_dtype_(dtype)
check_dtype(var_dtype, 'dtype', ['int32', 'int64'], 'argmin') check_dtype(var_dtype, 'dtype', ['int32', 'int64'], 'argmin')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册