From 74412dfefff946551f6f38cf41c58b37ba0710f8 Mon Sep 17 00:00:00 2001 From: ShenLiang <1422485404@qq.com> Date: Mon, 18 Jul 2022 14:47:48 +0800 Subject: [PATCH] fix bug of old pp (#44361) --- python/paddle/distributed/collective.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index 2506c30739..62b18298f1 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -2030,6 +2030,10 @@ def alltoall_single(in_tensor, return task +def _get_group_rank(global_rank, group=None): + return global_rank if group is None else group.get_group_rank(global_rank) + + def send(tensor, dst=0, group=None, use_calc_stream=True): """ Send a tensor to the receiver. @@ -2062,11 +2066,10 @@ def send(tensor, dst=0, group=None, use_calc_stream=True): """ if group is not None and not group.is_member(): return - + dst = _get_group_rank(dst, group) if in_dygraph_mode(): group = _get_default_group() if group is None else group - group_dst_rank = group.get_group_rank(dst) - task = group.process_group.send(tensor, group_dst_rank) + task = group.process_group.send(tensor, dst) if use_calc_stream: task.wait() return None @@ -2126,10 +2129,10 @@ def recv(tensor, src=0, group=None, use_calc_stream=True): if group is not None and not group.is_member(): return + src = _get_group_rank(src, group) if in_dygraph_mode(): group = _get_default_group() if group is None else group - group_src_rank = group.get_group_rank(src) - task = group.process_group.recv(tensor, group_src_rank) + task = group.process_group.recv(tensor, src) if use_calc_stream: task.wait() return None -- GitLab