未验证 提交 5a9214d8 编写于 作者: S ShenLiang 提交者: GitHub

[Distributed]Fix cache p2p in pp (#56796)

* add usecache

* add p2p cache fix

* add cache
上级 dfcfc8b7
......@@ -199,11 +199,13 @@ class PipelineParallel(MetaParallelBase):
p2p.initialize_p2p_groups(
hcg,
self._using_cache,
self._enable_partial_send_recv,
self._enable_timer,
)
# construct pipeline meta info
self._p2p_helper = p2p.P2pHelper(self._using_cache)
self.global_rank = self._hcg.get_global_rank()
self.micro_batch_id = 0
......@@ -349,10 +351,14 @@ class PipelineParallel(MetaParallelBase):
micro_dataset = self._wrap_data(data)
for step_id in range(startup_steps):
input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
input_tensor = self._p2p_helper.recv_forward(
self.is_pipeline_first_stage()
)
output_tensor = self._forward_step(input_tensor, micro_dataset)
p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
self._p2p_helper.send_forward(
output_tensor, self.is_pipeline_last_stage()
)
input_buffers.append(input_tensor)
output_buffers.append(output_tensor)
......@@ -361,14 +367,16 @@ class PipelineParallel(MetaParallelBase):
self._release_output(output_tensor)
if steady_steps > 0:
input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
input_tensor = self._p2p_helper.recv_forward(
self.is_pipeline_first_stage()
)
for i in range(steady_steps):
last_iter = i == (steady_steps - 1)
output_tensor = self._forward_step(input_tensor, micro_dataset)
output_tensor_grad = p2p.send_forward_recv_backward(
output_tensor_grad = self._p2p_helper.send_forward_recv_backward(
output_tensor, self.is_pipeline_last_stage()
)
......@@ -388,11 +396,11 @@ class PipelineParallel(MetaParallelBase):
if last_iter:
input_tensor = None
p2p.send_backward(
self._p2p_helper.send_backward(
input_tensor_grad, self.is_pipeline_first_stage()
)
else:
input_tensor = p2p.send_backward_recv_forward(
input_tensor = self._p2p_helper.send_backward_recv_forward(
input_tensor_grad, self.is_pipeline_first_stage()
)
......@@ -400,14 +408,16 @@ class PipelineParallel(MetaParallelBase):
input_tensor = input_buffers.pop(0)
output_tensor = output_buffers.pop(0)
output_tensor_grad = p2p.recv_backward(
output_tensor_grad = self._p2p_helper.recv_backward(
self.is_pipeline_last_stage()
)
input_tensor_grad = self._backward_step(
input_tensor, output_tensor, output_tensor_grad
)
p2p.send_backward(input_tensor_grad, self.is_pipeline_first_stage())
self._p2p_helper.send_backward(
input_tensor_grad, self.is_pipeline_first_stage()
)
if self._comm_overlap:
assert (
......@@ -513,28 +523,38 @@ class PipelineParallel(MetaParallelBase):
micro_dataset = self._wrap_data(data)
for step_id in range(startup_steps):
input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
input_tensor = self._p2p_helper.recv_forward(
self.is_pipeline_first_stage()
)
output_tensor = self._forward_step(input_tensor, micro_dataset)
p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
self._p2p_helper.send_forward(
output_tensor, self.is_pipeline_last_stage()
)
input_buffers.append(input_tensor)
output_buffers.append(output_tensor)
if steady_steps > 0:
input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
input_tensor = self._p2p_helper.recv_forward(
self.is_pipeline_first_stage()
)
for i in range(steady_steps):
last_iter = i == (steady_steps - 1)
output_tensor = self._forward_step(input_tensor, micro_dataset)
p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
self._p2p_helper.send_forward(
output_tensor, self.is_pipeline_last_stage()
)
input_buffers.append(input_tensor)
output_buffers.append(output_tensor)
if not last_iter:
input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
input_tensor = self._p2p_helper.recv_forward(
self.is_pipeline_first_stage()
)
if self._compute_loss:
self.train_loss = self._broadcast_final_loss()
......@@ -859,6 +879,11 @@ class PipelineParallelWithInterleave(PipelineParallel):
not forward_only
), "compute_loss can only be set to False when forward_only is set to True"
# NOTE(shenliang03): Due to ring_exchange for pipeline with interleave, cache should be enabled
assert (
self._using_cache
), "cache should be enabled for pipeline with interleave"
# init some attributes for this batch run
self.scaler = scaler
self.total_loss = None
......@@ -904,7 +929,9 @@ class PipelineParallelWithInterleave(PipelineParallel):
self.set_virtual_pipeline_rank(0)
self.input_tensors[0].append(
p2p.recv_forward(self.is_pipeline_first_stage(), sync_recv=False)
self._p2p_helper.recv_forward(
self.is_pipeline_first_stage(), sync_recv=False
)
)
# run startup steps
......@@ -942,7 +969,7 @@ class PipelineParallelWithInterleave(PipelineParallel):
(
input_tensor,
output_tensor_grad,
) = p2p.send_forward_backward_recv_forward_backward(
) = self._p2p_helper.send_forward_backward_recv_forward_backward(
output_tensor,
input_tensor_grad,
recv_prev=recv_prev,
......@@ -952,7 +979,7 @@ class PipelineParallelWithInterleave(PipelineParallel):
output_tensor_grad
)
else:
input_tensor = p2p.send_forward_recv_forward(
input_tensor = self._p2p_helper.send_forward_recv_forward(
output_tensor, recv_prev=recv_prev
)
self.input_tensors[next_virtual_pp_rank].append(input_tensor)
......@@ -1033,7 +1060,7 @@ class PipelineParallelWithInterleave(PipelineParallel):
(
input_tensor,
output_tensor_grad,
) = p2p.send_forward_backward_recv_forward_backward(
) = self._p2p_helper.send_forward_backward_recv_forward_backward(
output_tensor,
input_tensor_grad,
recv_prev=recv_prev,
......@@ -1057,7 +1084,7 @@ class PipelineParallelWithInterleave(PipelineParallel):
if not forward_only:
if all_startup_steps:
self.output_tensor_grads[self.num_model_chunks - 1].append(
p2p.recv_backward(
self._p2p_helper.recv_backward(
self.is_pipeline_last_stage(), sync_recv=False
)
)
......@@ -1080,7 +1107,7 @@ class PipelineParallelWithInterleave(PipelineParallel):
recv_next = False
self.output_tensor_grads[next_backward_virtual_pp_rank].append(
p2p.send_backward_recv_backward(
self._p2p_helper.send_backward_recv_backward(
input_tensor_grad, recv_next=recv_next
)
)
......
......@@ -31,17 +31,16 @@ from ...utils import timer_helper as timer
from .utils import number_2_dtype, paddle_2_number
_hcg = None
_use_cache = False
# _use_cache = False
_enable_partial_send_recv = True
_timers = None
def initialize_p2p_groups(
hcg, use_cache=True, enable_partial_send_recv=True, enable_timer=False
hcg, enable_partial_send_recv=True, enable_timer=False
):
global _hcg, _use_cache, _enable_partial_send_recv, _timers
global _hcg, _enable_partial_send_recv, _timers
_hcg = hcg
_use_cache = use_cache
_enable_partial_send_recv = enable_partial_send_recv
if enable_timer:
_timers = timer.get_timers()
......@@ -170,8 +169,14 @@ class SendRecvMeta:
]
)
_send_recv_meta = SendRecvMeta()
def __repr__(self):
return "send_shape_message: {}, send_dtype_message: {}, recv_shape_message: {}, recv_dtype_message: {}, recv_stop_gradient: {}".format(
self.send_shape_message,
self.send_dtype_message,
self.recv_shape_message,
self.recv_dtype_message,
self.recv_stop_gradient,
)
def _is_valid_send_recv_partial(tensor, mp_degree):
......@@ -303,7 +308,12 @@ def _process_p2p_tuple_or_tensor(
def _p2p_helper(
tensor_send_next, tensor_send_prev, recv_prev, recv_next, sync_recv=True
tensor_send_next,
tensor_send_prev,
recv_prev,
recv_next,
sync_recv=True,
send_recv_meta=None,
):
global _hcg
......@@ -311,12 +321,13 @@ def _p2p_helper(
tensor_recv_next = None
# send / recv message
recv_shape_msg = _send_recv_meta.recv_shape_message
recv_dtype_msg = _send_recv_meta.recv_dtype_message
recv_stop_gradient = _send_recv_meta.recv_stop_gradient
assert send_recv_meta is not None, "send_recv_meta should not be None"
recv_shape_msg = send_recv_meta.recv_shape_message
recv_dtype_msg = send_recv_meta.recv_dtype_message
recv_stop_gradient = send_recv_meta.recv_stop_gradient
send_shape_msg = _send_recv_meta.send_shape_message
send_dtype_msg = _send_recv_meta.send_dtype_message
send_shape_msg = send_recv_meta.send_shape_message
send_dtype_msg = send_recv_meta.send_dtype_message
# model parallel message
mp_group = _hcg.get_model_parallel_group()
......@@ -441,183 +452,195 @@ def _p2p_helper(
return tensor_recv_prev, tensor_recv_next
def recv_forward(pp_first_stage, sync_recv=True):
global _timers
if _timers is not None:
_timers("recv_forward").start()
if pp_first_stage:
input_tensor = None
else:
if not _send_recv_meta.has_recv_meta:
_send_recv_meta.recv_meta(_hcg.get_pipe_parallel_group())
_send_recv_meta.has_recv_meta = _use_cache
input_tensor, _ = _p2p_helper(
tensor_send_next=None,
tensor_send_prev=None,
recv_prev=True,
recv_next=False,
sync_recv=sync_recv,
)
if _timers is not None:
_timers("recv_forward").stop()
return input_tensor
class P2pHelper:
def __init__(self, use_cache=True):
self._send_recv_meta = SendRecvMeta()
self._use_cache = use_cache
def recv_backward(pp_last_stage, sync_recv=True):
global _timers
if _timers is not None:
_timers("recv_backward").start()
if pp_last_stage:
output_tensor_grad = None
else:
_, output_tensor_grad = _p2p_helper(
tensor_send_next=None,
tensor_send_prev=None,
recv_prev=False,
recv_next=True,
sync_recv=sync_recv,
)
if _timers is not None:
_timers("recv_backward").stop()
return output_tensor_grad
def send_forward(output_tensor, pp_last_stage):
global _timers
if _timers is not None:
_timers("send_forward").start()
if not pp_last_stage:
if not _send_recv_meta.has_send_meta:
_send_recv_meta.set_send_message(output_tensor)
_send_recv_meta.send_meta(
def _send_meta(self, output_tensor):
if not self._send_recv_meta.has_send_meta:
self._send_recv_meta.set_send_message(output_tensor)
self._send_recv_meta.send_meta(
output_tensor, _hcg.get_pipe_parallel_group()
)
_send_recv_meta.has_send_meta = _use_cache
_p2p_helper(
self._send_recv_meta.has_send_meta = self._use_cache
def _recv_meta(self):
if not self._send_recv_meta.has_recv_meta:
self._send_recv_meta.recv_meta(_hcg.get_pipe_parallel_group())
self._send_recv_meta.has_recv_meta = self._use_cache
def recv_forward(self, pp_first_stage, sync_recv=True):
global _timers
if _timers is not None:
_timers("recv_forward").start()
if pp_first_stage:
input_tensor = None
else:
self._recv_meta()
input_tensor, _ = _p2p_helper(
tensor_send_next=None,
tensor_send_prev=None,
recv_prev=True,
recv_next=False,
sync_recv=sync_recv,
send_recv_meta=self._send_recv_meta,
)
if _timers is not None:
_timers("recv_forward").stop()
return input_tensor
def recv_backward(self, pp_last_stage, sync_recv=True):
global _timers
if _timers is not None:
_timers("recv_backward").start()
if pp_last_stage:
output_tensor_grad = None
else:
_, output_tensor_grad = _p2p_helper(
tensor_send_next=None,
tensor_send_prev=None,
recv_prev=False,
recv_next=True,
sync_recv=sync_recv,
send_recv_meta=self._send_recv_meta,
)
if _timers is not None:
_timers("recv_backward").stop()
return output_tensor_grad
def send_forward(self, output_tensor, pp_last_stage):
global _timers
if _timers is not None:
_timers("send_forward").start()
if not pp_last_stage:
self._send_meta(output_tensor)
_p2p_helper(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=False,
recv_next=False,
send_recv_meta=self._send_recv_meta,
)
if _timers is not None:
_timers("send_forward").stop()
def send_backward(self, input_tensor_grad, pp_first_stage):
global _timers
if _timers is not None:
_timers("send_backward").start()
if not pp_first_stage:
_p2p_helper(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=False,
recv_next=False,
send_recv_meta=self._send_recv_meta,
)
if _timers is not None:
_timers("send_backward").stop()
def send_forward_recv_backward(self, output_tensor, pp_last_stage):
global _timers
if _timers is not None:
_timers("send_forward_recv_backward").start()
if pp_last_stage:
output_tensor_grad = None
else:
_, output_tensor_grad = _p2p_helper(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=False,
recv_next=True,
send_recv_meta=self._send_recv_meta,
)
if _timers is not None:
_timers("send_forward_recv_backward").stop()
return output_tensor_grad
def send_backward_recv_forward(self, input_tensor_grad, pp_first_stage):
global _timers
if _timers is not None:
_timers("send_backward_recv_forward").start()
if pp_first_stage:
input_tensor = None
else:
input_tensor, _ = _p2p_helper(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=True,
recv_next=False,
send_recv_meta=self._send_recv_meta,
)
if _timers is not None:
_timers("send_backward_recv_forward").stop()
return input_tensor
def send_forward_backward_recv_forward_backward(
self, output_tensor, input_tensor_grad, recv_prev, recv_next
):
# always have to send dytpe info to downstream
global _timers
if _timers is not None:
_timers("send_forward_backward_recv_forward_backward").start()
self._send_meta(output_tensor)
if recv_prev:
self._recv_meta()
input_tensor, output_tensor_grad = _p2p_helper(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=False,
recv_next=False,
)
if _timers is not None:
_timers("send_forward").stop()
def send_backward(input_tensor_grad, pp_first_stage):
global _timers
if _timers is not None:
_timers("send_backward").start()
if not pp_first_stage:
_p2p_helper(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=False,
recv_next=False,
recv_prev=recv_prev,
recv_next=recv_next,
sync_recv=False,
send_recv_meta=self._send_recv_meta,
)
if _timers is not None:
_timers("send_backward").stop()
if _timers is not None:
_timers("send_forward_backward_recv_forward_backward").stop()
return input_tensor, output_tensor_grad
def send_forward_recv_forward(self, output_tensor, recv_prev):
# always have to send dytpe info to downstream
global _timers
if _timers is not None:
_timers("send_forward_recv_forward").start()
def send_forward_recv_backward(output_tensor, pp_last_stage):
global _timers
if _timers is not None:
_timers("send_forward_recv_backward").start()
if pp_last_stage:
output_tensor_grad = None
else:
_, output_tensor_grad = _p2p_helper(
self._send_meta(output_tensor)
if recv_prev:
self._recv_meta()
input_tensor, _ = _p2p_helper(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=False,
recv_next=True,
recv_prev=recv_prev,
recv_next=False,
sync_recv=False,
send_recv_meta=self._send_recv_meta,
)
if _timers is not None:
_timers("send_forward_recv_backward").stop()
return output_tensor_grad
def send_backward_recv_forward(input_tensor_grad, pp_first_stage):
global _timers
if _timers is not None:
_timers("send_backward_recv_forward").start()
if pp_first_stage:
input_tensor = None
else:
input_tensor, _ = _p2p_helper(
if _timers is not None:
_timers("send_forward_recv_forward").stop()
return input_tensor
def send_backward_recv_backward(self, input_tensor_grad, recv_next):
global _timers
if _timers is not None:
_timers("send_backward_recv_backward").start()
_, output_tensor_grad = _p2p_helper(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=True,
recv_next=False,
recv_prev=False,
recv_next=recv_next,
sync_recv=False,
send_recv_meta=self._send_recv_meta,
)
if _timers is not None:
_timers("send_backward_recv_forward").stop()
return input_tensor
def send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad, recv_prev, recv_next
):
# always have to send dytpe info to downstream
global _timers
if _timers is not None:
_timers("send_forward_backward_recv_forward_backward").start()
if not _send_recv_meta.has_send_meta:
_send_recv_meta.set_send_message(output_tensor)
_send_recv_meta.send_meta(output_tensor, _hcg.get_pipe_parallel_group())
_send_recv_meta.has_send_meta = _use_cache
if recv_prev and not _send_recv_meta.has_recv_meta:
_send_recv_meta.recv_meta(_hcg.get_pipe_parallel_group())
_send_recv_meta.has_recv_meta = _use_cache
input_tensor, output_tensor_grad = _p2p_helper(
tensor_send_next=output_tensor,
tensor_send_prev=input_tensor_grad,
recv_prev=recv_prev,
recv_next=recv_next,
sync_recv=False,
)
if _timers is not None:
_timers("send_forward_backward_recv_forward_backward").stop()
return input_tensor, output_tensor_grad
def send_forward_recv_forward(output_tensor, recv_prev):
# always have to send dytpe info to downstream
global _timers
if _timers is not None:
_timers("send_forward_recv_forward").start()
if not _send_recv_meta.has_send_meta:
_send_recv_meta.set_send_message(output_tensor)
_send_recv_meta.send_meta(output_tensor, _hcg.get_pipe_parallel_group())
_send_recv_meta.has_send_meta = _use_cache
if recv_prev and not _send_recv_meta.has_recv_meta:
_send_recv_meta.recv_meta(_hcg.get_pipe_parallel_group())
_send_recv_meta.has_recv_meta = _use_cache
input_tensor, _ = _p2p_helper(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=recv_prev,
recv_next=False,
sync_recv=False,
)
if _timers is not None:
_timers("send_forward_recv_forward").stop()
return input_tensor
def send_backward_recv_backward(input_tensor_grad, recv_next):
global _timers
if _timers is not None:
_timers("send_backward_recv_backward").start()
_, output_tensor_grad = _p2p_helper(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=False,
recv_next=recv_next,
sync_recv=False,
)
if _timers is not None:
_timers("send_backward_recv_backward").stop()
return output_tensor_grad
if _timers is not None:
_timers("send_backward_recv_backward").stop()
return output_tensor_grad
def __repr__(self):
debug_str = f"using cache: {self._use_cache} \n"
debug_str += repr(self._send_recv_meta)
return debug_str
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册