未验证 提交 a28a2026 编写于 作者: W WangXi 提交者: GitHub

fix test_gen_nccl_id_op failed (#30686)

上级 16427570
...@@ -16,6 +16,7 @@ import unittest ...@@ -16,6 +16,7 @@ import unittest
import os import os
import copy import copy
from launch_function_helper import wait, _find_free_port from launch_function_helper import wait, _find_free_port
from multiprocessing import Pool, Process
from threading import Thread from threading import Thread
os.environ['GLOG_vmodule'] = str("gen_nccl_id_op*=10,gen_comm_id*=10") os.environ['GLOG_vmodule'] = str("gen_nccl_id_op*=10,gen_comm_id*=10")
...@@ -30,8 +31,8 @@ def run_gen_ncc_id(attr): ...@@ -30,8 +31,8 @@ def run_gen_ncc_id(attr):
nccl_comm_num = attr['nccl_comm_num'] nccl_comm_num = attr['nccl_comm_num']
use_hallreduce = attr['use_hierarchical_allreduce'] use_hallreduce = attr['use_hierarchical_allreduce']
startup_program = paddle.static.Program() startup_program = paddle.static.default_startup_program()
main_program = paddle.static.Program() main_program = paddle.static.default_main_program()
with paddle.static.program_guard(main_program, startup_program): with paddle.static.program_guard(main_program, startup_program):
nccl_id_var = startup_program.global_block().create_var( nccl_id_var = startup_program.global_block().create_var(
...@@ -62,9 +63,7 @@ def run_gen_ncc_id(attr): ...@@ -62,9 +63,7 @@ def run_gen_ncc_id(attr):
place = paddle.CPUPlace() place = paddle.CPUPlace()
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
scope = paddle.static.Scope() exe.run(startup_program)
with paddle.static.scope_guard(scope):
exe.run(startup_program)
class TestGenNcclIdOp(unittest.TestCase): class TestGenNcclIdOp(unittest.TestCase):
...@@ -99,13 +98,12 @@ class TestGenNcclIdOp(unittest.TestCase): ...@@ -99,13 +98,12 @@ class TestGenNcclIdOp(unittest.TestCase):
procs = [] procs = []
for i in range(nranks): for i in range(nranks):
attr['trainer_id'] = i attr['trainer_id'] = i
# NOTE. multiprocessing cannot be covered by coverage # NOTE: multiprocessing cannot be covered by coverage
p = Thread(target=run_gen_ncc_id, args=(copy.copy(attr), )) p = Process(target=run_gen_ncc_id, args=(attr, ))
p.start() p.start()
procs.append(p) procs.append(p)
for p in procs: wait(procs, timeout=120)
p.join()
def test_flat(self): def test_flat(self):
print(">>> test gen flat nccl id") print(">>> test gen flat nccl id")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册