From 71cdf0094f270ddf178032cf3d38fd8e300e81a1 Mon Sep 17 00:00:00 2001 From: ShenLiang <1422485404@qq.com> Date: Tue, 20 Jun 2023 16:51:21 +0800 Subject: [PATCH] solve conflict (#54747) --- .../fleet/meta_parallel/pipeline_parallel.py | 5 +- .../pp_utils/p2p_communication.py | 329 ++++++------------ 2 files changed, 116 insertions(+), 218 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 9598112a119..8fcc3d855e1 100755 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -10,11 +10,10 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and +import os import time import warnings -import os - import paddle from paddle import framework @@ -176,6 +175,7 @@ class PipelineParallel(MetaParallelBase): self._enable_timer = self._strategy.hybrid_configs[ "pp_configs" ].enable_timer + self._profiling = self._strategy.hybrid_configs["pp_configs"].profiling self._records = [] self._record_format = ( @@ -303,6 +303,7 @@ class PipelineParallel(MetaParallelBase): for model in models: # For virtual pipeline. Will separate parameters in different chunk into # different groups to get the best performance. + parameter_list = [ p for p in model.parameters() if not p.stop_gradient ] diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index eede472f792..dd422635a8c 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -175,30 +175,63 @@ def _is_valid_send_recv_partial(tensor, mp_degree): if not _enable_partial_send_recv: return False tensor_numel = np.prod(tensor.shape) - assert tensor_numel != 0, "can't send/recv zero element" + assert tensor_numel > 0, "can't send/recv zero element" return mp_degree > 1 and tensor_numel % mp_degree == 0 -def _partial_send_op(tensor, group, dst, nranks, rank_id): +def _send_on_calc_stream(tensor, group, dst, nranks=1, rank_id=0): assert ( group is not None - ), "Group should be an instance for _partial_send_op." + ), "Group should be an instance for _send_on_calc_stream." dst_rank_in_group = group.get_group_rank(dst) - if framework.in_dynamic_mode(): - return group.process_group.send_partial( + if _is_valid_send_recv_partial(tensor, nranks): + return group.process_group.send_partial_on_calc_stream( tensor, dst_rank_in_group, nranks, rank_id ) + else: + return group.process_group.send_on_calc_stream( + tensor, dst_rank_in_group + ) -def _partial_recv_op(tensor, group, src, nranks, rank_id): +def _recv_on_calc_stream(tensor, group, src, nranks=1, rank_id=0): assert ( group is not None - ), "Group should be an instance for _partial_recv_op." + ), "Group should be an instance for _recv_on_calc_stream." src_rank_in_group = group.get_group_rank(src) - if framework.in_dynamic_mode(): - return group.process_group.recv_partial( + if _is_valid_send_recv_partial(tensor, nranks): + return group.process_group.recv_partial_on_calc_stream( tensor, src_rank_in_group, nranks, rank_id ) + else: + return group.process_group.recv_on_calc_stream( + tensor, src_rank_in_group + ) + + +class P2PonCalcStream: + def __init__(self, op, tensor, peer, group, nranks=1, rank_id=0): + """ + Args: + op (function): The function to be executed on the calc stream. + tensor (Tensor): The tensor to be sent or received. + peer (int): The peer rank. + group (Group): The process group to p2p. + nranks (int): The number of ranks in model parallel group. + rank_id (int): The rank id in the model parallel group. + """ + if op not in [_send_on_calc_stream, _recv_on_calc_stream]: + raise RuntimeError( + "Invalid ``op`` function. Expected ``op`` " + "to be of type ``_send_on_calc_stream`` or " + "``_recv_on_calc_stream``." + ) + self.op = op + self.tensor = tensor + self.peer = peer + self.group = group + self.nranks = nranks + self.rank_id = rank_id def _partial_allgather_op( @@ -231,46 +264,39 @@ def allgather_partial( ) -def partial_batch_isend_irecv(p2p_op_list): +def batch_send_recv_on_calc_stream(p2p_op_list): group = p2p_op_list[0].group if _warn_cur_rank_not_in_group(group): return - - if framework.in_dynamic_mode(): - group = _get_global_group() if group is None else group - backend = group.backend - tasks = [] - with _with_batch_p2p_guard(backend): - for p2p_op in p2p_op_list: - op = p2p_op.op - tensor = p2p_op.tensor - peer = p2p_op.peer - comm_group = p2p_op.group - nranks = p2p_op.nranks - rank_id = p2p_op.rank_id - task = op(tensor, comm_group, peer, nranks, rank_id) - if task is not None: - tasks.append(task) - return tasks - else: - raise RuntimeError("Don't support static graph mode currently.") - - -class PartialP2POp: - def __init__(self, op, nranks, rank_id, tensor, peer, group): - if op not in [_partial_recv_op, _partial_send_op]: - raise RuntimeError( - "Invalid ``op`` function. Expected ``op`` " - "to be of type ``_partial_send_op`` or " - "``_partial_recv_op``." + group = _get_global_group() if group is None else group + backend = group.backend + with _with_batch_p2p_guard(backend): + for p2p_op in p2p_op_list: + op = p2p_op.op + tensor = p2p_op.tensor + peer = p2p_op.peer + comm_group = p2p_op.group + nranks = p2p_op.nranks + rank_id = p2p_op.rank_id + op(tensor, comm_group, peer, nranks, rank_id) + + +def _process_p2p_tuple_or_tensor( + tensors, p2p_func, pp_rank, pp_group, mp_degree=1, mp_rank=0 +): + ops = [] + if isinstance(tensors, tuple): + for tensor in tensors: + op = P2PonCalcStream( + p2p_func, tensor, pp_rank, pp_group, mp_degree, mp_rank ) - - self.op = op - self.nranks = nranks - self.rank_id = rank_id - self.tensor = tensor - self.peer = peer - self.group = group + ops.append(op) + else: + op = P2PonCalcStream( + p2p_func, tensors, pp_rank, pp_group, mp_degree, mp_rank + ) + ops.append(op) + return ops def _p2p_helper( @@ -326,189 +352,60 @@ def _p2p_helper( ) ops = [] - partial_ops = [] pipe_group = _hcg.get_pipe_parallel_group() + # start to p2p communicate if tensor_send_prev is not None: src_rank = _hcg._get_p2p_prev_rank() - if isinstance(tensor_send_prev, tuple): - for d in tensor_send_prev: - if _is_valid_send_recv_partial(d, mp_degree): - op = PartialP2POp( - _partial_send_op, - mp_degree, - mp_rank, - d, - src_rank, - pipe_group, - ) - partial_ops.append(op) - else: - op = paddle.distributed.P2POp( - paddle.distributed.isend, - d, - src_rank, - pipe_group, - ) - ops.append(op) - else: - if _is_valid_send_recv_partial(tensor_send_prev, mp_degree): - op = PartialP2POp( - _partial_send_op, - mp_degree, - mp_rank, - tensor_send_prev, - src_rank, - pipe_group, - ) - partial_ops.append(op) - else: - op = paddle.distributed.P2POp( - paddle.distributed.isend, - tensor_send_prev, - src_rank, - pipe_group, - ) - ops.append(op) - + ops.extend( + _process_p2p_tuple_or_tensor( + tensor_send_prev, + _send_on_calc_stream, + src_rank, + pipe_group, + mp_degree, + mp_rank, + ) + ) if tensor_recv_prev is not None: dst_rank = _hcg._get_p2p_prev_rank() - if isinstance(tensor_recv_prev, tuple): - for d in tensor_recv_prev: - if _is_valid_send_recv_partial(d, mp_degree): - op = PartialP2POp( - _partial_recv_op, - mp_degree, - mp_rank, - d, - dst_rank, - pipe_group, - ) - partial_ops.append(op) - else: - op = paddle.distributed.P2POp( - paddle.distributed.irecv, - d, - dst_rank, - pipe_group, - ) - ops.append(op) - else: - if _is_valid_send_recv_partial(tensor_recv_prev, mp_degree): - op = PartialP2POp( - _partial_recv_op, - mp_degree, - mp_rank, - tensor_recv_prev, - dst_rank, - pipe_group, - ) - partial_ops.append(op) - else: - op = paddle.distributed.P2POp( - paddle.distributed.irecv, - tensor_recv_prev, - dst_rank, - pipe_group, - ) - ops.append(op) - + ops.extend( + _process_p2p_tuple_or_tensor( + tensor_recv_prev, + _recv_on_calc_stream, + dst_rank, + pipe_group, + mp_degree, + mp_rank, + ) + ) if tensor_send_next is not None: src_rank = _hcg._get_p2p_next_rank() - if isinstance(tensor_send_next, tuple): - for d in tensor_send_next: - if _is_valid_send_recv_partial(d, mp_degree): - op = PartialP2POp( - _partial_send_op, - mp_degree, - mp_rank, - d, - src_rank, - pipe_group, - ) - partial_ops.append(op) - else: - op = paddle.distributed.P2POp( - paddle.distributed.isend, - d, - src_rank, - pipe_group, - ) - ops.append(op) - else: - if _is_valid_send_recv_partial(tensor_send_next, mp_degree): - op = PartialP2POp( - _partial_send_op, - mp_degree, - mp_rank, - tensor_send_next, - src_rank, - pipe_group, - ) - partial_ops.append(op) - else: - op = paddle.distributed.P2POp( - paddle.distributed.isend, - tensor_send_next, - src_rank, - pipe_group, - ) - ops.append(op) + ops.extend( + _process_p2p_tuple_or_tensor( + tensor_send_next, + _send_on_calc_stream, + src_rank, + pipe_group, + mp_degree, + mp_rank, + ) + ) if tensor_recv_next is not None: dst_rank = _hcg._get_p2p_next_rank() - if isinstance(tensor_recv_next, tuple): - for d in tensor_recv_next: - if _is_valid_send_recv_partial(d, mp_degree): - op = PartialP2POp( - _partial_recv_op, - mp_degree, - mp_rank, - d, - dst_rank, - pipe_group, - ) - partial_ops.append(op) - else: - op = paddle.distributed.P2POp( - paddle.distributed.irecv, - d, - dst_rank, - pipe_group, - ) - ops.append(op) - else: - if _is_valid_send_recv_partial(tensor_recv_next, mp_degree): - op = PartialP2POp( - _partial_recv_op, - mp_degree, - mp_rank, - tensor_recv_next, - dst_rank, - pipe_group, - ) - partial_ops.append(op) - else: - op = paddle.distributed.P2POp( - paddle.distributed.irecv, - tensor_recv_next, - dst_rank, - pipe_group, - ) - ops.append(op) - + ops.extend( + _process_p2p_tuple_or_tensor( + tensor_recv_next, + _recv_on_calc_stream, + dst_rank, + pipe_group, + mp_degree, + mp_rank, + ) + ) if len(ops) > 0: - reqs = paddle.distributed.batch_isend_irecv(ops) - for req in reqs: - req.wait() - - if len(partial_ops) > 0: - reqs = partial_batch_isend_irecv(partial_ops) - for req in reqs: - req.wait() - - # block cpu to wait the result - paddle.device.synchronize() + batch_send_recv_on_calc_stream(ops) tensors_for_all_gather = [] if tensor_recv_prev is not None: -- GitLab