From 9be2b7217c71f145da77e991838bea930c732622 Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Fri, 21 Oct 2022 14:58:57 +0800 Subject: [PATCH] Fix virtualpp with mp/recompute bugs (#47242) --- .../parallel_layers/pp_layers.py | 3 +- .../pp_utils/p2p_communication.py | 35 +++++++++++-------- 2 files changed, 22 insertions(+), 16 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py index 5defec96bff..f1a81239180 100755 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py @@ -598,11 +598,12 @@ class PipelineLayer(Layer): return run_function def forward_function(self, start, end): + run_function = self.run_function def execute_func(*x): if len(x) == 1: x = x[0] - for idx, layer in enumerate(self.run_function[start:end]): + for idx, layer in enumerate(run_function[start:end]): x = layer(x) return x diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index c9566b188a5..ecd3dc7ab91 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -168,17 +168,18 @@ def _is_valid_send_recv_partial(tensor, mp_degree): def _partial_send_op(tensor, group, use_calc_stream, ring_id, dst, nranks, rank_id): + dst_rank_in_group = dst if group is None else group.get_group_rank(dst) if _in_legacy_dygraph(): return _legacy_C_ops.partial_send(tensor.detach(), 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, - 'peer', dst, 'num', nranks, 'id', - rank_id) + 'peer', dst_rank_in_group, 'num', + nranks, 'id', rank_id) elif in_dygraph_mode(): group = paddle.distributed.collective._get_default_group( ) if group is None else group comm_op = group.process_group.send_partial_on_calc_stream \ if use_calc_stream else group.process_group.send_partial - return comm_op(tensor, dst, nranks, rank_id) + return comm_op(tensor, dst_rank_in_group, nranks, rank_id) def send_partial(tensor, @@ -192,12 +193,13 @@ def send_partial(tensor, return ring_id = 0 if group is None else group.id + dst_rank = _hcg._get_p2p_next_rank( + ) if dst == 1 else _hcg._get_p2p_prev_rank() + if _is_valid_send_recv_partial(tensor, nranks): - return _partial_send_op(tensor, group, use_calc_stream, ring_id, dst, - nranks, rank_id) + return _partial_send_op(tensor, group, use_calc_stream, ring_id, + dst_rank, nranks, rank_id) else: - dst_rank = _hcg._get_p2p_next_rank( - ) if dst == 1 else _hcg._get_p2p_prev_rank() if _in_legacy_dygraph(): send_op = lambda x, dst, group: \ paddle.distributed.send(x, dst, group, use_calc_stream) @@ -208,19 +210,21 @@ def send_partial(tensor, def _partial_recv_op(tensor, group, use_calc_stream, ring_id, src, nranks, rank_id): + src_rank_in_group = src if group is None else group.get_group_rank(src) if _in_legacy_dygraph(): assert use_calc_stream return _legacy_C_ops.partial_recv(tensor.detach(), 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, - 'peer', src, 'num', nranks, 'id', - rank_id, 'dtype', tensor.dtype, - 'out_shape', tensor.shape) + 'peer', src_rank_in_group, 'num', + nranks, 'id', rank_id, 'dtype', + tensor.dtype, 'out_shape', + tensor.shape) elif in_dygraph_mode(): group = paddle.distributed.collective._get_default_group( ) if group is None else group comm_op = group.process_group.recv_partial_on_calc_stream \ if use_calc_stream else group.process_group.recv_partial - return comm_op(tensor, src, nranks, rank_id) + return comm_op(tensor, src_rank_in_group, nranks, rank_id) def recv_partial(tensor, @@ -234,12 +238,13 @@ def recv_partial(tensor, return ring_id = 0 if group is None else group.id + src_rank = _hcg._get_p2p_prev_rank( + ) if src == 0 else _hcg._get_p2p_next_rank() + if _is_valid_send_recv_partial(tensor, nranks): - return _partial_recv_op(tensor, group, use_calc_stream, ring_id, src, - nranks, rank_id) + return _partial_recv_op(tensor, group, use_calc_stream, ring_id, + src_rank, nranks, rank_id) else: - src_rank = _hcg._get_p2p_prev_rank( - ) if src == 0 else _hcg._get_p2p_next_rank() if _in_legacy_dygraph() or use_calc_stream: recv_op = paddle.distributed.recv elif in_dygraph_mode(): -- GitLab