diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index b41fef05660b2f188dd2f8f929ffaf873383e984..a84c49d74aa06309423986d13ec5017e18ef2277 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -17,7 +17,11 @@ from ...utils.log_util import logger import numpy as np from paddle import _C_ops, _legacy_C_ops import paddle.fluid.core as core -from paddle.fluid.framework import _in_legacy_dygraph, _non_static_mode, in_dygraph_mode +from paddle.fluid.framework import ( + _in_legacy_dygraph, + _non_static_mode, + in_dygraph_mode, +) from .utils import paddle_2_number, paddle_2_number, number_2_dtype _hcg = None @@ -30,12 +34,23 @@ def initialize_p2p_groups(hcg, use_cache=True, enable_partial_send_recv=True): _hcg = hcg _use_cache = use_cache _enable_partial_send_recv = enable_partial_send_recv - send_next_group, send_prev_group, recv_next_group, recv_prev_group = _hcg.get_p2p_groups( + ( + send_next_group, + send_prev_group, + recv_next_group, + recv_prev_group, + ) = _hcg.get_p2p_groups() + + debug_str = ( + "P2pInfo: send_next_group: %s, send_prev_group: %s, " + "recv_next_group: %s, recv_prev_group: %s" + % ( + repr(send_next_group), + repr(send_prev_group), + repr(recv_next_group), + repr(recv_prev_group), + ) ) - - debug_str = "P2pInfo: send_next_group: %s, send_prev_group: %s, " \ - "recv_next_group: %s, recv_prev_group: %s" % (repr(send_next_group), - repr(send_prev_group),repr(recv_next_group), repr(recv_prev_group)) logger.info(debug_str) @@ -150,9 +165,15 @@ class SendRecvMeta: self.send_dtype_message = paddle_2_number(tensor.dtype) elif isinstance(tensor, tuple): self.send_shape_message = tuple( - [d.shape for d in tensor if not d.stop_gradient]) + [d.shape for d in tensor if not d.stop_gradient] + ) self.send_dtype_message = tuple( - [paddle_2_number(d.dtype) for d in tensor]) + [ + paddle_2_number(d.dtype) + for d in tensor + if not d.stop_gradient + ] + ) _send_recv_meta = SendRecvMeta() @@ -166,84 +187,117 @@ def _is_valid_send_recv_partial(tensor, mp_degree): return mp_degree > 1 and tensor_numel % mp_degree == 0 -def _partial_send_op(tensor, group, use_calc_stream, ring_id, dst, nranks, - rank_id): +def _partial_send_op( + tensor, group, use_calc_stream, ring_id, dst, nranks, rank_id +): dst_rank_in_group = dst if group is None else group.get_group_rank(dst) if _in_legacy_dygraph(): - return _legacy_C_ops.partial_send(tensor.detach(), 'use_calc_stream', - use_calc_stream, 'ring_id', ring_id, - 'peer', dst_rank_in_group, 'num', - nranks, 'id', rank_id) + return _legacy_C_ops.partial_send( + tensor.detach(), + 'use_calc_stream', + use_calc_stream, + 'ring_id', + ring_id, + 'peer', + dst_rank_in_group, + 'num', + nranks, + 'id', + rank_id, + ) elif in_dygraph_mode(): - group = paddle.distributed.collective._get_default_group( - ) if group is None else group - comm_op = group.process_group.send_partial_on_calc_stream \ - if use_calc_stream else group.process_group.send_partial + group = ( + paddle.distributed.collective._get_default_group() + if group is None + else group + ) + comm_op = ( + group.process_group.send_partial_on_calc_stream + if use_calc_stream + else group.process_group.send_partial + ) return comm_op(tensor, dst_rank_in_group, nranks, rank_id) -def send_partial(tensor, - dst=0, - nranks=1, - rank_id=0, - group=None, - use_calc_stream=True): +def send_partial( + tensor, dst=0, nranks=1, rank_id=0, group=None, use_calc_stream=True +): # dst: local rank in group if group is not None and not group.is_member(): return ring_id = 0 if group is None else group.id - dst_rank = _hcg._get_p2p_next_rank( - ) if dst == 1 else _hcg._get_p2p_prev_rank() + dst_rank = ( + _hcg._get_p2p_next_rank() if dst == 1 else _hcg._get_p2p_prev_rank() + ) if _is_valid_send_recv_partial(tensor, nranks): - return _partial_send_op(tensor, group, use_calc_stream, ring_id, - dst_rank, nranks, rank_id) + return _partial_send_op( + tensor, group, use_calc_stream, ring_id, dst_rank, nranks, rank_id + ) else: if _in_legacy_dygraph(): - send_op = lambda x, dst, group: \ - paddle.distributed.send(x, dst, group, use_calc_stream) + send_op = lambda x, dst, group: paddle.distributed.send( + x, dst, group, use_calc_stream + ) elif in_dygraph_mode(): send_op = paddle.distributed.isend return send_op(tensor.detach(), dst=dst_rank, group=group) -def _partial_recv_op(tensor, group, use_calc_stream, ring_id, src, nranks, - rank_id): +def _partial_recv_op( + tensor, group, use_calc_stream, ring_id, src, nranks, rank_id +): src_rank_in_group = src if group is None else group.get_group_rank(src) if _in_legacy_dygraph(): assert use_calc_stream - return _legacy_C_ops.partial_recv(tensor.detach(), 'use_calc_stream', - use_calc_stream, 'ring_id', ring_id, - 'peer', src_rank_in_group, 'num', - nranks, 'id', rank_id, 'dtype', - tensor.dtype, 'out_shape', - tensor.shape) + return _legacy_C_ops.partial_recv( + tensor.detach(), + 'use_calc_stream', + use_calc_stream, + 'ring_id', + ring_id, + 'peer', + src_rank_in_group, + 'num', + nranks, + 'id', + rank_id, + 'dtype', + tensor.dtype, + 'out_shape', + tensor.shape, + ) elif in_dygraph_mode(): - group = paddle.distributed.collective._get_default_group( - ) if group is None else group - comm_op = group.process_group.recv_partial_on_calc_stream \ - if use_calc_stream else group.process_group.recv_partial + group = ( + paddle.distributed.collective._get_default_group() + if group is None + else group + ) + comm_op = ( + group.process_group.recv_partial_on_calc_stream + if use_calc_stream + else group.process_group.recv_partial + ) return comm_op(tensor, src_rank_in_group, nranks, rank_id) -def recv_partial(tensor, - src=0, - nranks=1, - rank_id=0, - group=None, - use_calc_stream=True): +def recv_partial( + tensor, src=0, nranks=1, rank_id=0, group=None, use_calc_stream=True +): # src: local rank in group if group is not None and not group.is_member(): return ring_id = 0 if group is None else group.id - src_rank = _hcg._get_p2p_prev_rank( - ) if src == 0 else _hcg._get_p2p_next_rank() + src_rank = ( + _hcg._get_p2p_prev_rank() if src == 0 else _hcg._get_p2p_next_rank() + ) if _is_valid_send_recv_partial(tensor, nranks): - return _partial_recv_op(tensor, group, use_calc_stream, ring_id, - src_rank, nranks, rank_id) + return _partial_recv_op( + tensor, group, use_calc_stream, ring_id, src_rank, nranks, rank_id + ) else: if _in_legacy_dygraph() or use_calc_stream: recv_op = paddle.distributed.recv @@ -252,42 +306,52 @@ def recv_partial(tensor, return recv_op(tensor.detach(), src=src_rank, group=group) -def _partial_allgather_op(tensor, group, use_calc_stream, ring_id, nranks, - rank_id): +def _partial_allgather_op( + tensor, group, use_calc_stream, ring_id, nranks, rank_id +): if _in_legacy_dygraph(): - return _legacy_C_ops.partial_allgather_(tensor.detach(), - 'use_calc_stream', - use_calc_stream, 'ring_id', - ring_id, 'nranks', nranks, - 'rank', rank_id) + return _legacy_C_ops.partial_allgather_( + tensor.detach(), + 'use_calc_stream', + use_calc_stream, + 'ring_id', + ring_id, + 'nranks', + nranks, + 'rank', + rank_id, + ) elif in_dygraph_mode(): - group = paddle.distributed.collective._get_default_group( - ) if group is None else group - comm_op = group.process_group.all_gather_partial_on_calc_stream \ - if use_calc_stream else group.process_group.all_gather_partial + group = ( + paddle.distributed.collective._get_default_group() + if group is None + else group + ) + comm_op = ( + group.process_group.all_gather_partial_on_calc_stream + if use_calc_stream + else group.process_group.all_gather_partial + ) return comm_op(tensor, tensor, nranks, rank_id) -def allgather_partial(tensor, - nranks=1, - rank_id=0, - group=None, - use_calc_stream=True): +def allgather_partial( + tensor, nranks=1, rank_id=0, group=None, use_calc_stream=True +): if not _is_valid_send_recv_partial(tensor, nranks): return tensor if group is not None and not group.is_member(): return ring_id = 0 if group is None else group.id - return _partial_allgather_op(tensor, group, use_calc_stream, ring_id, - nranks, rank_id) + return _partial_allgather_op( + tensor, group, use_calc_stream, ring_id, nranks, rank_id + ) -def _p2p_helper(tensor_send_next, - tensor_send_prev, - recv_prev, - recv_next, - sync_recv=True): +def _p2p_helper( + tensor_send_next, tensor_send_prev, recv_prev, recv_next, sync_recv=True +): global _hcg tensor_recv_prev = None @@ -310,15 +374,17 @@ def _p2p_helper(tensor_send_next, if isinstance(recv_shape_msg, tuple): tensor_recv_prev = [] for idx, shape in enumerate(recv_shape_msg): - tmp = paddle.empty(shape=shape, - dtype=number_2_dtype(recv_dtype_msg[idx])) + tmp = paddle.empty( + shape=shape, dtype=number_2_dtype(recv_dtype_msg[idx]) + ) tmp.stop_gradient = recv_stop_gradient[idx] tensor_recv_prev.append(tmp) tensor_recv_prev = tuple(tensor_recv_prev) else: tensor_recv_prev = paddle.empty( - shape=recv_shape_msg, dtype=number_2_dtype(recv_dtype_msg)) + shape=recv_shape_msg, dtype=number_2_dtype(recv_dtype_msg) + ) tensor_recv_prev.stop_gradient = recv_stop_gradient if recv_next: @@ -326,12 +392,15 @@ def _p2p_helper(tensor_send_next, tensor_recv_next = [] for idx, shape in enumerate(send_shape_msg): tensor_recv_next.append( - paddle.empty(shape=shape, - dtype=number_2_dtype(send_dtype_msg[idx]))) + paddle.empty( + shape=shape, dtype=number_2_dtype(send_dtype_msg[idx]) + ) + ) tensor_recv_next = tuple(tensor_recv_next) else: tensor_recv_next = paddle.empty( - shape=send_shape_msg, dtype=number_2_dtype(send_dtype_msg)) + shape=send_shape_msg, dtype=number_2_dtype(send_dtype_msg) + ) # TODO(Yuang Liu): use batch_isend_irecv replace all these comm ops tasks = [] @@ -340,51 +409,63 @@ def _p2p_helper(tensor_send_next, if isinstance(tensor_send_prev, tuple): for d in tensor_send_prev: paddle.distributed.wait(d, use_calc_stream=True) - send_partial(d, - dst=0, - nranks=mp_degree, - rank_id=mp_rank, - group=_hcg.send_prev_group, - use_calc_stream=False) + send_partial( + d, + dst=0, + nranks=mp_degree, + rank_id=mp_rank, + group=_hcg.send_prev_group, + use_calc_stream=False, + ) else: paddle.distributed.wait(tensor_send_prev, use_calc_stream=True) - send_partial(tensor_send_prev, - dst=0, - nranks=mp_degree, - rank_id=mp_rank, - group=_hcg.send_prev_group, - use_calc_stream=False) + send_partial( + tensor_send_prev, + dst=0, + nranks=mp_degree, + rank_id=mp_rank, + group=_hcg.send_prev_group, + use_calc_stream=False, + ) if tensor_recv_prev is not None: if isinstance(tensor_recv_prev, tuple): for d in tensor_recv_prev: - task = recv_partial(d, - src=0, - nranks=mp_degree, - rank_id=mp_rank, - group=_hcg.recv_prev_group, - use_calc_stream=sync_recv) + task = recv_partial( + d, + src=0, + nranks=mp_degree, + rank_id=mp_rank, + group=_hcg.recv_prev_group, + use_calc_stream=sync_recv, + ) if sync_recv: - allgather_partial(d, - nranks=mp_degree, - rank_id=mp_rank, - group=mp_group, - use_calc_stream=True) + allgather_partial( + d, + nranks=mp_degree, + rank_id=mp_rank, + group=mp_group, + use_calc_stream=True, + ) else: tasks.append(task) else: - task = recv_partial(tensor_recv_prev, - src=0, - nranks=mp_degree, - rank_id=mp_rank, - group=_hcg.recv_prev_group, - use_calc_stream=sync_recv) + task = recv_partial( + tensor_recv_prev, + src=0, + nranks=mp_degree, + rank_id=mp_rank, + group=_hcg.recv_prev_group, + use_calc_stream=sync_recv, + ) if sync_recv: - allgather_partial(tensor_recv_prev, - nranks=mp_degree, - rank_id=mp_rank, - group=mp_group, - use_calc_stream=True) + allgather_partial( + tensor_recv_prev, + nranks=mp_degree, + rank_id=mp_rank, + group=mp_group, + use_calc_stream=True, + ) else: tasks.append(task) @@ -392,52 +473,64 @@ def _p2p_helper(tensor_send_next, if isinstance(tensor_send_next, tuple): for d in tensor_send_next: paddle.distributed.wait(d, use_calc_stream=True) - send_partial(d, - dst=1, - nranks=mp_degree, - rank_id=mp_rank, - group=_hcg.send_next_group, - use_calc_stream=False) + send_partial( + d, + dst=1, + nranks=mp_degree, + rank_id=mp_rank, + group=_hcg.send_next_group, + use_calc_stream=False, + ) else: paddle.distributed.wait(tensor_send_next, use_calc_stream=True) - send_partial(tensor_send_next, - dst=1, - nranks=mp_degree, - rank_id=mp_rank, - group=_hcg.send_next_group, - use_calc_stream=False) + send_partial( + tensor_send_next, + dst=1, + nranks=mp_degree, + rank_id=mp_rank, + group=_hcg.send_next_group, + use_calc_stream=False, + ) if tensor_recv_next is not None: if isinstance(tensor_recv_next, tuple): for d in tensor_recv_next: - task = recv_partial(d, - src=1, - nranks=mp_degree, - rank_id=mp_rank, - group=_hcg.recv_next_group, - use_calc_stream=sync_recv) + task = recv_partial( + d, + src=1, + nranks=mp_degree, + rank_id=mp_rank, + group=_hcg.recv_next_group, + use_calc_stream=sync_recv, + ) if sync_recv: - allgather_partial(d, - nranks=mp_degree, - rank_id=mp_rank, - group=mp_group, - use_calc_stream=True) + allgather_partial( + d, + nranks=mp_degree, + rank_id=mp_rank, + group=mp_group, + use_calc_stream=True, + ) else: tasks.append(task) else: - task = recv_partial(tensor_recv_next, - src=1, - nranks=mp_degree, - rank_id=mp_rank, - group=_hcg.recv_next_group, - use_calc_stream=sync_recv) + task = recv_partial( + tensor_recv_next, + src=1, + nranks=mp_degree, + rank_id=mp_rank, + group=_hcg.recv_next_group, + use_calc_stream=sync_recv, + ) if sync_recv: - allgather_partial(tensor_recv_next, - nranks=mp_degree, - rank_id=mp_rank, - group=mp_group, - use_calc_stream=True) + allgather_partial( + tensor_recv_next, + nranks=mp_degree, + rank_id=mp_rank, + group=mp_group, + use_calc_stream=True, + ) else: tasks.append(task) @@ -463,11 +556,13 @@ def _p2p_helper(tensor_send_next, tensors_for_all_gather.append(tensor_recv_next) for tensor in tensors_for_all_gather: - allgather_partial(tensor, - nranks=mp_degree, - rank_id=mp_rank, - group=mp_group, - use_calc_stream=True) + allgather_partial( + tensor, + nranks=mp_degree, + rank_id=mp_rank, + group=mp_group, + use_calc_stream=True, + ) return tensor_recv_prev, tensor_recv_next @@ -480,11 +575,13 @@ def recv_forward(pp_first_stage, sync_recv=True): _send_recv_meta.recv_meta(_hcg.recv_prev_group) _send_recv_meta.has_recv_meta = _use_cache - input_tensor, _ = _p2p_helper(tensor_send_next=None, - tensor_send_prev=None, - recv_prev=True, - recv_next=False, - sync_recv=sync_recv) + input_tensor, _ = _p2p_helper( + tensor_send_next=None, + tensor_send_prev=None, + recv_prev=True, + recv_next=False, + sync_recv=sync_recv, + ) return input_tensor @@ -492,11 +589,13 @@ def recv_backward(pp_last_stage, sync_recv=True): if pp_last_stage: output_tensor_grad = None else: - _, output_tensor_grad = _p2p_helper(tensor_send_next=None, - tensor_send_prev=None, - recv_prev=False, - recv_next=True, - sync_recv=sync_recv) + _, output_tensor_grad = _p2p_helper( + tensor_send_next=None, + tensor_send_prev=None, + recv_prev=False, + recv_next=True, + sync_recv=sync_recv, + ) return output_tensor_grad @@ -507,28 +606,34 @@ def send_forward(output_tensor, pp_last_stage): _send_recv_meta.send_meta(output_tensor, _hcg.send_next_group) _send_recv_meta.has_send_meta = _use_cache - _p2p_helper(tensor_send_next=output_tensor, - tensor_send_prev=None, - recv_prev=False, - recv_next=False) + _p2p_helper( + tensor_send_next=output_tensor, + tensor_send_prev=None, + recv_prev=False, + recv_next=False, + ) def send_backward(input_tensor_grad, pp_first_stage): if not pp_first_stage: - _p2p_helper(tensor_send_next=None, - tensor_send_prev=input_tensor_grad, - recv_prev=False, - recv_next=False) + _p2p_helper( + tensor_send_next=None, + tensor_send_prev=input_tensor_grad, + recv_prev=False, + recv_next=False, + ) def send_forward_recv_backward(output_tensor, pp_last_stage): if pp_last_stage: output_tensor_grad = None else: - _, output_tensor_grad = _p2p_helper(tensor_send_next=output_tensor, - tensor_send_prev=None, - recv_prev=False, - recv_next=True) + _, output_tensor_grad = _p2p_helper( + tensor_send_next=output_tensor, + tensor_send_prev=None, + recv_prev=False, + recv_next=True, + ) return output_tensor_grad @@ -536,16 +641,18 @@ def send_backward_recv_forward(input_tensor_grad, pp_first_stage): if pp_first_stage: input_tensor = None else: - input_tensor, _ = _p2p_helper(tensor_send_next=None, - tensor_send_prev=input_tensor_grad, - recv_prev=True, - recv_next=False) + input_tensor, _ = _p2p_helper( + tensor_send_next=None, + tensor_send_prev=input_tensor_grad, + recv_prev=True, + recv_next=False, + ) return input_tensor -def send_forward_backward_recv_forward_backward(output_tensor, - input_tensor_grad, recv_prev, - recv_next): +def send_forward_backward_recv_forward_backward( + output_tensor, input_tensor_grad, recv_prev, recv_next +): # always have to send dytpe info to downstream if not _send_recv_meta.has_send_meta: _send_recv_meta.set_send_message(output_tensor) @@ -559,7 +666,8 @@ def send_forward_backward_recv_forward_backward(output_tensor, tensor_send_prev=input_tensor_grad, recv_prev=recv_prev, recv_next=recv_next, - sync_recv=False) + sync_recv=False, + ) return input_tensor, output_tensor_grad @@ -573,19 +681,23 @@ def send_forward_recv_forward(output_tensor, recv_prev): _send_recv_meta.recv_meta(_hcg.recv_prev_group) _send_recv_meta.has_recv_meta = _use_cache - input_tensor, _ = _p2p_helper(tensor_send_next=output_tensor, - tensor_send_prev=None, - recv_prev=recv_prev, - recv_next=False, - sync_recv=False) + input_tensor, _ = _p2p_helper( + tensor_send_next=output_tensor, + tensor_send_prev=None, + recv_prev=recv_prev, + recv_next=False, + sync_recv=False, + ) return input_tensor def send_backward_recv_backward(input_tensor_grad, recv_next): - _, output_tensor_grad = _p2p_helper(tensor_send_next=None, - tensor_send_prev=input_tensor_grad, - recv_prev=False, - recv_next=recv_next, - sync_recv=False) + _, output_tensor_grad = _p2p_helper( + tensor_send_next=None, + tensor_send_prev=input_tensor_grad, + recv_prev=False, + recv_next=recv_next, + sync_recv=False, + ) return output_tensor_grad