未验证 提交 b0337433 编写于 作者: W Wen Sun 提交者: GitHub

Fix `broadcast_object_list` & `scatter_object_list` not work in specified group (#51762)

* fix: fix broadcast object list

* fix: fix scatter object list
上级 391b210f
...@@ -118,7 +118,7 @@ def broadcast_object_list(object_list, src, group=None): ...@@ -118,7 +118,7 @@ def broadcast_object_list(object_list, src, group=None):
obj_size_tensor = paddle.concat(obj_sizes) obj_size_tensor = paddle.concat(obj_sizes)
else: else:
obj_size_tensor = paddle.empty([obj_nums], dtype="int64") obj_size_tensor = paddle.empty([obj_nums], dtype="int64")
broadcast(obj_size_tensor, src) broadcast(obj_size_tensor, src, group)
if rank == src: if rank == src:
# cast to uint8 to keep the same dtype # cast to uint8 to keep the same dtype
...@@ -126,7 +126,7 @@ def broadcast_object_list(object_list, src, group=None): ...@@ -126,7 +126,7 @@ def broadcast_object_list(object_list, src, group=None):
else: else:
data_len = paddle.sum(obj_size_tensor).item() data_len = paddle.sum(obj_size_tensor).item()
obj_data_tensor = paddle.empty([data_len], dtype="uint8") obj_data_tensor = paddle.empty([data_len], dtype="uint8")
broadcast(obj_data_tensor, src) broadcast(obj_data_tensor, src, group)
offset = 0 offset = 0
for i in range(obj_nums): for i in range(obj_nums):
......
...@@ -135,11 +135,11 @@ def scatter_object_list( ...@@ -135,11 +135,11 @@ def scatter_object_list(
in_tensor = paddle.to_tensor(numpy_data) in_tensor = paddle.to_tensor(numpy_data)
in_tensor_list.append(in_tensor) in_tensor_list.append(in_tensor)
out_tensor = paddle.empty([max_obj_size], dtype="uint8") out_tensor = paddle.empty([max_obj_size], dtype="uint8")
scatter(out_tensor, in_tensor_list if rank == src else None, src) scatter(out_tensor, in_tensor_list if rank == src else None, src, group)
# NOTE: shape can be [] after 0D tensor support # NOTE: shape can be [] after 0D tensor support
out_tensor_size = paddle.empty([1], dtype="int64") out_tensor_size = paddle.empty([1], dtype="int64")
scatter(out_tensor_size, in_obj_sizes if rank == src else None, src) scatter(out_tensor_size, in_obj_sizes if rank == src else None, src, group)
out_object_list.clear() out_object_list.clear()
out_object_list.append( out_object_list.append(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册