diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index da20e312a1dc8764a2ee6873d74780b384e2a27e..fa4b937ba56b7240ee5b0b3f060022d5ae3b9c85 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import os import numpy as np @@ -27,10 +26,30 @@ _hcg = None _use_cache = False _enable_partial_send_recv = True +_xpu_comm_group_started = False + _sync_send = os.environ.get("PADDLE_P2P_SYNC_SEND", "0") _sync_send = _sync_send.lower() in ['1', 'true'] +def _xpu_comm_group_start(): + if not paddle.is_compiled_with_xpu(): + return + global _xpu_comm_group_started + assert not _xpu_comm_group_started + framework.core.ProcessGroupBKCL.group_start() + _xpu_comm_group_started = True + + +def _xpu_comm_group_end(): + if not paddle.is_compiled_with_xpu(): + return + global _xpu_comm_group_started + if _xpu_comm_group_started: + framework.core.ProcessGroupBKCL.group_end() + _xpu_comm_group_started = False + + def initialize_p2p_groups(hcg, use_cache=True, enable_partial_send_recv=True): global _hcg, _use_cache, _enable_partial_send_recv _hcg = hcg @@ -357,6 +376,7 @@ def _p2p_helper( # TODO(Yuang Liu): use batch_isend_irecv replace all these comm ops tasks = [] # start to p2p communicate + if _sync_send: # Some devices(NPU for example) do not support asynchronized send op, So the order is # recv_prev -> send_next -> recv_next -> send_prev @@ -492,8 +512,8 @@ def _p2p_helper( group=_hcg.send_prev_group, use_calc_stream=False, ) - else: + _xpu_comm_group_start() if tensor_send_prev is not None: if isinstance(tensor_send_prev, tuple): for d in tensor_send_prev: @@ -529,6 +549,7 @@ def _p2p_helper( use_calc_stream=sync_recv, ) if sync_recv: + _xpu_comm_group_end() allgather_partial( d, nranks=mp_degree, @@ -549,6 +570,7 @@ def _p2p_helper( ) if sync_recv: + _xpu_comm_group_end() allgather_partial( tensor_recv_prev, nranks=mp_degree, @@ -595,6 +617,7 @@ def _p2p_helper( ) if sync_recv: + _xpu_comm_group_end() allgather_partial( d, nranks=mp_degree, @@ -615,6 +638,7 @@ def _p2p_helper( use_calc_stream=sync_recv, ) if sync_recv: + _xpu_comm_group_end() allgather_partial( tensor_recv_next, nranks=mp_degree, @@ -624,7 +648,7 @@ def _p2p_helper( ) else: tasks.append(task) - + _xpu_comm_group_end() if not sync_recv: if framework.in_dygraph_mode(): # wait irecv tasks in eager dygraph mode with new comm library