From 5a9214d82955287d8561baec7d1c5192ba717ee6 Mon Sep 17 00:00:00 2001 From: ShenLiang <1422485404@qq.com> Date: Thu, 31 Aug 2023 11:31:56 +0800 Subject: [PATCH] [Distributed]Fix cache p2p in pp (#56796) * add usecache * add p2p cache fix * add cache --- .../fleet/meta_parallel/pipeline_parallel.py | 67 +++- .../pp_utils/p2p_communication.py | 379 ++++++++++-------- 2 files changed, 248 insertions(+), 198 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index cfabb974f95..4836472e665 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -199,11 +199,13 @@ class PipelineParallel(MetaParallelBase): p2p.initialize_p2p_groups( hcg, - self._using_cache, self._enable_partial_send_recv, self._enable_timer, ) + # construct pipeline meta info + self._p2p_helper = p2p.P2pHelper(self._using_cache) + self.global_rank = self._hcg.get_global_rank() self.micro_batch_id = 0 @@ -349,10 +351,14 @@ class PipelineParallel(MetaParallelBase): micro_dataset = self._wrap_data(data) for step_id in range(startup_steps): - input_tensor = p2p.recv_forward(self.is_pipeline_first_stage()) + input_tensor = self._p2p_helper.recv_forward( + self.is_pipeline_first_stage() + ) output_tensor = self._forward_step(input_tensor, micro_dataset) - p2p.send_forward(output_tensor, self.is_pipeline_last_stage()) + self._p2p_helper.send_forward( + output_tensor, self.is_pipeline_last_stage() + ) input_buffers.append(input_tensor) output_buffers.append(output_tensor) @@ -361,14 +367,16 @@ class PipelineParallel(MetaParallelBase): self._release_output(output_tensor) if steady_steps > 0: - input_tensor = p2p.recv_forward(self.is_pipeline_first_stage()) + input_tensor = self._p2p_helper.recv_forward( + self.is_pipeline_first_stage() + ) for i in range(steady_steps): last_iter = i == (steady_steps - 1) output_tensor = self._forward_step(input_tensor, micro_dataset) - output_tensor_grad = p2p.send_forward_recv_backward( + output_tensor_grad = self._p2p_helper.send_forward_recv_backward( output_tensor, self.is_pipeline_last_stage() ) @@ -388,11 +396,11 @@ class PipelineParallel(MetaParallelBase): if last_iter: input_tensor = None - p2p.send_backward( + self._p2p_helper.send_backward( input_tensor_grad, self.is_pipeline_first_stage() ) else: - input_tensor = p2p.send_backward_recv_forward( + input_tensor = self._p2p_helper.send_backward_recv_forward( input_tensor_grad, self.is_pipeline_first_stage() ) @@ -400,14 +408,16 @@ class PipelineParallel(MetaParallelBase): input_tensor = input_buffers.pop(0) output_tensor = output_buffers.pop(0) - output_tensor_grad = p2p.recv_backward( + output_tensor_grad = self._p2p_helper.recv_backward( self.is_pipeline_last_stage() ) input_tensor_grad = self._backward_step( input_tensor, output_tensor, output_tensor_grad ) - p2p.send_backward(input_tensor_grad, self.is_pipeline_first_stage()) + self._p2p_helper.send_backward( + input_tensor_grad, self.is_pipeline_first_stage() + ) if self._comm_overlap: assert ( @@ -513,28 +523,38 @@ class PipelineParallel(MetaParallelBase): micro_dataset = self._wrap_data(data) for step_id in range(startup_steps): - input_tensor = p2p.recv_forward(self.is_pipeline_first_stage()) + input_tensor = self._p2p_helper.recv_forward( + self.is_pipeline_first_stage() + ) output_tensor = self._forward_step(input_tensor, micro_dataset) - p2p.send_forward(output_tensor, self.is_pipeline_last_stage()) + self._p2p_helper.send_forward( + output_tensor, self.is_pipeline_last_stage() + ) input_buffers.append(input_tensor) output_buffers.append(output_tensor) if steady_steps > 0: - input_tensor = p2p.recv_forward(self.is_pipeline_first_stage()) + input_tensor = self._p2p_helper.recv_forward( + self.is_pipeline_first_stage() + ) for i in range(steady_steps): last_iter = i == (steady_steps - 1) output_tensor = self._forward_step(input_tensor, micro_dataset) - p2p.send_forward(output_tensor, self.is_pipeline_last_stage()) + self._p2p_helper.send_forward( + output_tensor, self.is_pipeline_last_stage() + ) input_buffers.append(input_tensor) output_buffers.append(output_tensor) if not last_iter: - input_tensor = p2p.recv_forward(self.is_pipeline_first_stage()) + input_tensor = self._p2p_helper.recv_forward( + self.is_pipeline_first_stage() + ) if self._compute_loss: self.train_loss = self._broadcast_final_loss() @@ -859,6 +879,11 @@ class PipelineParallelWithInterleave(PipelineParallel): not forward_only ), "compute_loss can only be set to False when forward_only is set to True" + # NOTE(shenliang03): Due to ring_exchange for pipeline with interleave, cache should be enabled + assert ( + self._using_cache + ), "cache should be enabled for pipeline with interleave" + # init some attributes for this batch run self.scaler = scaler self.total_loss = None @@ -904,7 +929,9 @@ class PipelineParallelWithInterleave(PipelineParallel): self.set_virtual_pipeline_rank(0) self.input_tensors[0].append( - p2p.recv_forward(self.is_pipeline_first_stage(), sync_recv=False) + self._p2p_helper.recv_forward( + self.is_pipeline_first_stage(), sync_recv=False + ) ) # run startup steps @@ -942,7 +969,7 @@ class PipelineParallelWithInterleave(PipelineParallel): ( input_tensor, output_tensor_grad, - ) = p2p.send_forward_backward_recv_forward_backward( + ) = self._p2p_helper.send_forward_backward_recv_forward_backward( output_tensor, input_tensor_grad, recv_prev=recv_prev, @@ -952,7 +979,7 @@ class PipelineParallelWithInterleave(PipelineParallel): output_tensor_grad ) else: - input_tensor = p2p.send_forward_recv_forward( + input_tensor = self._p2p_helper.send_forward_recv_forward( output_tensor, recv_prev=recv_prev ) self.input_tensors[next_virtual_pp_rank].append(input_tensor) @@ -1033,7 +1060,7 @@ class PipelineParallelWithInterleave(PipelineParallel): ( input_tensor, output_tensor_grad, - ) = p2p.send_forward_backward_recv_forward_backward( + ) = self._p2p_helper.send_forward_backward_recv_forward_backward( output_tensor, input_tensor_grad, recv_prev=recv_prev, @@ -1057,7 +1084,7 @@ class PipelineParallelWithInterleave(PipelineParallel): if not forward_only: if all_startup_steps: self.output_tensor_grads[self.num_model_chunks - 1].append( - p2p.recv_backward( + self._p2p_helper.recv_backward( self.is_pipeline_last_stage(), sync_recv=False ) ) @@ -1080,7 +1107,7 @@ class PipelineParallelWithInterleave(PipelineParallel): recv_next = False self.output_tensor_grads[next_backward_virtual_pp_rank].append( - p2p.send_backward_recv_backward( + self._p2p_helper.send_backward_recv_backward( input_tensor_grad, recv_next=recv_next ) ) diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index 304b226ee1f..9f8a032d588 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -31,17 +31,16 @@ from ...utils import timer_helper as timer from .utils import number_2_dtype, paddle_2_number _hcg = None -_use_cache = False +# _use_cache = False _enable_partial_send_recv = True _timers = None def initialize_p2p_groups( - hcg, use_cache=True, enable_partial_send_recv=True, enable_timer=False + hcg, enable_partial_send_recv=True, enable_timer=False ): - global _hcg, _use_cache, _enable_partial_send_recv, _timers + global _hcg, _enable_partial_send_recv, _timers _hcg = hcg - _use_cache = use_cache _enable_partial_send_recv = enable_partial_send_recv if enable_timer: _timers = timer.get_timers() @@ -170,8 +169,14 @@ class SendRecvMeta: ] ) - -_send_recv_meta = SendRecvMeta() + def __repr__(self): + return "send_shape_message: {}, send_dtype_message: {}, recv_shape_message: {}, recv_dtype_message: {}, recv_stop_gradient: {}".format( + self.send_shape_message, + self.send_dtype_message, + self.recv_shape_message, + self.recv_dtype_message, + self.recv_stop_gradient, + ) def _is_valid_send_recv_partial(tensor, mp_degree): @@ -303,7 +308,12 @@ def _process_p2p_tuple_or_tensor( def _p2p_helper( - tensor_send_next, tensor_send_prev, recv_prev, recv_next, sync_recv=True + tensor_send_next, + tensor_send_prev, + recv_prev, + recv_next, + sync_recv=True, + send_recv_meta=None, ): global _hcg @@ -311,12 +321,13 @@ def _p2p_helper( tensor_recv_next = None # send / recv message - recv_shape_msg = _send_recv_meta.recv_shape_message - recv_dtype_msg = _send_recv_meta.recv_dtype_message - recv_stop_gradient = _send_recv_meta.recv_stop_gradient + assert send_recv_meta is not None, "send_recv_meta should not be None" + recv_shape_msg = send_recv_meta.recv_shape_message + recv_dtype_msg = send_recv_meta.recv_dtype_message + recv_stop_gradient = send_recv_meta.recv_stop_gradient - send_shape_msg = _send_recv_meta.send_shape_message - send_dtype_msg = _send_recv_meta.send_dtype_message + send_shape_msg = send_recv_meta.send_shape_message + send_dtype_msg = send_recv_meta.send_dtype_message # model parallel message mp_group = _hcg.get_model_parallel_group() @@ -441,183 +452,195 @@ def _p2p_helper( return tensor_recv_prev, tensor_recv_next -def recv_forward(pp_first_stage, sync_recv=True): - global _timers - if _timers is not None: - _timers("recv_forward").start() - if pp_first_stage: - input_tensor = None - else: - if not _send_recv_meta.has_recv_meta: - _send_recv_meta.recv_meta(_hcg.get_pipe_parallel_group()) - _send_recv_meta.has_recv_meta = _use_cache - - input_tensor, _ = _p2p_helper( - tensor_send_next=None, - tensor_send_prev=None, - recv_prev=True, - recv_next=False, - sync_recv=sync_recv, - ) - if _timers is not None: - _timers("recv_forward").stop() - return input_tensor - +class P2pHelper: + def __init__(self, use_cache=True): + self._send_recv_meta = SendRecvMeta() + self._use_cache = use_cache -def recv_backward(pp_last_stage, sync_recv=True): - global _timers - if _timers is not None: - _timers("recv_backward").start() - if pp_last_stage: - output_tensor_grad = None - else: - _, output_tensor_grad = _p2p_helper( - tensor_send_next=None, - tensor_send_prev=None, - recv_prev=False, - recv_next=True, - sync_recv=sync_recv, - ) - if _timers is not None: - _timers("recv_backward").stop() - return output_tensor_grad - - -def send_forward(output_tensor, pp_last_stage): - global _timers - if _timers is not None: - _timers("send_forward").start() - if not pp_last_stage: - if not _send_recv_meta.has_send_meta: - _send_recv_meta.set_send_message(output_tensor) - _send_recv_meta.send_meta( + def _send_meta(self, output_tensor): + if not self._send_recv_meta.has_send_meta: + self._send_recv_meta.set_send_message(output_tensor) + self._send_recv_meta.send_meta( output_tensor, _hcg.get_pipe_parallel_group() ) - _send_recv_meta.has_send_meta = _use_cache - - _p2p_helper( + self._send_recv_meta.has_send_meta = self._use_cache + + def _recv_meta(self): + if not self._send_recv_meta.has_recv_meta: + self._send_recv_meta.recv_meta(_hcg.get_pipe_parallel_group()) + self._send_recv_meta.has_recv_meta = self._use_cache + + def recv_forward(self, pp_first_stage, sync_recv=True): + global _timers + if _timers is not None: + _timers("recv_forward").start() + if pp_first_stage: + input_tensor = None + else: + self._recv_meta() + + input_tensor, _ = _p2p_helper( + tensor_send_next=None, + tensor_send_prev=None, + recv_prev=True, + recv_next=False, + sync_recv=sync_recv, + send_recv_meta=self._send_recv_meta, + ) + if _timers is not None: + _timers("recv_forward").stop() + return input_tensor + + def recv_backward(self, pp_last_stage, sync_recv=True): + global _timers + if _timers is not None: + _timers("recv_backward").start() + if pp_last_stage: + output_tensor_grad = None + else: + _, output_tensor_grad = _p2p_helper( + tensor_send_next=None, + tensor_send_prev=None, + recv_prev=False, + recv_next=True, + sync_recv=sync_recv, + send_recv_meta=self._send_recv_meta, + ) + if _timers is not None: + _timers("recv_backward").stop() + return output_tensor_grad + + def send_forward(self, output_tensor, pp_last_stage): + global _timers + if _timers is not None: + _timers("send_forward").start() + if not pp_last_stage: + self._send_meta(output_tensor) + + _p2p_helper( + tensor_send_next=output_tensor, + tensor_send_prev=None, + recv_prev=False, + recv_next=False, + send_recv_meta=self._send_recv_meta, + ) + if _timers is not None: + _timers("send_forward").stop() + + def send_backward(self, input_tensor_grad, pp_first_stage): + global _timers + if _timers is not None: + _timers("send_backward").start() + if not pp_first_stage: + _p2p_helper( + tensor_send_next=None, + tensor_send_prev=input_tensor_grad, + recv_prev=False, + recv_next=False, + send_recv_meta=self._send_recv_meta, + ) + if _timers is not None: + _timers("send_backward").stop() + + def send_forward_recv_backward(self, output_tensor, pp_last_stage): + global _timers + if _timers is not None: + _timers("send_forward_recv_backward").start() + if pp_last_stage: + output_tensor_grad = None + else: + _, output_tensor_grad = _p2p_helper( + tensor_send_next=output_tensor, + tensor_send_prev=None, + recv_prev=False, + recv_next=True, + send_recv_meta=self._send_recv_meta, + ) + if _timers is not None: + _timers("send_forward_recv_backward").stop() + return output_tensor_grad + + def send_backward_recv_forward(self, input_tensor_grad, pp_first_stage): + global _timers + if _timers is not None: + _timers("send_backward_recv_forward").start() + if pp_first_stage: + input_tensor = None + else: + input_tensor, _ = _p2p_helper( + tensor_send_next=None, + tensor_send_prev=input_tensor_grad, + recv_prev=True, + recv_next=False, + send_recv_meta=self._send_recv_meta, + ) + if _timers is not None: + _timers("send_backward_recv_forward").stop() + return input_tensor + + def send_forward_backward_recv_forward_backward( + self, output_tensor, input_tensor_grad, recv_prev, recv_next + ): + # always have to send dytpe info to downstream + global _timers + if _timers is not None: + _timers("send_forward_backward_recv_forward_backward").start() + + self._send_meta(output_tensor) + if recv_prev: + self._recv_meta() + + input_tensor, output_tensor_grad = _p2p_helper( tensor_send_next=output_tensor, - tensor_send_prev=None, - recv_prev=False, - recv_next=False, - ) - if _timers is not None: - _timers("send_forward").stop() - - -def send_backward(input_tensor_grad, pp_first_stage): - global _timers - if _timers is not None: - _timers("send_backward").start() - if not pp_first_stage: - _p2p_helper( - tensor_send_next=None, tensor_send_prev=input_tensor_grad, - recv_prev=False, - recv_next=False, + recv_prev=recv_prev, + recv_next=recv_next, + sync_recv=False, + send_recv_meta=self._send_recv_meta, ) - if _timers is not None: - _timers("send_backward").stop() + if _timers is not None: + _timers("send_forward_backward_recv_forward_backward").stop() + return input_tensor, output_tensor_grad + def send_forward_recv_forward(self, output_tensor, recv_prev): + # always have to send dytpe info to downstream + global _timers + if _timers is not None: + _timers("send_forward_recv_forward").start() -def send_forward_recv_backward(output_tensor, pp_last_stage): - global _timers - if _timers is not None: - _timers("send_forward_recv_backward").start() - if pp_last_stage: - output_tensor_grad = None - else: - _, output_tensor_grad = _p2p_helper( + self._send_meta(output_tensor) + if recv_prev: + self._recv_meta() + + input_tensor, _ = _p2p_helper( tensor_send_next=output_tensor, tensor_send_prev=None, - recv_prev=False, - recv_next=True, + recv_prev=recv_prev, + recv_next=False, + sync_recv=False, + send_recv_meta=self._send_recv_meta, ) - if _timers is not None: - _timers("send_forward_recv_backward").stop() - return output_tensor_grad - - -def send_backward_recv_forward(input_tensor_grad, pp_first_stage): - global _timers - if _timers is not None: - _timers("send_backward_recv_forward").start() - if pp_first_stage: - input_tensor = None - else: - input_tensor, _ = _p2p_helper( + if _timers is not None: + _timers("send_forward_recv_forward").stop() + return input_tensor + + def send_backward_recv_backward(self, input_tensor_grad, recv_next): + global _timers + if _timers is not None: + _timers("send_backward_recv_backward").start() + _, output_tensor_grad = _p2p_helper( tensor_send_next=None, tensor_send_prev=input_tensor_grad, - recv_prev=True, - recv_next=False, + recv_prev=False, + recv_next=recv_next, + sync_recv=False, + send_recv_meta=self._send_recv_meta, ) - if _timers is not None: - _timers("send_backward_recv_forward").stop() - return input_tensor - - -def send_forward_backward_recv_forward_backward( - output_tensor, input_tensor_grad, recv_prev, recv_next -): - # always have to send dytpe info to downstream - global _timers - if _timers is not None: - _timers("send_forward_backward_recv_forward_backward").start() - if not _send_recv_meta.has_send_meta: - _send_recv_meta.set_send_message(output_tensor) - _send_recv_meta.send_meta(output_tensor, _hcg.get_pipe_parallel_group()) - _send_recv_meta.has_send_meta = _use_cache - if recv_prev and not _send_recv_meta.has_recv_meta: - _send_recv_meta.recv_meta(_hcg.get_pipe_parallel_group()) - _send_recv_meta.has_recv_meta = _use_cache - input_tensor, output_tensor_grad = _p2p_helper( - tensor_send_next=output_tensor, - tensor_send_prev=input_tensor_grad, - recv_prev=recv_prev, - recv_next=recv_next, - sync_recv=False, - ) - if _timers is not None: - _timers("send_forward_backward_recv_forward_backward").stop() - return input_tensor, output_tensor_grad - - -def send_forward_recv_forward(output_tensor, recv_prev): - # always have to send dytpe info to downstream - global _timers - if _timers is not None: - _timers("send_forward_recv_forward").start() - if not _send_recv_meta.has_send_meta: - _send_recv_meta.set_send_message(output_tensor) - _send_recv_meta.send_meta(output_tensor, _hcg.get_pipe_parallel_group()) - _send_recv_meta.has_send_meta = _use_cache - if recv_prev and not _send_recv_meta.has_recv_meta: - _send_recv_meta.recv_meta(_hcg.get_pipe_parallel_group()) - _send_recv_meta.has_recv_meta = _use_cache - - input_tensor, _ = _p2p_helper( - tensor_send_next=output_tensor, - tensor_send_prev=None, - recv_prev=recv_prev, - recv_next=False, - sync_recv=False, - ) - if _timers is not None: - _timers("send_forward_recv_forward").stop() - return input_tensor - - -def send_backward_recv_backward(input_tensor_grad, recv_next): - global _timers - if _timers is not None: - _timers("send_backward_recv_backward").start() - _, output_tensor_grad = _p2p_helper( - tensor_send_next=None, - tensor_send_prev=input_tensor_grad, - recv_prev=False, - recv_next=recv_next, - sync_recv=False, - ) - if _timers is not None: - _timers("send_backward_recv_backward").stop() - return output_tensor_grad + if _timers is not None: + _timers("send_backward_recv_backward").stop() + return output_tensor_grad + + def __repr__(self): + debug_str = f"using cache: {self._use_cache} \n" + debug_str += repr(self._send_recv_meta) + return debug_str -- GitLab