未验证 提交 f3aec871 编写于 作者: Z Zhang Ting 提交者: GitHub

[bug fix] fix fp16 dtype checking for argmax op (#50811)

* fix fp16 dtype checking for argmax op

* run fp16 test when place is gpu

* Update search.py

fix doc
上级 587120ec
...@@ -366,5 +366,18 @@ class TestArgMinMaxOpError(unittest.TestCase): ...@@ -366,5 +366,18 @@ class TestArgMinMaxOpError(unittest.TestCase):
self.assertRaises(ValueError, test_argmin_dtype_type) self.assertRaises(ValueError, test_argmin_dtype_type)
class TestArgMaxOpFp16(unittest.TestCase):
def test_fp16(self):
x_np = np.random.random((10, 16)).astype('float16')
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data(shape=[10, 16], name='x', dtype='float16')
out = paddle.argmax(x)
if core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
out = exe.run(feed={'x': x_np}, fetch_list=[out])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -127,7 +127,7 @@ def argmax(x, axis=None, keepdim=False, dtype="int64", name=None): ...@@ -127,7 +127,7 @@ def argmax(x, axis=None, keepdim=False, dtype="int64", name=None):
element along the provided axis. element along the provided axis.
Args: Args:
x(Tensor): An input N-D Tensor with type float32, float64, int16, x(Tensor): An input N-D Tensor with type float16, float32, float64, int16,
int32, int64, uint8. int32, int64, uint8.
axis(int, optional): Axis to compute indices along. The effective range axis(int, optional): Axis to compute indices along. The effective range
is [-R, R), where R is x.ndim. when axis < 0, it works the same way is [-R, R), where R is x.ndim. when axis < 0, it works the same way
...@@ -185,7 +185,15 @@ def argmax(x, axis=None, keepdim=False, dtype="int64", name=None): ...@@ -185,7 +185,15 @@ def argmax(x, axis=None, keepdim=False, dtype="int64", name=None):
check_variable_and_dtype( check_variable_and_dtype(
x, x,
'x', 'x',
['float32', 'float64', 'int16', 'int32', 'int64', 'uint8'], [
'float16',
'float32',
'float64',
'int16',
'int32',
'int64',
'uint8',
],
'paddle.argmax', 'paddle.argmax',
) )
check_dtype(var_dtype, 'dtype', ['int32', 'int64'], 'argmin') check_dtype(var_dtype, 'dtype', ['int32', 'int64'], 'argmin')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册