未验证 提交 b0337433 编写于 作者: W Wen Sun 提交者: GitHub

Fix `broadcast_object_list` & `scatter_object_list` not work in specified group (#51762)

* fix: fix broadcast object list

* fix: fix scatter object list
上级 391b210f
......@@ -118,7 +118,7 @@ def broadcast_object_list(object_list, src, group=None):
obj_size_tensor = paddle.concat(obj_sizes)
else:
obj_size_tensor = paddle.empty([obj_nums], dtype="int64")
broadcast(obj_size_tensor, src)
broadcast(obj_size_tensor, src, group)
if rank == src:
# cast to uint8 to keep the same dtype
......@@ -126,7 +126,7 @@ def broadcast_object_list(object_list, src, group=None):
else:
data_len = paddle.sum(obj_size_tensor).item()
obj_data_tensor = paddle.empty([data_len], dtype="uint8")
broadcast(obj_data_tensor, src)
broadcast(obj_data_tensor, src, group)
offset = 0
for i in range(obj_nums):
......
......@@ -135,11 +135,11 @@ def scatter_object_list(
in_tensor = paddle.to_tensor(numpy_data)
in_tensor_list.append(in_tensor)
out_tensor = paddle.empty([max_obj_size], dtype="uint8")
scatter(out_tensor, in_tensor_list if rank == src else None, src)
scatter(out_tensor, in_tensor_list if rank == src else None, src, group)
# NOTE: shape can be [] after 0D tensor support
out_tensor_size = paddle.empty([1], dtype="int64")
scatter(out_tensor_size, in_obj_sizes if rank == src else None, src)
scatter(out_tensor_size, in_obj_sizes if rank == src else None, src, group)
out_object_list.clear()
out_object_list.append(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册