未验证 提交 9be2b721 编写于 作者: Y Yuang Liu 提交者: GitHub

Fix virtualpp with mp/recompute bugs (#47242)

上级 a9ac608f
...@@ -598,11 +598,12 @@ class PipelineLayer(Layer): ...@@ -598,11 +598,12 @@ class PipelineLayer(Layer):
return run_function return run_function
def forward_function(self, start, end): def forward_function(self, start, end):
run_function = self.run_function
def execute_func(*x): def execute_func(*x):
if len(x) == 1: if len(x) == 1:
x = x[0] x = x[0]
for idx, layer in enumerate(self.run_function[start:end]): for idx, layer in enumerate(run_function[start:end]):
x = layer(x) x = layer(x)
return x return x
......
...@@ -168,17 +168,18 @@ def _is_valid_send_recv_partial(tensor, mp_degree): ...@@ -168,17 +168,18 @@ def _is_valid_send_recv_partial(tensor, mp_degree):
def _partial_send_op(tensor, group, use_calc_stream, ring_id, dst, nranks, def _partial_send_op(tensor, group, use_calc_stream, ring_id, dst, nranks,
rank_id): rank_id):
dst_rank_in_group = dst if group is None else group.get_group_rank(dst)
if _in_legacy_dygraph(): if _in_legacy_dygraph():
return _legacy_C_ops.partial_send(tensor.detach(), 'use_calc_stream', return _legacy_C_ops.partial_send(tensor.detach(), 'use_calc_stream',
use_calc_stream, 'ring_id', ring_id, use_calc_stream, 'ring_id', ring_id,
'peer', dst, 'num', nranks, 'id', 'peer', dst_rank_in_group, 'num',
rank_id) nranks, 'id', rank_id)
elif in_dygraph_mode(): elif in_dygraph_mode():
group = paddle.distributed.collective._get_default_group( group = paddle.distributed.collective._get_default_group(
) if group is None else group ) if group is None else group
comm_op = group.process_group.send_partial_on_calc_stream \ comm_op = group.process_group.send_partial_on_calc_stream \
if use_calc_stream else group.process_group.send_partial if use_calc_stream else group.process_group.send_partial
return comm_op(tensor, dst, nranks, rank_id) return comm_op(tensor, dst_rank_in_group, nranks, rank_id)
def send_partial(tensor, def send_partial(tensor,
...@@ -192,12 +193,13 @@ def send_partial(tensor, ...@@ -192,12 +193,13 @@ def send_partial(tensor,
return return
ring_id = 0 if group is None else group.id ring_id = 0 if group is None else group.id
if _is_valid_send_recv_partial(tensor, nranks):
return _partial_send_op(tensor, group, use_calc_stream, ring_id, dst,
nranks, rank_id)
else:
dst_rank = _hcg._get_p2p_next_rank( dst_rank = _hcg._get_p2p_next_rank(
) if dst == 1 else _hcg._get_p2p_prev_rank() ) if dst == 1 else _hcg._get_p2p_prev_rank()
if _is_valid_send_recv_partial(tensor, nranks):
return _partial_send_op(tensor, group, use_calc_stream, ring_id,
dst_rank, nranks, rank_id)
else:
if _in_legacy_dygraph(): if _in_legacy_dygraph():
send_op = lambda x, dst, group: \ send_op = lambda x, dst, group: \
paddle.distributed.send(x, dst, group, use_calc_stream) paddle.distributed.send(x, dst, group, use_calc_stream)
...@@ -208,19 +210,21 @@ def send_partial(tensor, ...@@ -208,19 +210,21 @@ def send_partial(tensor,
def _partial_recv_op(tensor, group, use_calc_stream, ring_id, src, nranks, def _partial_recv_op(tensor, group, use_calc_stream, ring_id, src, nranks,
rank_id): rank_id):
src_rank_in_group = src if group is None else group.get_group_rank(src)
if _in_legacy_dygraph(): if _in_legacy_dygraph():
assert use_calc_stream assert use_calc_stream
return _legacy_C_ops.partial_recv(tensor.detach(), 'use_calc_stream', return _legacy_C_ops.partial_recv(tensor.detach(), 'use_calc_stream',
use_calc_stream, 'ring_id', ring_id, use_calc_stream, 'ring_id', ring_id,
'peer', src, 'num', nranks, 'id', 'peer', src_rank_in_group, 'num',
rank_id, 'dtype', tensor.dtype, nranks, 'id', rank_id, 'dtype',
'out_shape', tensor.shape) tensor.dtype, 'out_shape',
tensor.shape)
elif in_dygraph_mode(): elif in_dygraph_mode():
group = paddle.distributed.collective._get_default_group( group = paddle.distributed.collective._get_default_group(
) if group is None else group ) if group is None else group
comm_op = group.process_group.recv_partial_on_calc_stream \ comm_op = group.process_group.recv_partial_on_calc_stream \
if use_calc_stream else group.process_group.recv_partial if use_calc_stream else group.process_group.recv_partial
return comm_op(tensor, src, nranks, rank_id) return comm_op(tensor, src_rank_in_group, nranks, rank_id)
def recv_partial(tensor, def recv_partial(tensor,
...@@ -234,12 +238,13 @@ def recv_partial(tensor, ...@@ -234,12 +238,13 @@ def recv_partial(tensor,
return return
ring_id = 0 if group is None else group.id ring_id = 0 if group is None else group.id
if _is_valid_send_recv_partial(tensor, nranks):
return _partial_recv_op(tensor, group, use_calc_stream, ring_id, src,
nranks, rank_id)
else:
src_rank = _hcg._get_p2p_prev_rank( src_rank = _hcg._get_p2p_prev_rank(
) if src == 0 else _hcg._get_p2p_next_rank() ) if src == 0 else _hcg._get_p2p_next_rank()
if _is_valid_send_recv_partial(tensor, nranks):
return _partial_recv_op(tensor, group, use_calc_stream, ring_id,
src_rank, nranks, rank_id)
else:
if _in_legacy_dygraph() or use_calc_stream: if _in_legacy_dygraph() or use_calc_stream:
recv_op = paddle.distributed.recv recv_op = paddle.distributed.recv
elif in_dygraph_mode(): elif in_dygraph_mode():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册