diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index 91136033761913810e59481aacbafad91491654b..ce5c1cfe9eb8537e6140c250fdfdbfe409f4d72b 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -327,7 +327,8 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): if tensor_send_prev is not None: if isinstance(tensor_send_prev, tuple): for d in tensor_send_prev: - paddle.distributed.wait(d, use_calc_stream=True) + if _in_legacy_dygraph(): + paddle.distributed.wait(d, use_calc_stream=True) tasks.append( send_partial(d, dst=0, @@ -336,7 +337,8 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): group=_hcg.send_prev_group, use_calc_stream=False)) else: - paddle.distributed.wait(tensor_send_prev, use_calc_stream=True) + if _in_legacy_dygraph(): + paddle.distributed.wait(tensor_send_prev, use_calc_stream=True) tasks.append( send_partial(tensor_send_prev, dst=0, @@ -355,12 +357,6 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): rank_id=mp_rank, group=_hcg.recv_prev_group, use_calc_stream=True)) - tasks.append( - allgather_partial(d, - nranks=mp_degree, - rank_id=mp_rank, - group=mp_group, - use_calc_stream=True)) else: tasks.append( recv_partial(tensor_recv_prev, @@ -369,17 +365,12 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): rank_id=mp_rank, group=_hcg.recv_prev_group, use_calc_stream=True)) - tasks.append( - allgather_partial(tensor_recv_prev, - nranks=mp_degree, - rank_id=mp_rank, - group=mp_group, - use_calc_stream=True)) if tensor_send_next is not None: if isinstance(tensor_send_next, tuple): for d in tensor_send_next: - paddle.distributed.wait(d, use_calc_stream=True) + if _in_legacy_dygraph(): + paddle.distributed.wait(d, use_calc_stream=True) tasks.append( send_partial(d, dst=1, @@ -388,7 +379,8 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): group=_hcg.send_next_group, use_calc_stream=False)) else: - paddle.distributed.wait(tensor_send_next, use_calc_stream=True) + if _in_legacy_dygraph(): + paddle.distributed.wait(tensor_send_next, use_calc_stream=True) tasks.append( send_partial(tensor_send_next, dst=1, @@ -407,12 +399,6 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): rank_id=mp_rank, group=_hcg.recv_next_group, use_calc_stream=True)) - tasks.append( - allgather_partial(d, - nranks=mp_degree, - rank_id=mp_rank, - group=mp_group, - use_calc_stream=True)) else: tasks.append( @@ -423,17 +409,40 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): group=_hcg.recv_next_group, use_calc_stream=True)) - tasks.append( - allgather_partial(tensor_recv_next, - nranks=mp_degree, - rank_id=mp_rank, - group=mp_group, - use_calc_stream=True)) if in_dygraph_mode(): - # wait tasks in new dygraph mode with new comm library + # wait isend/irecv tasks in eager dygraph mode with new comm library for task in tasks: - if task is not None: - task.wait() + assert task is not None + task.wait() + + tensors_for_all_gather = [] + if tensor_recv_prev is not None: + if isinstance(tensor_recv_prev, tuple): + for d in tensor_recv_prev: + tensors_for_all_gather.append(d) + else: + tensors_for_all_gather.append(tensor_recv_prev) + if tensor_recv_next is not None: + if isinstance(tensor_recv_next, tuple): + for d in tensor_recv_next: + tensors_for_all_gather.append(d) + else: + tensors_for_all_gather.append(tensor_recv_next) + + tasks = [] + for tensor in tensors_for_all_gather: + tasks.append( + allgather_partial(tensor, + nranks=mp_degree, + rank_id=mp_rank, + group=mp_group, + use_calc_stream=True)) + + for task in tasks: + # wait partial all gather tasks + if task is not None: + task.wait() + return tensor_recv_prev, tensor_recv_next