diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 02a1b421526dfc88a39edf1b82c394f2c816187a..56429b748064daeac2780d5414513fffa9003b58 100755 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -526,7 +526,7 @@ class PipelineParallelWithInterleave(PipelineParallel): self.set_virtual_pipeline_rank(0) self.input_tensors[0].append( - p2p.recv_forward(self.is_pipeline_first_stage())) + p2p.recv_forward(self.is_pipeline_first_stage(), sync_recv=False)) # run startup steps for micro_step in range(startup_steps): @@ -647,7 +647,8 @@ class PipelineParallelWithInterleave(PipelineParallel): if not forward_only: if all_startup_steps: self.output_tensor_grads[self.num_model_chunks - 1].append( - p2p.recv_backward(self.is_pipeline_last_stage())) + p2p.recv_backward(self.is_pipeline_last_stage(), + sync_recv=False)) for micro_step in range(steady_steps, num_steps): # cooldown loop diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index 160c5f1511220aa8ec36f670d10f96f1b5e90cd9..7962e2dd4373e643a4510bc5c56bf53b3c02f88e 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -207,6 +207,7 @@ def _partial_recv_op(tensor, group, use_calc_stream, ring_id, src, nranks, rank_id): src_rank_in_group = src if group is None else group.get_group_rank(src) if _in_legacy_dygraph(): + assert use_calc_stream return _legacy_C_ops.partial_recv(tensor.detach(), 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, 'peer', src_rank_in_group, 'num', @@ -216,8 +217,11 @@ def _partial_recv_op(tensor, group, use_calc_stream, ring_id, src, nranks, elif in_dygraph_mode(): group = paddle.distributed.collective._get_default_group( ) if group is None else group - return group.process_group.recv_partial(tensor, src_rank_in_group, + task = group.process_group.recv_partial(tensor, src_rank_in_group, nranks, rank_id) + if use_calc_stream: + task.wait() + return task def recv_partial(tensor, @@ -238,7 +242,7 @@ def recv_partial(tensor, return _partial_recv_op(tensor, group, use_calc_stream, ring_id, src_rank, nranks, rank_id) else: - if _in_legacy_dygraph(): + if _in_legacy_dygraph() or use_calc_stream: recv_op = paddle.distributed.recv elif in_dygraph_mode(): recv_op = paddle.distributed.irecv @@ -275,7 +279,11 @@ def allgather_partial(tensor, nranks, rank_id) -def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): +def _p2p_helper(tensor_send_next, + tensor_send_prev, + recv_prev, + recv_next, + sync_recv=True): global _hcg tensor_recv_prev = None @@ -354,7 +362,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): nranks=mp_degree, rank_id=mp_rank, group=_hcg.recv_prev_group, - use_calc_stream=True)) + use_calc_stream=sync_recv)) else: tasks.append( recv_partial(tensor_recv_prev, @@ -362,7 +370,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): nranks=mp_degree, rank_id=mp_rank, group=_hcg.recv_prev_group, - use_calc_stream=True)) + use_calc_stream=sync_recv)) if tensor_send_next is not None: if isinstance(tensor_send_next, tuple): @@ -394,7 +402,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): nranks=mp_degree, rank_id=mp_rank, group=_hcg.recv_next_group, - use_calc_stream=True)) + use_calc_stream=sync_recv)) else: tasks.append( @@ -403,10 +411,10 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): nranks=mp_degree, rank_id=mp_rank, group=_hcg.recv_next_group, - use_calc_stream=True)) + use_calc_stream=sync_recv)) - if in_dygraph_mode(): - # wait isend/irecv tasks in eager dygraph mode with new comm library + if not sync_recv and in_dygraph_mode(): + # wait irecv tasks in eager dygraph mode with new comm library for task in tasks: assert task is not None task.wait() @@ -443,7 +451,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): return tensor_recv_prev, tensor_recv_next -def recv_forward(pp_first_stage): +def recv_forward(pp_first_stage, sync_recv=True): if pp_first_stage: input_tensor = None else: @@ -454,18 +462,20 @@ def recv_forward(pp_first_stage): input_tensor, _ = _p2p_helper(tensor_send_next=None, tensor_send_prev=None, recv_prev=True, - recv_next=False) + recv_next=False, + sync_recv=sync_recv) return input_tensor -def recv_backward(pp_last_stage): +def recv_backward(pp_last_stage, sync_recv=True): if pp_last_stage: output_tensor_grad = None else: _, output_tensor_grad = _p2p_helper(tensor_send_next=None, tensor_send_prev=None, recv_prev=False, - recv_next=True) + recv_next=True, + sync_recv=sync_recv) return output_tensor_grad @@ -527,7 +537,8 @@ def send_forward_backward_recv_forward_backward(output_tensor, tensor_send_next=output_tensor, tensor_send_prev=input_tensor_grad, recv_prev=recv_prev, - recv_next=recv_next) + recv_next=recv_next, + sync_recv=False) return input_tensor, output_tensor_grad @@ -544,7 +555,8 @@ def send_forward_recv_forward(output_tensor, recv_prev): input_tensor, _ = _p2p_helper(tensor_send_next=output_tensor, tensor_send_prev=None, recv_prev=recv_prev, - recv_next=False) + recv_next=False, + sync_recv=False) return input_tensor @@ -553,5 +565,6 @@ def send_backward_recv_backward(input_tensor_grad, recv_next): _, output_tensor_grad = _p2p_helper(tensor_send_next=None, tensor_send_prev=input_tensor_grad, recv_prev=False, - recv_next=recv_next) + recv_next=recv_next, + sync_recv=False) return output_tensor_grad