From b0337433ba9c5c513cd70ccb8d1ceee8c43d16fc Mon Sep 17 00:00:00 2001 From: Wen Sun <35923278+HermitSun@users.noreply.github.com> Date: Sat, 18 Mar 2023 00:12:21 +0800 Subject: [PATCH] Fix `broadcast_object_list` & `scatter_object_list` not work in specified group (#51762) * fix: fix broadcast object list * fix: fix scatter object list --- python/paddle/distributed/communication/broadcast.py | 4 ++-- python/paddle/distributed/communication/scatter.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/paddle/distributed/communication/broadcast.py b/python/paddle/distributed/communication/broadcast.py index 2e5dde826b..114dddb69e 100644 --- a/python/paddle/distributed/communication/broadcast.py +++ b/python/paddle/distributed/communication/broadcast.py @@ -118,7 +118,7 @@ def broadcast_object_list(object_list, src, group=None): obj_size_tensor = paddle.concat(obj_sizes) else: obj_size_tensor = paddle.empty([obj_nums], dtype="int64") - broadcast(obj_size_tensor, src) + broadcast(obj_size_tensor, src, group) if rank == src: # cast to uint8 to keep the same dtype @@ -126,7 +126,7 @@ def broadcast_object_list(object_list, src, group=None): else: data_len = paddle.sum(obj_size_tensor).item() obj_data_tensor = paddle.empty([data_len], dtype="uint8") - broadcast(obj_data_tensor, src) + broadcast(obj_data_tensor, src, group) offset = 0 for i in range(obj_nums): diff --git a/python/paddle/distributed/communication/scatter.py b/python/paddle/distributed/communication/scatter.py index 455bb5d1cf..625954c391 100644 --- a/python/paddle/distributed/communication/scatter.py +++ b/python/paddle/distributed/communication/scatter.py @@ -135,11 +135,11 @@ def scatter_object_list( in_tensor = paddle.to_tensor(numpy_data) in_tensor_list.append(in_tensor) out_tensor = paddle.empty([max_obj_size], dtype="uint8") - scatter(out_tensor, in_tensor_list if rank == src else None, src) + scatter(out_tensor, in_tensor_list if rank == src else None, src, group) # NOTE: shape can be [] after 0D tensor support out_tensor_size = paddle.empty([1], dtype="int64") - scatter(out_tensor_size, in_obj_sizes if rank == src else None, src) + scatter(out_tensor_size, in_obj_sizes if rank == src else None, src, group) out_object_list.clear() out_object_list.append( -- GitLab