diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py index 5defec96bff88f8a863c30c8bd5b326daef3af02..f1a812391801b48730b448a6a6c82da8a3af9fc1 100755 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py @@ -598,11 +598,12 @@ class PipelineLayer(Layer): return run_function def forward_function(self, start, end): + run_function = self.run_function def execute_func(*x): if len(x) == 1: x = x[0] - for idx, layer in enumerate(self.run_function[start:end]): + for idx, layer in enumerate(run_function[start:end]): x = layer(x) return x diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index c9566b188a54ee0494f73164fb66fc05dd42ebfb..ecd3dc7ab9136f9c3e66ae572f0ef5a180b50f26 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -168,17 +168,18 @@ def _is_valid_send_recv_partial(tensor, mp_degree): def _partial_send_op(tensor, group, use_calc_stream, ring_id, dst, nranks, rank_id): + dst_rank_in_group = dst if group is None else group.get_group_rank(dst) if _in_legacy_dygraph(): return _legacy_C_ops.partial_send(tensor.detach(), 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, - 'peer', dst, 'num', nranks, 'id', - rank_id) + 'peer', dst_rank_in_group, 'num', + nranks, 'id', rank_id) elif in_dygraph_mode(): group = paddle.distributed.collective._get_default_group( ) if group is None else group comm_op = group.process_group.send_partial_on_calc_stream \ if use_calc_stream else group.process_group.send_partial - return comm_op(tensor, dst, nranks, rank_id) + return comm_op(tensor, dst_rank_in_group, nranks, rank_id) def send_partial(tensor, @@ -192,12 +193,13 @@ def send_partial(tensor, return ring_id = 0 if group is None else group.id + dst_rank = _hcg._get_p2p_next_rank( + ) if dst == 1 else _hcg._get_p2p_prev_rank() + if _is_valid_send_recv_partial(tensor, nranks): - return _partial_send_op(tensor, group, use_calc_stream, ring_id, dst, - nranks, rank_id) + return _partial_send_op(tensor, group, use_calc_stream, ring_id, + dst_rank, nranks, rank_id) else: - dst_rank = _hcg._get_p2p_next_rank( - ) if dst == 1 else _hcg._get_p2p_prev_rank() if _in_legacy_dygraph(): send_op = lambda x, dst, group: \ paddle.distributed.send(x, dst, group, use_calc_stream) @@ -208,19 +210,21 @@ def send_partial(tensor, def _partial_recv_op(tensor, group, use_calc_stream, ring_id, src, nranks, rank_id): + src_rank_in_group = src if group is None else group.get_group_rank(src) if _in_legacy_dygraph(): assert use_calc_stream return _legacy_C_ops.partial_recv(tensor.detach(), 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, - 'peer', src, 'num', nranks, 'id', - rank_id, 'dtype', tensor.dtype, - 'out_shape', tensor.shape) + 'peer', src_rank_in_group, 'num', + nranks, 'id', rank_id, 'dtype', + tensor.dtype, 'out_shape', + tensor.shape) elif in_dygraph_mode(): group = paddle.distributed.collective._get_default_group( ) if group is None else group comm_op = group.process_group.recv_partial_on_calc_stream \ if use_calc_stream else group.process_group.recv_partial - return comm_op(tensor, src, nranks, rank_id) + return comm_op(tensor, src_rank_in_group, nranks, rank_id) def recv_partial(tensor, @@ -234,12 +238,13 @@ def recv_partial(tensor, return ring_id = 0 if group is None else group.id + src_rank = _hcg._get_p2p_prev_rank( + ) if src == 0 else _hcg._get_p2p_next_rank() + if _is_valid_send_recv_partial(tensor, nranks): - return _partial_recv_op(tensor, group, use_calc_stream, ring_id, src, - nranks, rank_id) + return _partial_recv_op(tensor, group, use_calc_stream, ring_id, + src_rank, nranks, rank_id) else: - src_rank = _hcg._get_p2p_prev_rank( - ) if src == 0 else _hcg._get_p2p_next_rank() if _in_legacy_dygraph() or use_calc_stream: recv_op = paddle.distributed.recv elif in_dygraph_mode():