diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index 2506c3073941afa6a86b6739abc169e870e69a08..62b18298f11e05b16f0d09ab645592f307fc0b0b 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -2030,6 +2030,10 @@ def alltoall_single(in_tensor, return task +def _get_group_rank(global_rank, group=None): + return global_rank if group is None else group.get_group_rank(global_rank) + + def send(tensor, dst=0, group=None, use_calc_stream=True): """ Send a tensor to the receiver. @@ -2062,11 +2066,10 @@ def send(tensor, dst=0, group=None, use_calc_stream=True): """ if group is not None and not group.is_member(): return - + dst = _get_group_rank(dst, group) if in_dygraph_mode(): group = _get_default_group() if group is None else group - group_dst_rank = group.get_group_rank(dst) - task = group.process_group.send(tensor, group_dst_rank) + task = group.process_group.send(tensor, dst) if use_calc_stream: task.wait() return None @@ -2126,10 +2129,10 @@ def recv(tensor, src=0, group=None, use_calc_stream=True): if group is not None and not group.is_member(): return + src = _get_group_rank(src, group) if in_dygraph_mode(): group = _get_default_group() if group is None else group - group_src_rank = group.get_group_rank(src) - task = group.process_group.recv(tensor, group_src_rank) + task = group.process_group.recv(tensor, src) if use_calc_stream: task.wait() return None