diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index 7962e2dd4373e643a4510bc5c56bf53b3c02f88e..e2ca6f8d2a034c4f036bdfee35c87b2cfad5312e 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -165,17 +165,15 @@ def _is_valid_send_recv_partial(tensor, mp_degree): def _partial_send_op(tensor, group, use_calc_stream, ring_id, dst, nranks, rank_id): - dst_rank_in_group = dst if group is None else group.get_group_rank(dst) if _in_legacy_dygraph(): return _legacy_C_ops.partial_send(tensor.detach(), 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, - 'peer', dst_rank_in_group, 'num', - nranks, 'id', rank_id) + 'peer', dst, 'num', nranks, 'id', + rank_id) elif in_dygraph_mode(): group = paddle.distributed.collective._get_default_group( ) if group is None else group - return group.process_group.send_partial(tensor, dst_rank_in_group, - nranks, rank_id) + return group.process_group.send_partial(tensor, dst, nranks, rank_id) def send_partial(tensor, @@ -189,13 +187,12 @@ def send_partial(tensor, return ring_id = 0 if group is None else group.id - dst_rank = _hcg._get_p2p_next_rank( - ) if dst == 1 else _hcg._get_p2p_prev_rank() - if _is_valid_send_recv_partial(tensor, nranks): - return _partial_send_op(tensor, group, use_calc_stream, ring_id, - dst_rank, nranks, rank_id) + return _partial_send_op(tensor, group, use_calc_stream, ring_id, dst, + nranks, rank_id) else: + dst_rank = _hcg._get_p2p_next_rank( + ) if dst == 1 else _hcg._get_p2p_prev_rank() if _in_legacy_dygraph(): send_op = paddle.distributed.send elif in_dygraph_mode(): @@ -205,23 +202,22 @@ def send_partial(tensor, def _partial_recv_op(tensor, group, use_calc_stream, ring_id, src, nranks, rank_id): - src_rank_in_group = src if group is None else group.get_group_rank(src) if _in_legacy_dygraph(): assert use_calc_stream return _legacy_C_ops.partial_recv(tensor.detach(), 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, - 'peer', src_rank_in_group, 'num', - nranks, 'id', rank_id, 'dtype', - tensor.dtype, 'out_shape', - tensor.shape) + 'peer', src, 'num', nranks, 'id', + rank_id, 'dtype', tensor.dtype, + 'out_shape', tensor.shape) elif in_dygraph_mode(): group = paddle.distributed.collective._get_default_group( ) if group is None else group - task = group.process_group.recv_partial(tensor, src_rank_in_group, - nranks, rank_id) + task = group.process_group.recv_partial(tensor, src, nranks, rank_id) if use_calc_stream: task.wait() - return task + return None + else: + return task def recv_partial(tensor, @@ -235,13 +231,12 @@ def recv_partial(tensor, return ring_id = 0 if group is None else group.id - src_rank = _hcg._get_p2p_prev_rank( - ) if src == 0 else _hcg._get_p2p_next_rank() - if _is_valid_send_recv_partial(tensor, nranks): - return _partial_recv_op(tensor, group, use_calc_stream, ring_id, - src_rank, nranks, rank_id) + return _partial_recv_op(tensor, group, use_calc_stream, ring_id, src, + nranks, rank_id) else: + src_rank = _hcg._get_p2p_prev_rank( + ) if src == 0 else _hcg._get_p2p_next_rank() if _in_legacy_dygraph() or use_calc_stream: recv_op = paddle.distributed.recv elif in_dygraph_mode(): @@ -260,8 +255,13 @@ def _partial_allgather_op(tensor, group, use_calc_stream, ring_id, nranks, elif in_dygraph_mode(): group = paddle.distributed.collective._get_default_group( ) if group is None else group - return group.process_group.all_gather_partial(tensor, tensor, nranks, + task = group.process_group.all_gather_partial(tensor, tensor, nranks, rank_id) + if use_calc_stream: + task.wait() + return None + else: + return task def allgather_partial(tensor, @@ -270,9 +270,9 @@ def allgather_partial(tensor, group=None, use_calc_stream=True): if not _is_valid_send_recv_partial(tensor, nranks): - return None + return tensor if group is not None and not group.is_member(): - return None + return ring_id = 0 if group is None else group.id return _partial_allgather_op(tensor, group, use_calc_stream, ring_id, @@ -335,8 +335,7 @@ def _p2p_helper(tensor_send_next, if tensor_send_prev is not None: if isinstance(tensor_send_prev, tuple): for d in tensor_send_prev: - if _in_legacy_dygraph(): - paddle.distributed.wait(d, use_calc_stream=True) + paddle.distributed.wait(d, use_calc_stream=True) send_partial(d, dst=0, nranks=mp_degree, @@ -344,8 +343,7 @@ def _p2p_helper(tensor_send_next, group=_hcg.send_prev_group, use_calc_stream=False) else: - if _in_legacy_dygraph(): - paddle.distributed.wait(tensor_send_prev, use_calc_stream=True) + paddle.distributed.wait(tensor_send_prev, use_calc_stream=True) send_partial(tensor_send_prev, dst=0, nranks=mp_degree, @@ -356,27 +354,40 @@ def _p2p_helper(tensor_send_next, if tensor_recv_prev is not None: if isinstance(tensor_recv_prev, tuple): for d in tensor_recv_prev: - tasks.append( - recv_partial(d, - src=0, - nranks=mp_degree, - rank_id=mp_rank, - group=_hcg.recv_prev_group, - use_calc_stream=sync_recv)) + task = recv_partial(d, + src=0, + nranks=mp_degree, + rank_id=mp_rank, + group=_hcg.recv_prev_group, + use_calc_stream=sync_recv) + if sync_recv: + allgather_partial(d, + nranks=mp_degree, + rank_id=mp_rank, + group=mp_group, + use_calc_stream=True) + else: + tasks.append(task) else: - tasks.append( - recv_partial(tensor_recv_prev, - src=0, - nranks=mp_degree, - rank_id=mp_rank, - group=_hcg.recv_prev_group, - use_calc_stream=sync_recv)) + task = recv_partial(tensor_recv_prev, + src=0, + nranks=mp_degree, + rank_id=mp_rank, + group=_hcg.recv_prev_group, + use_calc_stream=sync_recv) + if sync_recv: + allgather_partial(tensor_recv_prev, + nranks=mp_degree, + rank_id=mp_rank, + group=mp_group, + use_calc_stream=True) + else: + tasks.append(task) if tensor_send_next is not None: if isinstance(tensor_send_next, tuple): for d in tensor_send_next: - if _in_legacy_dygraph(): - paddle.distributed.wait(d, use_calc_stream=True) + paddle.distributed.wait(d, use_calc_stream=True) send_partial(d, dst=1, nranks=mp_degree, @@ -384,8 +395,7 @@ def _p2p_helper(tensor_send_next, group=_hcg.send_next_group, use_calc_stream=False) else: - if _in_legacy_dygraph(): - paddle.distributed.wait(tensor_send_next, use_calc_stream=True) + paddle.distributed.wait(tensor_send_next, use_calc_stream=True) send_partial(tensor_send_next, dst=1, nranks=mp_degree, @@ -396,57 +406,64 @@ def _p2p_helper(tensor_send_next, if tensor_recv_next is not None: if isinstance(tensor_recv_next, tuple): for d in tensor_recv_next: - tasks.append( - recv_partial(d, - src=1, - nranks=mp_degree, - rank_id=mp_rank, - group=_hcg.recv_next_group, - use_calc_stream=sync_recv)) - - else: - tasks.append( - recv_partial(tensor_recv_next, - src=1, - nranks=mp_degree, - rank_id=mp_rank, - group=_hcg.recv_next_group, - use_calc_stream=sync_recv)) + task = recv_partial(d, + src=1, + nranks=mp_degree, + rank_id=mp_rank, + group=_hcg.recv_next_group, + use_calc_stream=sync_recv) + if sync_recv: + allgather_partial(d, + nranks=mp_degree, + rank_id=mp_rank, + group=mp_group, + use_calc_stream=True) + else: + tasks.append(task) - if not sync_recv and in_dygraph_mode(): - # wait irecv tasks in eager dygraph mode with new comm library - for task in tasks: - assert task is not None - task.wait() - - tensors_for_all_gather = [] - if tensor_recv_prev is not None: - if isinstance(tensor_recv_prev, tuple): - for d in tensor_recv_prev: - tensors_for_all_gather.append(d) else: - tensors_for_all_gather.append(tensor_recv_prev) - if tensor_recv_next is not None: - if isinstance(tensor_recv_next, tuple): - for d in tensor_recv_next: - tensors_for_all_gather.append(d) - else: - tensors_for_all_gather.append(tensor_recv_next) + task = recv_partial(tensor_recv_next, + src=1, + nranks=mp_degree, + rank_id=mp_rank, + group=_hcg.recv_next_group, + use_calc_stream=sync_recv) + if sync_recv: + allgather_partial(tensor_recv_next, + nranks=mp_degree, + rank_id=mp_rank, + group=mp_group, + use_calc_stream=True) + else: + tasks.append(task) + + if not sync_recv: + if in_dygraph_mode(): + # wait irecv tasks in eager dygraph mode with new comm library + for task in tasks: + assert task is not None + task.wait() - tasks = [] - for tensor in tensors_for_all_gather: - tasks.append( + tensors_for_all_gather = [] + if tensor_recv_prev is not None: + if isinstance(tensor_recv_prev, tuple): + for d in tensor_recv_prev: + tensors_for_all_gather.append(d) + else: + tensors_for_all_gather.append(tensor_recv_prev) + if tensor_recv_next is not None: + if isinstance(tensor_recv_next, tuple): + for d in tensor_recv_next: + tensors_for_all_gather.append(d) + else: + tensors_for_all_gather.append(tensor_recv_next) + + for tensor in tensors_for_all_gather: allgather_partial(tensor, nranks=mp_degree, rank_id=mp_rank, group=mp_group, - use_calc_stream=True)) - - if in_dygraph_mode(): - for task in tasks: - # wait partial all gather tasks - if task is not None: - task.wait() + use_calc_stream=True) return tensor_recv_prev, tensor_recv_next