未验证 提交 39d5bb6d 编写于 作者: W wawltor 提交者: GitHub

udpate the dtype check for the argmin, argmax

fix the bug for dtype check for the argmin/argmax
上级 9b7692b1
......@@ -325,16 +325,16 @@ class TestArgMinMaxOpError(unittest.TestCase):
def test_argmax_dtype_type():
data = paddle.static.data(
name="test_argmax", shape=[10], dtype="float32")
output = paddle.argmax(x=data, dtype=1)
output = paddle.argmax(x=data, dtype=None)
self.assertRaises(TypeError, test_argmax_dtype_type)
self.assertRaises(ValueError, test_argmax_dtype_type)
def test_argmin_dtype_type():
data = paddle.static.data(
name="test_argmin", shape=[10], dtype="float32")
output = paddle.argmin(x=data, dtype=1)
output = paddle.argmin(x=data, dtype=None)
self.assertRaises(TypeError, test_argmin_dtype_type)
self.assertRaises(ValueError, test_argmin_dtype_type)
if __name__ == '__main__':
......
......@@ -167,10 +167,10 @@ def argmax(x, axis=None, keepdim=False, dtype="int64", name=None):
"The type of 'axis' must be int or None in argmax, but received %s."
% (type(axis)))
if not (isinstance(dtype, str) or isinstance(dtype, np.dtype)):
raise TypeError(
"the type of 'dtype' in argmax must be str or np.dtype, but received {}".
format(type(dtype)))
if dtype is None:
raise ValueError(
"the value of 'dtype' in argmax could not be None, but received None"
)
var_dtype = convert_np_dtype_to_dtype_(dtype)
check_dtype(var_dtype, 'dtype', ['int32', 'int64'], 'argmin')
......@@ -245,10 +245,10 @@ def argmin(x, axis=None, keepdim=False, dtype="int64", name=None):
"The type of 'axis' must be int or None in argmin, but received %s."
% (type(axis)))
if not (isinstance(dtype, str) or isinstance(dtype, np.dtype)):
raise TypeError(
"the type of 'dtype' in argmin must be str or np.dtype, but received {}".
format(dtype(dtype)))
if dtype is None:
raise ValueError(
"the value of 'dtype' in argmin could not be None, but received None"
)
var_dtype = convert_np_dtype_to_dtype_(dtype)
check_dtype(var_dtype, 'dtype', ['int32', 'int64'], 'argmin')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册