diff --git a/python/paddle/fluid/tests/unittests/test_gen_nccl_id_op.py b/python/paddle/fluid/tests/unittests/test_gen_nccl_id_op.py index 17df3347dc491e69aa53a991bce1fa22ad83ac98..c5e48e27a75d5672542d26b3a0150fabfb9f5e5a 100644 --- a/python/paddle/fluid/tests/unittests/test_gen_nccl_id_op.py +++ b/python/paddle/fluid/tests/unittests/test_gen_nccl_id_op.py @@ -16,6 +16,7 @@ import unittest import os import copy from launch_function_helper import wait, _find_free_port +from multiprocessing import Pool, Process from threading import Thread os.environ['GLOG_vmodule'] = str("gen_nccl_id_op*=10,gen_comm_id*=10") @@ -30,8 +31,8 @@ def run_gen_ncc_id(attr): nccl_comm_num = attr['nccl_comm_num'] use_hallreduce = attr['use_hierarchical_allreduce'] - startup_program = paddle.static.Program() - main_program = paddle.static.Program() + startup_program = paddle.static.default_startup_program() + main_program = paddle.static.default_main_program() with paddle.static.program_guard(main_program, startup_program): nccl_id_var = startup_program.global_block().create_var( @@ -62,9 +63,7 @@ def run_gen_ncc_id(attr): place = paddle.CPUPlace() exe = paddle.static.Executor(place) - scope = paddle.static.Scope() - with paddle.static.scope_guard(scope): - exe.run(startup_program) + exe.run(startup_program) class TestGenNcclIdOp(unittest.TestCase): @@ -99,13 +98,12 @@ class TestGenNcclIdOp(unittest.TestCase): procs = [] for i in range(nranks): attr['trainer_id'] = i - # NOTE. multiprocessing cannot be covered by coverage - p = Thread(target=run_gen_ncc_id, args=(copy.copy(attr), )) + # NOTE: multiprocessing cannot be covered by coverage + p = Process(target=run_gen_ncc_id, args=(attr, )) p.start() procs.append(p) - for p in procs: - p.join() + wait(procs, timeout=120) def test_flat(self): print(">>> test gen flat nccl id")