未验证 提交 5a9214d8 编写于 作者: S ShenLiang 提交者: GitHub

[Distributed]Fix cache p2p in pp (#56796)

* add usecache

* add p2p cache fix

* add cache
上级 dfcfc8b7
...@@ -199,11 +199,13 @@ class PipelineParallel(MetaParallelBase): ...@@ -199,11 +199,13 @@ class PipelineParallel(MetaParallelBase):
p2p.initialize_p2p_groups( p2p.initialize_p2p_groups(
hcg, hcg,
self._using_cache,
self._enable_partial_send_recv, self._enable_partial_send_recv,
self._enable_timer, self._enable_timer,
) )
# construct pipeline meta info
self._p2p_helper = p2p.P2pHelper(self._using_cache)
self.global_rank = self._hcg.get_global_rank() self.global_rank = self._hcg.get_global_rank()
self.micro_batch_id = 0 self.micro_batch_id = 0
...@@ -349,10 +351,14 @@ class PipelineParallel(MetaParallelBase): ...@@ -349,10 +351,14 @@ class PipelineParallel(MetaParallelBase):
micro_dataset = self._wrap_data(data) micro_dataset = self._wrap_data(data)
for step_id in range(startup_steps): for step_id in range(startup_steps):
input_tensor = p2p.recv_forward(self.is_pipeline_first_stage()) input_tensor = self._p2p_helper.recv_forward(
self.is_pipeline_first_stage()
)
output_tensor = self._forward_step(input_tensor, micro_dataset) output_tensor = self._forward_step(input_tensor, micro_dataset)
p2p.send_forward(output_tensor, self.is_pipeline_last_stage()) self._p2p_helper.send_forward(
output_tensor, self.is_pipeline_last_stage()
)
input_buffers.append(input_tensor) input_buffers.append(input_tensor)
output_buffers.append(output_tensor) output_buffers.append(output_tensor)
...@@ -361,14 +367,16 @@ class PipelineParallel(MetaParallelBase): ...@@ -361,14 +367,16 @@ class PipelineParallel(MetaParallelBase):
self._release_output(output_tensor) self._release_output(output_tensor)
if steady_steps > 0: if steady_steps > 0:
input_tensor = p2p.recv_forward(self.is_pipeline_first_stage()) input_tensor = self._p2p_helper.recv_forward(
self.is_pipeline_first_stage()
)
for i in range(steady_steps): for i in range(steady_steps):
last_iter = i == (steady_steps - 1) last_iter = i == (steady_steps - 1)
output_tensor = self._forward_step(input_tensor, micro_dataset) output_tensor = self._forward_step(input_tensor, micro_dataset)
output_tensor_grad = p2p.send_forward_recv_backward( output_tensor_grad = self._p2p_helper.send_forward_recv_backward(
output_tensor, self.is_pipeline_last_stage() output_tensor, self.is_pipeline_last_stage()
) )
...@@ -388,11 +396,11 @@ class PipelineParallel(MetaParallelBase): ...@@ -388,11 +396,11 @@ class PipelineParallel(MetaParallelBase):
if last_iter: if last_iter:
input_tensor = None input_tensor = None
p2p.send_backward( self._p2p_helper.send_backward(
input_tensor_grad, self.is_pipeline_first_stage() input_tensor_grad, self.is_pipeline_first_stage()
) )
else: else:
input_tensor = p2p.send_backward_recv_forward( input_tensor = self._p2p_helper.send_backward_recv_forward(
input_tensor_grad, self.is_pipeline_first_stage() input_tensor_grad, self.is_pipeline_first_stage()
) )
...@@ -400,14 +408,16 @@ class PipelineParallel(MetaParallelBase): ...@@ -400,14 +408,16 @@ class PipelineParallel(MetaParallelBase):
input_tensor = input_buffers.pop(0) input_tensor = input_buffers.pop(0)
output_tensor = output_buffers.pop(0) output_tensor = output_buffers.pop(0)
output_tensor_grad = p2p.recv_backward( output_tensor_grad = self._p2p_helper.recv_backward(
self.is_pipeline_last_stage() self.is_pipeline_last_stage()
) )
input_tensor_grad = self._backward_step( input_tensor_grad = self._backward_step(
input_tensor, output_tensor, output_tensor_grad input_tensor, output_tensor, output_tensor_grad
) )
p2p.send_backward(input_tensor_grad, self.is_pipeline_first_stage()) self._p2p_helper.send_backward(
input_tensor_grad, self.is_pipeline_first_stage()
)
if self._comm_overlap: if self._comm_overlap:
assert ( assert (
...@@ -513,28 +523,38 @@ class PipelineParallel(MetaParallelBase): ...@@ -513,28 +523,38 @@ class PipelineParallel(MetaParallelBase):
micro_dataset = self._wrap_data(data) micro_dataset = self._wrap_data(data)
for step_id in range(startup_steps): for step_id in range(startup_steps):
input_tensor = p2p.recv_forward(self.is_pipeline_first_stage()) input_tensor = self._p2p_helper.recv_forward(
self.is_pipeline_first_stage()
)
output_tensor = self._forward_step(input_tensor, micro_dataset) output_tensor = self._forward_step(input_tensor, micro_dataset)
p2p.send_forward(output_tensor, self.is_pipeline_last_stage()) self._p2p_helper.send_forward(
output_tensor, self.is_pipeline_last_stage()
)
input_buffers.append(input_tensor) input_buffers.append(input_tensor)
output_buffers.append(output_tensor) output_buffers.append(output_tensor)
if steady_steps > 0: if steady_steps > 0:
input_tensor = p2p.recv_forward(self.is_pipeline_first_stage()) input_tensor = self._p2p_helper.recv_forward(
self.is_pipeline_first_stage()
)
for i in range(steady_steps): for i in range(steady_steps):
last_iter = i == (steady_steps - 1) last_iter = i == (steady_steps - 1)
output_tensor = self._forward_step(input_tensor, micro_dataset) output_tensor = self._forward_step(input_tensor, micro_dataset)
p2p.send_forward(output_tensor, self.is_pipeline_last_stage()) self._p2p_helper.send_forward(
output_tensor, self.is_pipeline_last_stage()
)
input_buffers.append(input_tensor) input_buffers.append(input_tensor)
output_buffers.append(output_tensor) output_buffers.append(output_tensor)
if not last_iter: if not last_iter:
input_tensor = p2p.recv_forward(self.is_pipeline_first_stage()) input_tensor = self._p2p_helper.recv_forward(
self.is_pipeline_first_stage()
)
if self._compute_loss: if self._compute_loss:
self.train_loss = self._broadcast_final_loss() self.train_loss = self._broadcast_final_loss()
...@@ -859,6 +879,11 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -859,6 +879,11 @@ class PipelineParallelWithInterleave(PipelineParallel):
not forward_only not forward_only
), "compute_loss can only be set to False when forward_only is set to True" ), "compute_loss can only be set to False when forward_only is set to True"
# NOTE(shenliang03): Due to ring_exchange for pipeline with interleave, cache should be enabled
assert (
self._using_cache
), "cache should be enabled for pipeline with interleave"
# init some attributes for this batch run # init some attributes for this batch run
self.scaler = scaler self.scaler = scaler
self.total_loss = None self.total_loss = None
...@@ -904,7 +929,9 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -904,7 +929,9 @@ class PipelineParallelWithInterleave(PipelineParallel):
self.set_virtual_pipeline_rank(0) self.set_virtual_pipeline_rank(0)
self.input_tensors[0].append( self.input_tensors[0].append(
p2p.recv_forward(self.is_pipeline_first_stage(), sync_recv=False) self._p2p_helper.recv_forward(
self.is_pipeline_first_stage(), sync_recv=False
)
) )
# run startup steps # run startup steps
...@@ -942,7 +969,7 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -942,7 +969,7 @@ class PipelineParallelWithInterleave(PipelineParallel):
( (
input_tensor, input_tensor,
output_tensor_grad, output_tensor_grad,
) = p2p.send_forward_backward_recv_forward_backward( ) = self._p2p_helper.send_forward_backward_recv_forward_backward(
output_tensor, output_tensor,
input_tensor_grad, input_tensor_grad,
recv_prev=recv_prev, recv_prev=recv_prev,
...@@ -952,7 +979,7 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -952,7 +979,7 @@ class PipelineParallelWithInterleave(PipelineParallel):
output_tensor_grad output_tensor_grad
) )
else: else:
input_tensor = p2p.send_forward_recv_forward( input_tensor = self._p2p_helper.send_forward_recv_forward(
output_tensor, recv_prev=recv_prev output_tensor, recv_prev=recv_prev
) )
self.input_tensors[next_virtual_pp_rank].append(input_tensor) self.input_tensors[next_virtual_pp_rank].append(input_tensor)
...@@ -1033,7 +1060,7 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -1033,7 +1060,7 @@ class PipelineParallelWithInterleave(PipelineParallel):
( (
input_tensor, input_tensor,
output_tensor_grad, output_tensor_grad,
) = p2p.send_forward_backward_recv_forward_backward( ) = self._p2p_helper.send_forward_backward_recv_forward_backward(
output_tensor, output_tensor,
input_tensor_grad, input_tensor_grad,
recv_prev=recv_prev, recv_prev=recv_prev,
...@@ -1057,7 +1084,7 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -1057,7 +1084,7 @@ class PipelineParallelWithInterleave(PipelineParallel):
if not forward_only: if not forward_only:
if all_startup_steps: if all_startup_steps:
self.output_tensor_grads[self.num_model_chunks - 1].append( self.output_tensor_grads[self.num_model_chunks - 1].append(
p2p.recv_backward( self._p2p_helper.recv_backward(
self.is_pipeline_last_stage(), sync_recv=False self.is_pipeline_last_stage(), sync_recv=False
) )
) )
...@@ -1080,7 +1107,7 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -1080,7 +1107,7 @@ class PipelineParallelWithInterleave(PipelineParallel):
recv_next = False recv_next = False
self.output_tensor_grads[next_backward_virtual_pp_rank].append( self.output_tensor_grads[next_backward_virtual_pp_rank].append(
p2p.send_backward_recv_backward( self._p2p_helper.send_backward_recv_backward(
input_tensor_grad, recv_next=recv_next input_tensor_grad, recv_next=recv_next
) )
) )
......
...@@ -31,17 +31,16 @@ from ...utils import timer_helper as timer ...@@ -31,17 +31,16 @@ from ...utils import timer_helper as timer
from .utils import number_2_dtype, paddle_2_number from .utils import number_2_dtype, paddle_2_number
_hcg = None _hcg = None
_use_cache = False # _use_cache = False
_enable_partial_send_recv = True _enable_partial_send_recv = True
_timers = None _timers = None
def initialize_p2p_groups( def initialize_p2p_groups(
hcg, use_cache=True, enable_partial_send_recv=True, enable_timer=False hcg, enable_partial_send_recv=True, enable_timer=False
): ):
global _hcg, _use_cache, _enable_partial_send_recv, _timers global _hcg, _enable_partial_send_recv, _timers
_hcg = hcg _hcg = hcg
_use_cache = use_cache
_enable_partial_send_recv = enable_partial_send_recv _enable_partial_send_recv = enable_partial_send_recv
if enable_timer: if enable_timer:
_timers = timer.get_timers() _timers = timer.get_timers()
...@@ -170,8 +169,14 @@ class SendRecvMeta: ...@@ -170,8 +169,14 @@ class SendRecvMeta:
] ]
) )
def __repr__(self):
_send_recv_meta = SendRecvMeta() return "send_shape_message: {}, send_dtype_message: {}, recv_shape_message: {}, recv_dtype_message: {}, recv_stop_gradient: {}".format(
self.send_shape_message,
self.send_dtype_message,
self.recv_shape_message,
self.recv_dtype_message,
self.recv_stop_gradient,
)
def _is_valid_send_recv_partial(tensor, mp_degree): def _is_valid_send_recv_partial(tensor, mp_degree):
...@@ -303,7 +308,12 @@ def _process_p2p_tuple_or_tensor( ...@@ -303,7 +308,12 @@ def _process_p2p_tuple_or_tensor(
def _p2p_helper( def _p2p_helper(
tensor_send_next, tensor_send_prev, recv_prev, recv_next, sync_recv=True tensor_send_next,
tensor_send_prev,
recv_prev,
recv_next,
sync_recv=True,
send_recv_meta=None,
): ):
global _hcg global _hcg
...@@ -311,12 +321,13 @@ def _p2p_helper( ...@@ -311,12 +321,13 @@ def _p2p_helper(
tensor_recv_next = None tensor_recv_next = None
# send / recv message # send / recv message
recv_shape_msg = _send_recv_meta.recv_shape_message assert send_recv_meta is not None, "send_recv_meta should not be None"
recv_dtype_msg = _send_recv_meta.recv_dtype_message recv_shape_msg = send_recv_meta.recv_shape_message
recv_stop_gradient = _send_recv_meta.recv_stop_gradient recv_dtype_msg = send_recv_meta.recv_dtype_message
recv_stop_gradient = send_recv_meta.recv_stop_gradient
send_shape_msg = _send_recv_meta.send_shape_message send_shape_msg = send_recv_meta.send_shape_message
send_dtype_msg = _send_recv_meta.send_dtype_message send_dtype_msg = send_recv_meta.send_dtype_message
# model parallel message # model parallel message
mp_group = _hcg.get_model_parallel_group() mp_group = _hcg.get_model_parallel_group()
...@@ -441,183 +452,195 @@ def _p2p_helper( ...@@ -441,183 +452,195 @@ def _p2p_helper(
return tensor_recv_prev, tensor_recv_next return tensor_recv_prev, tensor_recv_next
def recv_forward(pp_first_stage, sync_recv=True): class P2pHelper:
global _timers def __init__(self, use_cache=True):
if _timers is not None: self._send_recv_meta = SendRecvMeta()
_timers("recv_forward").start() self._use_cache = use_cache
if pp_first_stage:
input_tensor = None
else:
if not _send_recv_meta.has_recv_meta:
_send_recv_meta.recv_meta(_hcg.get_pipe_parallel_group())
_send_recv_meta.has_recv_meta = _use_cache
input_tensor, _ = _p2p_helper(
tensor_send_next=None,
tensor_send_prev=None,
recv_prev=True,
recv_next=False,
sync_recv=sync_recv,
)
if _timers is not None:
_timers("recv_forward").stop()
return input_tensor
def recv_backward(pp_last_stage, sync_recv=True): def _send_meta(self, output_tensor):
global _timers if not self._send_recv_meta.has_send_meta:
if _timers is not None: self._send_recv_meta.set_send_message(output_tensor)
_timers("recv_backward").start() self._send_recv_meta.send_meta(
if pp_last_stage:
output_tensor_grad = None
else:
_, output_tensor_grad = _p2p_helper(
tensor_send_next=None,
tensor_send_prev=None,
recv_prev=False,
recv_next=True,
sync_recv=sync_recv,
)
if _timers is not None:
_timers("recv_backward").stop()
return output_tensor_grad
def send_forward(output_tensor, pp_last_stage):
global _timers
if _timers is not None:
_timers("send_forward").start()
if not pp_last_stage:
if not _send_recv_meta.has_send_meta:
_send_recv_meta.set_send_message(output_tensor)
_send_recv_meta.send_meta(
output_tensor, _hcg.get_pipe_parallel_group() output_tensor, _hcg.get_pipe_parallel_group()
) )
_send_recv_meta.has_send_meta = _use_cache self._send_recv_meta.has_send_meta = self._use_cache
_p2p_helper( def _recv_meta(self):
if not self._send_recv_meta.has_recv_meta:
self._send_recv_meta.recv_meta(_hcg.get_pipe_parallel_group())
self._send_recv_meta.has_recv_meta = self._use_cache
def recv_forward(self, pp_first_stage, sync_recv=True):
global _timers
if _timers is not None:
_timers("recv_forward").start()
if pp_first_stage:
input_tensor = None
else:
self._recv_meta()
input_tensor, _ = _p2p_helper(
tensor_send_next=None,
tensor_send_prev=None,
recv_prev=True,
recv_next=False,
sync_recv=sync_recv,
send_recv_meta=self._send_recv_meta,
)
if _timers is not None:
_timers("recv_forward").stop()
return input_tensor
def recv_backward(self, pp_last_stage, sync_recv=True):
global _timers
if _timers is not None:
_timers("recv_backward").start()
if pp_last_stage:
output_tensor_grad = None
else:
_, output_tensor_grad = _p2p_helper(
tensor_send_next=None,
tensor_send_prev=None,
recv_prev=False,
recv_next=True,
sync_recv=sync_recv,
send_recv_meta=self._send_recv_meta,
)
if _timers is not None:
_timers("recv_backward").stop()
return output_tensor_grad
def send_forward(self, output_tensor, pp_last_stage):
global _timers
if _timers is not None:
_timers("send_forward").start()
if not pp_last_stage:
self._send_meta(output_tensor)
_p2p_helper(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=False,
recv_next=False,
send_recv_meta=self._send_recv_meta,
)
if _timers is not None:
_timers("send_forward").stop()
def send_backward(self, input_tensor_grad, pp_first_stage):
global _timers
if _timers is not None:
_timers("send_backward").start()
if not pp_first_stage:
_p2p_helper(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=False,
recv_next=False,
send_recv_meta=self._send_recv_meta,
)
if _timers is not None:
_timers("send_backward").stop()
def send_forward_recv_backward(self, output_tensor, pp_last_stage):
global _timers
if _timers is not None:
_timers("send_forward_recv_backward").start()
if pp_last_stage:
output_tensor_grad = None
else:
_, output_tensor_grad = _p2p_helper(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=False,
recv_next=True,
send_recv_meta=self._send_recv_meta,
)
if _timers is not None:
_timers("send_forward_recv_backward").stop()
return output_tensor_grad
def send_backward_recv_forward(self, input_tensor_grad, pp_first_stage):
global _timers
if _timers is not None:
_timers("send_backward_recv_forward").start()
if pp_first_stage:
input_tensor = None
else:
input_tensor, _ = _p2p_helper(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=True,
recv_next=False,
send_recv_meta=self._send_recv_meta,
)
if _timers is not None:
_timers("send_backward_recv_forward").stop()
return input_tensor
def send_forward_backward_recv_forward_backward(
self, output_tensor, input_tensor_grad, recv_prev, recv_next
):
# always have to send dytpe info to downstream
global _timers
if _timers is not None:
_timers("send_forward_backward_recv_forward_backward").start()
self._send_meta(output_tensor)
if recv_prev:
self._recv_meta()
input_tensor, output_tensor_grad = _p2p_helper(
tensor_send_next=output_tensor, tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=False,
recv_next=False,
)
if _timers is not None:
_timers("send_forward").stop()
def send_backward(input_tensor_grad, pp_first_stage):
global _timers
if _timers is not None:
_timers("send_backward").start()
if not pp_first_stage:
_p2p_helper(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad, tensor_send_prev=input_tensor_grad,
recv_prev=False, recv_prev=recv_prev,
recv_next=False, recv_next=recv_next,
sync_recv=False,
send_recv_meta=self._send_recv_meta,
) )
if _timers is not None: if _timers is not None:
_timers("send_backward").stop() _timers("send_forward_backward_recv_forward_backward").stop()
return input_tensor, output_tensor_grad
def send_forward_recv_forward(self, output_tensor, recv_prev):
# always have to send dytpe info to downstream
global _timers
if _timers is not None:
_timers("send_forward_recv_forward").start()
def send_forward_recv_backward(output_tensor, pp_last_stage): self._send_meta(output_tensor)
global _timers if recv_prev:
if _timers is not None: self._recv_meta()
_timers("send_forward_recv_backward").start()
if pp_last_stage: input_tensor, _ = _p2p_helper(
output_tensor_grad = None
else:
_, output_tensor_grad = _p2p_helper(
tensor_send_next=output_tensor, tensor_send_next=output_tensor,
tensor_send_prev=None, tensor_send_prev=None,
recv_prev=False, recv_prev=recv_prev,
recv_next=True, recv_next=False,
sync_recv=False,
send_recv_meta=self._send_recv_meta,
) )
if _timers is not None: if _timers is not None:
_timers("send_forward_recv_backward").stop() _timers("send_forward_recv_forward").stop()
return output_tensor_grad return input_tensor
def send_backward_recv_backward(self, input_tensor_grad, recv_next):
def send_backward_recv_forward(input_tensor_grad, pp_first_stage): global _timers
global _timers if _timers is not None:
if _timers is not None: _timers("send_backward_recv_backward").start()
_timers("send_backward_recv_forward").start() _, output_tensor_grad = _p2p_helper(
if pp_first_stage:
input_tensor = None
else:
input_tensor, _ = _p2p_helper(
tensor_send_next=None, tensor_send_next=None,
tensor_send_prev=input_tensor_grad, tensor_send_prev=input_tensor_grad,
recv_prev=True, recv_prev=False,
recv_next=False, recv_next=recv_next,
sync_recv=False,
send_recv_meta=self._send_recv_meta,
) )
if _timers is not None: if _timers is not None:
_timers("send_backward_recv_forward").stop() _timers("send_backward_recv_backward").stop()
return input_tensor return output_tensor_grad
def __repr__(self):
def send_forward_backward_recv_forward_backward( debug_str = f"using cache: {self._use_cache} \n"
output_tensor, input_tensor_grad, recv_prev, recv_next debug_str += repr(self._send_recv_meta)
): return debug_str
# always have to send dytpe info to downstream
global _timers
if _timers is not None:
_timers("send_forward_backward_recv_forward_backward").start()
if not _send_recv_meta.has_send_meta:
_send_recv_meta.set_send_message(output_tensor)
_send_recv_meta.send_meta(output_tensor, _hcg.get_pipe_parallel_group())
_send_recv_meta.has_send_meta = _use_cache
if recv_prev and not _send_recv_meta.has_recv_meta:
_send_recv_meta.recv_meta(_hcg.get_pipe_parallel_group())
_send_recv_meta.has_recv_meta = _use_cache
input_tensor, output_tensor_grad = _p2p_helper(
tensor_send_next=output_tensor,
tensor_send_prev=input_tensor_grad,
recv_prev=recv_prev,
recv_next=recv_next,
sync_recv=False,
)
if _timers is not None:
_timers("send_forward_backward_recv_forward_backward").stop()
return input_tensor, output_tensor_grad
def send_forward_recv_forward(output_tensor, recv_prev):
# always have to send dytpe info to downstream
global _timers
if _timers is not None:
_timers("send_forward_recv_forward").start()
if not _send_recv_meta.has_send_meta:
_send_recv_meta.set_send_message(output_tensor)
_send_recv_meta.send_meta(output_tensor, _hcg.get_pipe_parallel_group())
_send_recv_meta.has_send_meta = _use_cache
if recv_prev and not _send_recv_meta.has_recv_meta:
_send_recv_meta.recv_meta(_hcg.get_pipe_parallel_group())
_send_recv_meta.has_recv_meta = _use_cache
input_tensor, _ = _p2p_helper(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=recv_prev,
recv_next=False,
sync_recv=False,
)
if _timers is not None:
_timers("send_forward_recv_forward").stop()
return input_tensor
def send_backward_recv_backward(input_tensor_grad, recv_next):
global _timers
if _timers is not None:
_timers("send_backward_recv_backward").start()
_, output_tensor_grad = _p2p_helper(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=False,
recv_next=recv_next,
sync_recv=False,
)
if _timers is not None:
_timers("send_backward_recv_backward").stop()
return output_tensor_grad
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册