From 817f9ef061166793bc0616540f86a9593e750c7f Mon Sep 17 00:00:00 2001 From: caozhou <48191911+Caozhou1995@users.noreply.github.com> Date: Wed, 13 Oct 2021 14:56:35 +0800 Subject: [PATCH] fix pp comm init bug (#36377) --- python/paddle/distributed/auto_parallel/reshard.py | 5 ++++- .../fluid/tests/unittests/test_auto_parallel_reshard.py | 3 +++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/python/paddle/distributed/auto_parallel/reshard.py b/python/paddle/distributed/auto_parallel/reshard.py index d66d799c6e0..2d54bf8a788 100644 --- a/python/paddle/distributed/auto_parallel/reshard.py +++ b/python/paddle/distributed/auto_parallel/reshard.py @@ -662,7 +662,10 @@ def _concat_partitions_with_op(partition_tensor_list, tensor, partition_index, def _init_comm_for_send_recv(): - if not PROCESS_GROUP_MAP["global_group"].is_instantiate(): + if not PROCESS_GROUP_MAP: + genv = _get_global_env() + PROCESS_GROUP_MAP["global_group"] = ProcessGroup( + 0, list(range(genv.world_size))) PROCESS_GROUP_MAP["global_group"].instantiate() diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py index 89e9b7e817f..da82e56d4a1 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py @@ -27,6 +27,7 @@ from paddle.distributed import fleet from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.completion import complete_backward_annotation from paddle.distributed.auto_parallel.reshard import reshard +from paddle.distributed.auto_parallel.process import PROCESS_GROUP_MAP paddle.enable_static() _global_parallel_strategy = None @@ -254,6 +255,8 @@ class TestMLPReshard(unittest.TestCase): dist_main_prog, dist_startup_prog = get_dist_prog( train_program, startup_program, dist_context, rank_id) complete_backward_annotation(dist_main_prog, dist_context) + for key in list(PROCESS_GROUP_MAP.keys()): + del PROCESS_GROUP_MAP[key] reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context) # check send and recv result -- GitLab