未验证 提交 817f9ef0 编写于 作者: C caozhou 提交者: GitHub

fix pp comm init bug (#36377)

上级 192e08cb
......@@ -662,7 +662,10 @@ def _concat_partitions_with_op(partition_tensor_list, tensor, partition_index,
def _init_comm_for_send_recv():
if not PROCESS_GROUP_MAP["global_group"].is_instantiate():
if not PROCESS_GROUP_MAP:
genv = _get_global_env()
PROCESS_GROUP_MAP["global_group"] = ProcessGroup(
0, list(range(genv.world_size)))
PROCESS_GROUP_MAP["global_group"].instantiate()
......
......@@ -27,6 +27,7 @@ from paddle.distributed import fleet
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.completion import complete_backward_annotation
from paddle.distributed.auto_parallel.reshard import reshard
from paddle.distributed.auto_parallel.process import PROCESS_GROUP_MAP
paddle.enable_static()
_global_parallel_strategy = None
......@@ -254,6 +255,8 @@ class TestMLPReshard(unittest.TestCase):
dist_main_prog, dist_startup_prog = get_dist_prog(
train_program, startup_program, dist_context, rank_id)
complete_backward_annotation(dist_main_prog, dist_context)
for key in list(PROCESS_GROUP_MAP.keys()):
del PROCESS_GROUP_MAP[key]
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context)
# check send and recv result
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册