未验证 提交 77298931 编写于 作者: 张春乔 提交者: GitHub

suppot fp16 in broadcast (#50905)

上级 d832a54d
...@@ -70,6 +70,41 @@ class TestBroadcastToAPI(unittest.TestCase): ...@@ -70,6 +70,41 @@ class TestBroadcastToAPI(unittest.TestCase):
assert np.array_equal(res_2, np.tile(input, (1, 1))) assert np.array_equal(res_2, np.tile(input, (1, 1)))
assert np.array_equal(res_3, np.tile(input, (1, 1))) assert np.array_equal(res_3, np.tile(input, (1, 1)))
def test_api_fp16_gpu(self):
if paddle.fluid.core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
input = np.random.random([12, 14]).astype("float16")
x = paddle.static.data(
name="x", shape=[12, 14], dtype="float16"
)
positive_2 = paddle.fluid.layers.fill_constant([1], "int32", 12)
expand_shape = paddle.static.data(
name="expand_shape",
shape=[2],
dtype="int32",
)
out_1 = paddle.broadcast_to(x, shape=[12, 14])
out_2 = paddle.broadcast_to(x, shape=[positive_2, 14])
out_3 = paddle.broadcast_to(x, shape=expand_shape)
exe = paddle.static.Executor(place)
res_1, res_2, res_3 = exe.run(
paddle.static.default_main_program(),
feed={
"x": input,
"expand_shape": np.array([12, 14]).astype("int32"),
},
fetch_list=[out_1, out_2, out_3],
)
assert np.array_equal(res_1, np.tile(input, (1, 1)))
assert np.array_equal(res_2, np.tile(input, (1, 1)))
assert np.array_equal(res_3, np.tile(input, (1, 1)))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -3256,7 +3256,7 @@ def broadcast_to(x, shape, name=None): ...@@ -3256,7 +3256,7 @@ def broadcast_to(x, shape, name=None):
Args: Args:
x (Tensor): The input tensor, its data type is bool, float32, float64, int32 or int64. x (Tensor): The input tensor, its data type is bool, float16, float32, float64, int32 or int64.
shape (list|tuple|Tensor): The result shape after broadcasting. The data type is int32. If shape is a list or tuple, all its elements shape (list|tuple|Tensor): The result shape after broadcasting. The data type is int32. If shape is a list or tuple, all its elements
should be integers or 0-D or 1-D Tensors with the data type int32. If shape is a Tensor, it should be an 1-D Tensor with the data type int32. should be integers or 0-D or 1-D Tensors with the data type int32. If shape is a Tensor, it should be an 1-D Tensor with the data type int32.
The value -1 in shape means keeping the corresponding dimension unchanged. The value -1 in shape means keeping the corresponding dimension unchanged.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册