未验证 提交 9cc3f69f 编写于 作者: Y Yuang Liu 提交者: GitHub

Cherry pick for dygraph pp (#46876)

* bug fix for virtual pipeline parallel (#45922)

* dont wait for send op under dygraph pp (#46209)

* [interleave pp] sync recv for 1f1b (#46399)

* [dygraph pp] all sync for allgather partial (#46483)
上级 6a6c7493
......@@ -378,7 +378,7 @@ class PipelineLayer(Layer):
for virtual_pp_rank in range(self._num_virtual_pipeline_stages):
# Mapping the virtual pipeline stage to the real pipeline stage.
# start_idx marks the start of a new virtual pp stage.
start_idx = virtual_pp_rank * self._num_virtual_pipeline_stages
start_idx = virtual_pp_rank * self._num_stages
for stage in range(self._num_stages):
# stage mark the real pp stage
if self.segment_parts[start_idx +
......@@ -484,7 +484,7 @@ class PipelineLayer(Layer):
", ".join(str(arg) for arg in self.segment_parts))
for i in range(self._stage_id, self._total_stages_with_virtual_stages,
self._num_virtual_pipeline_stages):
self._num_stages):
# If there are 2 real pp stages and 2 virtual pp stages, and the model has 8 layers.
# Layers [0, 1], [4, 5] will be assigned to the first real pp stage.
# Layers [2, 3], [6, 7] will be assigned to the second real pp stage.
......@@ -529,7 +529,7 @@ class PipelineLayer(Layer):
stage_to_virtual_stage_info = "stage {} contains virtual stages: ".format(
stage)
for i in range(stage, self._total_stages_with_virtual_stages,
self._num_virtual_pipeline_stages):
self._num_stages):
stage_to_virtual_stage_info += " {},".format(i)
logger.info(stage_to_virtual_stage_info)
......
......@@ -526,7 +526,7 @@ class PipelineParallelWithInterleave(PipelineParallel):
self.set_virtual_pipeline_rank(0)
self.input_tensors[0].append(
p2p.recv_forward(self.is_pipeline_first_stage()))
p2p.recv_forward(self.is_pipeline_first_stage(), sync_recv=False))
# run startup steps
for micro_step in range(startup_steps):
......@@ -647,7 +647,8 @@ class PipelineParallelWithInterleave(PipelineParallel):
if not forward_only:
if all_startup_steps:
self.output_tensor_grads[self.num_model_chunks - 1].append(
p2p.recv_backward(self.is_pipeline_last_stage()))
p2p.recv_backward(self.is_pipeline_last_stage(),
sync_recv=False))
for micro_step in range(steady_steps, num_steps):
# cooldown loop
......
......@@ -165,17 +165,15 @@ def _is_valid_send_recv_partial(tensor, mp_degree):
def _partial_send_op(tensor, group, use_calc_stream, ring_id, dst, nranks,
rank_id):
dst_rank_in_group = dst if group is None else group.get_group_rank(dst)
if _in_legacy_dygraph():
return _legacy_C_ops.partial_send(tensor.detach(), 'use_calc_stream',
use_calc_stream, 'ring_id', ring_id,
'peer', dst_rank_in_group, 'num',
nranks, 'id', rank_id)
'peer', dst, 'num', nranks, 'id',
rank_id)
elif in_dygraph_mode():
group = paddle.distributed.collective._get_default_group(
) if group is None else group
return group.process_group.send_partial(tensor, dst_rank_in_group,
nranks, rank_id)
return group.process_group.send_partial(tensor, dst, nranks, rank_id)
def send_partial(tensor,
......@@ -189,13 +187,12 @@ def send_partial(tensor,
return
ring_id = 0 if group is None else group.id
dst_rank = _hcg._get_p2p_next_rank(
) if dst == 1 else _hcg._get_p2p_prev_rank()
if _is_valid_send_recv_partial(tensor, nranks):
return _partial_send_op(tensor, group, use_calc_stream, ring_id,
dst_rank, nranks, rank_id)
return _partial_send_op(tensor, group, use_calc_stream, ring_id, dst,
nranks, rank_id)
else:
dst_rank = _hcg._get_p2p_next_rank(
) if dst == 1 else _hcg._get_p2p_prev_rank()
if _in_legacy_dygraph():
send_op = paddle.distributed.send
elif in_dygraph_mode():
......@@ -205,19 +202,22 @@ def send_partial(tensor,
def _partial_recv_op(tensor, group, use_calc_stream, ring_id, src, nranks,
rank_id):
src_rank_in_group = src if group is None else group.get_group_rank(src)
if _in_legacy_dygraph():
assert use_calc_stream
return _legacy_C_ops.partial_recv(tensor.detach(), 'use_calc_stream',
use_calc_stream, 'ring_id', ring_id,
'peer', src_rank_in_group, 'num',
nranks, 'id', rank_id, 'dtype',
tensor.dtype, 'out_shape',
tensor.shape)
'peer', src, 'num', nranks, 'id',
rank_id, 'dtype', tensor.dtype,
'out_shape', tensor.shape)
elif in_dygraph_mode():
group = paddle.distributed.collective._get_default_group(
) if group is None else group
return group.process_group.recv_partial(tensor, src_rank_in_group,
nranks, rank_id)
task = group.process_group.recv_partial(tensor, src, nranks, rank_id)
if use_calc_stream:
task.wait()
return None
else:
return task
def recv_partial(tensor,
......@@ -231,14 +231,13 @@ def recv_partial(tensor,
return
ring_id = 0 if group is None else group.id
src_rank = _hcg._get_p2p_prev_rank(
) if src == 0 else _hcg._get_p2p_next_rank()
if _is_valid_send_recv_partial(tensor, nranks):
return _partial_recv_op(tensor, group, use_calc_stream, ring_id,
src_rank, nranks, rank_id)
return _partial_recv_op(tensor, group, use_calc_stream, ring_id, src,
nranks, rank_id)
else:
if _in_legacy_dygraph():
src_rank = _hcg._get_p2p_prev_rank(
) if src == 0 else _hcg._get_p2p_next_rank()
if _in_legacy_dygraph() or use_calc_stream:
recv_op = paddle.distributed.recv
elif in_dygraph_mode():
recv_op = paddle.distributed.irecv
......@@ -256,8 +255,13 @@ def _partial_allgather_op(tensor, group, use_calc_stream, ring_id, nranks,
elif in_dygraph_mode():
group = paddle.distributed.collective._get_default_group(
) if group is None else group
return group.process_group.all_gather_partial(tensor, tensor, nranks,
task = group.process_group.all_gather_partial(tensor, tensor, nranks,
rank_id)
if use_calc_stream:
task.wait()
return None
else:
return task
def allgather_partial(tensor,
......@@ -266,16 +270,20 @@ def allgather_partial(tensor,
group=None,
use_calc_stream=True):
if not _is_valid_send_recv_partial(tensor, nranks):
return None
return tensor
if group is not None and not group.is_member():
return None
return
ring_id = 0 if group is None else group.id
return _partial_allgather_op(tensor, group, use_calc_stream, ring_id,
nranks, rank_id)
def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
def _p2p_helper(tensor_send_next,
tensor_send_prev,
recv_prev,
recv_next,
sync_recv=True):
global _hcg
tensor_recv_prev = None
......@@ -327,90 +335,111 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
if tensor_send_prev is not None:
if isinstance(tensor_send_prev, tuple):
for d in tensor_send_prev:
if _in_legacy_dygraph():
paddle.distributed.wait(d, use_calc_stream=True)
tasks.append(
send_partial(d,
dst=0,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.send_prev_group,
use_calc_stream=False))
use_calc_stream=False)
else:
if _in_legacy_dygraph():
paddle.distributed.wait(tensor_send_prev, use_calc_stream=True)
tasks.append(
send_partial(tensor_send_prev,
dst=0,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.send_prev_group,
use_calc_stream=False))
use_calc_stream=False)
if tensor_recv_prev is not None:
if isinstance(tensor_recv_prev, tuple):
for d in tensor_recv_prev:
tasks.append(
recv_partial(d,
task = recv_partial(d,
src=0,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.recv_prev_group,
use_calc_stream=True))
use_calc_stream=sync_recv)
if sync_recv:
allgather_partial(d,
nranks=mp_degree,
rank_id=mp_rank,
group=mp_group,
use_calc_stream=True)
else:
tasks.append(
recv_partial(tensor_recv_prev,
tasks.append(task)
else:
task = recv_partial(tensor_recv_prev,
src=0,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.recv_prev_group,
use_calc_stream=True))
use_calc_stream=sync_recv)
if sync_recv:
allgather_partial(tensor_recv_prev,
nranks=mp_degree,
rank_id=mp_rank,
group=mp_group,
use_calc_stream=True)
else:
tasks.append(task)
if tensor_send_next is not None:
if isinstance(tensor_send_next, tuple):
for d in tensor_send_next:
if _in_legacy_dygraph():
paddle.distributed.wait(d, use_calc_stream=True)
tasks.append(
send_partial(d,
dst=1,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.send_next_group,
use_calc_stream=False))
use_calc_stream=False)
else:
if _in_legacy_dygraph():
paddle.distributed.wait(tensor_send_next, use_calc_stream=True)
tasks.append(
send_partial(tensor_send_next,
dst=1,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.send_next_group,
use_calc_stream=False))
use_calc_stream=False)
if tensor_recv_next is not None:
if isinstance(tensor_recv_next, tuple):
for d in tensor_recv_next:
tasks.append(
recv_partial(d,
task = recv_partial(d,
src=1,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.recv_next_group,
use_calc_stream=True))
use_calc_stream=sync_recv)
if sync_recv:
allgather_partial(d,
nranks=mp_degree,
rank_id=mp_rank,
group=mp_group,
use_calc_stream=True)
else:
tasks.append(task)
else:
tasks.append(
recv_partial(tensor_recv_next,
task = recv_partial(tensor_recv_next,
src=1,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.recv_next_group,
use_calc_stream=True))
use_calc_stream=sync_recv)
if sync_recv:
allgather_partial(tensor_recv_next,
nranks=mp_degree,
rank_id=mp_rank,
group=mp_group,
use_calc_stream=True)
else:
tasks.append(task)
if not sync_recv:
if in_dygraph_mode():
# wait isend/irecv tasks in eager dygraph mode with new comm library
# wait irecv tasks in eager dygraph mode with new comm library
for task in tasks:
assert task is not None
task.wait()
......@@ -429,24 +458,17 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
else:
tensors_for_all_gather.append(tensor_recv_next)
tasks = []
for tensor in tensors_for_all_gather:
tasks.append(
allgather_partial(tensor,
nranks=mp_degree,
rank_id=mp_rank,
group=mp_group,
use_calc_stream=True))
for task in tasks:
# wait partial all gather tasks
if task is not None:
task.wait()
use_calc_stream=True)
return tensor_recv_prev, tensor_recv_next
def recv_forward(pp_first_stage):
def recv_forward(pp_first_stage, sync_recv=True):
if pp_first_stage:
input_tensor = None
else:
......@@ -457,18 +479,20 @@ def recv_forward(pp_first_stage):
input_tensor, _ = _p2p_helper(tensor_send_next=None,
tensor_send_prev=None,
recv_prev=True,
recv_next=False)
recv_next=False,
sync_recv=sync_recv)
return input_tensor
def recv_backward(pp_last_stage):
def recv_backward(pp_last_stage, sync_recv=True):
if pp_last_stage:
output_tensor_grad = None
else:
_, output_tensor_grad = _p2p_helper(tensor_send_next=None,
tensor_send_prev=None,
recv_prev=False,
recv_next=True)
recv_next=True,
sync_recv=sync_recv)
return output_tensor_grad
......@@ -530,7 +554,8 @@ def send_forward_backward_recv_forward_backward(output_tensor,
tensor_send_next=output_tensor,
tensor_send_prev=input_tensor_grad,
recv_prev=recv_prev,
recv_next=recv_next)
recv_next=recv_next,
sync_recv=False)
return input_tensor, output_tensor_grad
......@@ -547,7 +572,8 @@ def send_forward_recv_forward(output_tensor, recv_prev):
input_tensor, _ = _p2p_helper(tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=recv_prev,
recv_next=False)
recv_next=False,
sync_recv=False)
return input_tensor
......@@ -556,5 +582,6 @@ def send_backward_recv_backward(input_tensor_grad, recv_next):
_, output_tensor_grad = _p2p_helper(tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=False,
recv_next=recv_next)
recv_next=recv_next,
sync_recv=False)
return output_tensor_grad
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册