未验证 提交 a28a2026 编写于 作者: W WangXi 提交者: GitHub

fix test_gen_nccl_id_op failed (#30686)

上级 16427570
......@@ -16,6 +16,7 @@ import unittest
import os
import copy
from launch_function_helper import wait, _find_free_port
from multiprocessing import Pool, Process
from threading import Thread
os.environ['GLOG_vmodule'] = str("gen_nccl_id_op*=10,gen_comm_id*=10")
......@@ -30,8 +31,8 @@ def run_gen_ncc_id(attr):
nccl_comm_num = attr['nccl_comm_num']
use_hallreduce = attr['use_hierarchical_allreduce']
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
startup_program = paddle.static.default_startup_program()
main_program = paddle.static.default_main_program()
with paddle.static.program_guard(main_program, startup_program):
nccl_id_var = startup_program.global_block().create_var(
......@@ -62,9 +63,7 @@ def run_gen_ncc_id(attr):
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
scope = paddle.static.Scope()
with paddle.static.scope_guard(scope):
exe.run(startup_program)
exe.run(startup_program)
class TestGenNcclIdOp(unittest.TestCase):
......@@ -99,13 +98,12 @@ class TestGenNcclIdOp(unittest.TestCase):
procs = []
for i in range(nranks):
attr['trainer_id'] = i
# NOTE. multiprocessing cannot be covered by coverage
p = Thread(target=run_gen_ncc_id, args=(copy.copy(attr), ))
# NOTE: multiprocessing cannot be covered by coverage
p = Process(target=run_gen_ncc_id, args=(attr, ))
p.start()
procs.append(p)
for p in procs:
p.join()
wait(procs, timeout=120)
def test_flat(self):
print(">>> test gen flat nccl id")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册