diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index 3b4094f047552dd5d822f4f4c576779f9a4e16f3..160c5f1511220aa8ec36f670d10f96f1b5e90cd9 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -329,23 +329,21 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): for d in tensor_send_prev: if _in_legacy_dygraph(): paddle.distributed.wait(d, use_calc_stream=True) - tasks.append( - send_partial(d, - dst=0, - nranks=mp_degree, - rank_id=mp_rank, - group=_hcg.send_prev_group, - use_calc_stream=False)) - else: - if _in_legacy_dygraph(): - paddle.distributed.wait(tensor_send_prev, use_calc_stream=True) - tasks.append( - send_partial(tensor_send_prev, + send_partial(d, dst=0, nranks=mp_degree, rank_id=mp_rank, group=_hcg.send_prev_group, - use_calc_stream=False)) + use_calc_stream=False) + else: + if _in_legacy_dygraph(): + paddle.distributed.wait(tensor_send_prev, use_calc_stream=True) + send_partial(tensor_send_prev, + dst=0, + nranks=mp_degree, + rank_id=mp_rank, + group=_hcg.send_prev_group, + use_calc_stream=False) if tensor_recv_prev is not None: if isinstance(tensor_recv_prev, tuple): @@ -371,23 +369,21 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): for d in tensor_send_next: if _in_legacy_dygraph(): paddle.distributed.wait(d, use_calc_stream=True) - tasks.append( - send_partial(d, - dst=1, - nranks=mp_degree, - rank_id=mp_rank, - group=_hcg.send_next_group, - use_calc_stream=False)) - else: - if _in_legacy_dygraph(): - paddle.distributed.wait(tensor_send_next, use_calc_stream=True) - tasks.append( - send_partial(tensor_send_next, + send_partial(d, dst=1, nranks=mp_degree, rank_id=mp_rank, group=_hcg.send_next_group, - use_calc_stream=False)) + use_calc_stream=False) + else: + if _in_legacy_dygraph(): + paddle.distributed.wait(tensor_send_next, use_calc_stream=True) + send_partial(tensor_send_next, + dst=1, + nranks=mp_degree, + rank_id=mp_rank, + group=_hcg.send_next_group, + use_calc_stream=False) if tensor_recv_next is not None: if isinstance(tensor_recv_next, tuple): @@ -438,10 +434,11 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): group=mp_group, use_calc_stream=True)) - for task in tasks: - # wait partial all gather tasks - if task is not None: - task.wait() + if in_dygraph_mode(): + for task in tasks: + # wait partial all gather tasks + if task is not None: + task.wait() return tensor_recv_prev, tensor_recv_next