未验证 提交 69d49aba 编写于 作者: I Infinity_lee 提交者: GitHub

[fp16] suppot fp16 in argmin (#50858)

上级 72cbb6da
......@@ -379,5 +379,18 @@ class TestArgMaxOpFp16(unittest.TestCase):
out = exe.run(feed={'x': x_np}, fetch_list=[out])
class TestArgMinOpFp16(unittest.TestCase):
def test_fp16(self):
x_np = np.random.random((10, 16)).astype('float16')
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data(shape=[10, 16], name='x', dtype='float16')
out = paddle.argmin(x)
if core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
out = exe.run(feed={'x': x_np}, fetch_list=[out])
if __name__ == '__main__':
unittest.main()
......@@ -224,7 +224,7 @@ def argmin(x, axis=None, keepdim=False, dtype="int64", name=None):
element along the provided axis.
Args:
x(Tensor): An input N-D Tensor with type float32, float64, int16,
x(Tensor): An input N-D Tensor with type float16, float32, float64, int16,
int32, int64, uint8.
axis(int, optional): Axis to compute indices along. The effective range
is [-R, R), where R is x.ndim. when axis < 0, it works the same way
......@@ -282,7 +282,15 @@ def argmin(x, axis=None, keepdim=False, dtype="int64", name=None):
check_variable_and_dtype(
x,
'x',
['float32', 'float64', 'int16', 'int32', 'int64', 'uint8'],
[
'float16',
'float32',
'float64',
'int16',
'int32',
'int64',
'uint8',
],
'paddle.argmin',
)
check_dtype(var_dtype, 'dtype', ['int32', 'int64'], 'argmin')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册