diff --git a/python/paddle/distributed/auto_parallel/reshard.py b/python/paddle/distributed/auto_parallel/reshard.py index d66d799c6e0f9109c68f1f81eaa27fc0b6684070..2d54bf8a7887a3920a9ea1bcfaa19c943b05e212 100644 --- a/python/paddle/distributed/auto_parallel/reshard.py +++ b/python/paddle/distributed/auto_parallel/reshard.py @@ -662,7 +662,10 @@ def _concat_partitions_with_op(partition_tensor_list, tensor, partition_index, def _init_comm_for_send_recv(): - if not PROCESS_GROUP_MAP["global_group"].is_instantiate(): + if not PROCESS_GROUP_MAP: + genv = _get_global_env() + PROCESS_GROUP_MAP["global_group"] = ProcessGroup( + 0, list(range(genv.world_size))) PROCESS_GROUP_MAP["global_group"].instantiate() diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py index 89e9b7e817f45762486da2d836ee6766d2c24500..da82e56d4a1518a2e9730598db6bb053ca9a6305 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py @@ -27,6 +27,7 @@ from paddle.distributed import fleet from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.completion import complete_backward_annotation from paddle.distributed.auto_parallel.reshard import reshard +from paddle.distributed.auto_parallel.process import PROCESS_GROUP_MAP paddle.enable_static() _global_parallel_strategy = None @@ -254,6 +255,8 @@ class TestMLPReshard(unittest.TestCase): dist_main_prog, dist_startup_prog = get_dist_prog( train_program, startup_program, dist_context, rank_id) complete_backward_annotation(dist_main_prog, dist_context) + for key in list(PROCESS_GROUP_MAP.keys()): + del PROCESS_GROUP_MAP[key] reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context) # check send and recv result