diff --git a/python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py b/python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py index f5cb975019c982cc86663dadbf8f4b5c365c6513..d23648ba65fe3fb8f3e32f6dfcffe4469d7476bb 100644 --- a/python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py @@ -366,5 +366,18 @@ class TestArgMinMaxOpError(unittest.TestCase): self.assertRaises(ValueError, test_argmin_dtype_type) +class TestArgMaxOpFp16(unittest.TestCase): + def test_fp16(self): + x_np = np.random.random((10, 16)).astype('float16') + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data(shape=[10, 16], name='x', dtype='float16') + out = paddle.argmax(x) + if core.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + exe = paddle.static.Executor(place) + exe.run(paddle.static.default_startup_program()) + out = exe.run(feed={'x': x_np}, fetch_list=[out]) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index 2b6b4671efbffb7cec25320ad6ab8067f2f2fff9..30ddfb13986fd43f2e046c730a4ad42229529f92 100755 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -127,7 +127,7 @@ def argmax(x, axis=None, keepdim=False, dtype="int64", name=None): element along the provided axis. Args: - x(Tensor): An input N-D Tensor with type float32, float64, int16, + x(Tensor): An input N-D Tensor with type float16, float32, float64, int16, int32, int64, uint8. axis(int, optional): Axis to compute indices along. The effective range is [-R, R), where R is x.ndim. when axis < 0, it works the same way @@ -185,7 +185,15 @@ def argmax(x, axis=None, keepdim=False, dtype="int64", name=None): check_variable_and_dtype( x, 'x', - ['float32', 'float64', 'int16', 'int32', 'int64', 'uint8'], + [ + 'float16', + 'float32', + 'float64', + 'int16', + 'int32', + 'int64', + 'uint8', + ], 'paddle.argmax', ) check_dtype(var_dtype, 'dtype', ['int32', 'int64'], 'argmin')