diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index c613110612271f921956010e29cf86064f1b772f..f81164b778cc27158b44fa573d4dedfcf9a698dd 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -13,7 +13,6 @@ # limitations under the License. import paddle -import paddle.distributed as dist _groups = None _hcg = None @@ -21,7 +20,10 @@ _hcg = None def initialize_p2p_groups(hcg): global _groups, _hcg - _groups = [dist.new_group(ranks=group) for group in hcg.get_p2p_groups()] + _groups = [ + paddle.distributed.new_group(ranks=group) + for group in hcg.get_p2p_groups() + ] _hcg = hcg @@ -33,7 +35,7 @@ def send(tensor, dest_stage): _is_valid_communciate(src_stage, dest_stage) group = _get_send_recv_group(src_stage, dest_stage) dst_rank = _hcg.get_rank_from_stage(stage_id=dest_stage) - return dist.broadcast(tensor, src_rank, group=group) + return paddle.distributed.broadcast(tensor, src_rank, group=group) def recv(tensor, src_stage): @@ -43,7 +45,7 @@ def recv(tensor, src_stage): _is_valid_communciate(src_stage, dest_stage) group = _get_send_recv_group(src_stage, dest_stage) src_rank = _hcg.get_rank_from_stage(stage_id=src_stage) - return dist.broadcast(tensor, src_rank, group=group) + return paddle.distributed.broadcast(tensor, src_rank, group=group) def _is_valid_communciate(src_stage, dest_stage):