未验证 提交 5f60b597 编写于 作者: 张春乔 提交者: GitHub

support fp16 on unbind (#50916)

上级 336cd205
...@@ -42,6 +42,29 @@ class TestUnbind(unittest.TestCase): ...@@ -42,6 +42,29 @@ class TestUnbind(unittest.TestCase):
assert np.array_equal(res_1, input_1[0, 0:100]) assert np.array_equal(res_1, input_1[0, 0:100])
assert np.array_equal(res_2, input_1[1, 0:100]) assert np.array_equal(res_2, input_1[1, 0:100])
def test_unbind_static_fp16_gpu(self):
if paddle.fluid.core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
input = np.random.random([2, 3]).astype("float16")
x = paddle.static.data(name="x", shape=[2, 3], dtype="float16")
y = paddle.unbind(x)
exe = paddle.static.Executor(place)
res = exe.run(
paddle.static.default_main_program(),
feed={
"x": input,
},
fetch_list=[y],
)
assert np.array_equal(res[0], input[0, :])
assert np.array_equal(res[1], input[1, :])
def test_unbind_dygraph(self): def test_unbind_dygraph(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
np_x = np.random.random([2, 3]).astype("float32") np_x = np.random.random([2, 3]).astype("float32")
......
...@@ -2738,7 +2738,7 @@ def unbind(input, axis=0): ...@@ -2738,7 +2738,7 @@ def unbind(input, axis=0):
Removes a tensor dimension, then split the input tensor into multiple sub-Tensors. Removes a tensor dimension, then split the input tensor into multiple sub-Tensors.
Args: Args:
input (Tensor): The input variable which is an N-D Tensor, data type being float32, float64, int32 or int64. input (Tensor): The input variable which is an N-D Tensor, data type being float16, float32, float64, int32 or int64.
axis (int32|int64, optional): A scalar with type ``int32|int64`` shape [1]. The dimension along which to unbind. axis (int32|int64, optional): A scalar with type ``int32|int64`` shape [1]. The dimension along which to unbind.
If :math:`axis < 0`, the dimension to unbind along is :math:`rank(input) + axis`. Default is 0. If :math:`axis < 0`, the dimension to unbind along is :math:`rank(input) + axis`. Default is 0.
Returns: Returns:
...@@ -2785,7 +2785,10 @@ def unbind(input, axis=0): ...@@ -2785,7 +2785,10 @@ def unbind(input, axis=0):
check_type(input, 'input', (Variable), 'unbind') check_type(input, 'input', (Variable), 'unbind')
dtype = helper.input_dtype() dtype = helper.input_dtype()
check_dtype( check_dtype(
dtype, 'unbind', ['float32', 'float64', 'int32', 'int64'], 'unbind' dtype,
'unbind',
['float16', 'float32', 'float64', 'int32', 'int64'],
'unbind',
) )
outs = [ outs = [
helper.create_variable_for_type_inference( helper.create_variable_for_type_inference(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册