未验证 提交 4873c20d 编写于 作者: L lilong12 提交者: GitHub

modify ut cmakefile (#28140)

* modify ut cmakefile, test=develop
上级 e8db4412
...@@ -15,12 +15,6 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_sparse_embedding) ...@@ -15,12 +15,6 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_sparse_embedding)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_transformer) list(APPEND DIST_TEST_OPS test_parallel_dygraph_transformer)
list(APPEND DIST_TEST_OPS test_listen_and_serv_op) list(APPEND DIST_TEST_OPS test_listen_and_serv_op)
list(APPEND DIST_TEST_OPS test_fleet_graph_execution_meta_optimizer) list(APPEND DIST_TEST_OPS test_fleet_graph_execution_meta_optimizer)
list(APPEND DIST_TEST_OPS test_collective_reduce_api)
list(APPEND DIST_TEST_OPS test_collective_scatter_api)
list(APPEND DIST_TEST_OPS test_collective_barrier_api)
list(APPEND DIST_TEST_OPS test_collective_allreduce_api)
list(APPEND DIST_TEST_OPS test_collective_broadcast_api)
list(APPEND DIST_TEST_OPS test_collective_allgather_api)
set(MIXED_DIST_TEST_OPS ${DIST_TEST_OPS}) set(MIXED_DIST_TEST_OPS ${DIST_TEST_OPS})
#remove distribute unittests. #remove distribute unittests.
list(APPEND MIXED_DIST_TEST_OPS test_dgc_op) list(APPEND MIXED_DIST_TEST_OPS test_dgc_op)
...@@ -70,6 +64,12 @@ if(NOT WITH_GPU OR WIN32) ...@@ -70,6 +64,12 @@ if(NOT WITH_GPU OR WIN32)
LIST(REMOVE_ITEM TEST_OPS test_collective_scatter) LIST(REMOVE_ITEM TEST_OPS test_collective_scatter)
LIST(REMOVE_ITEM TEST_OPS test_reducescatter) LIST(REMOVE_ITEM TEST_OPS test_reducescatter)
LIST(REMOVE_ITEM TEST_OPS test_reducescatter_api) LIST(REMOVE_ITEM TEST_OPS test_reducescatter_api)
LIST(REMOVE_ITEM TEST_OPS test_collective_reduce_api)
LIST(REMOVE_ITEM TEST_OPS test_collective_scatter_api)
LIST(REMOVE_ITEM TEST_OPS test_collective_barrier_api)
LIST(REMOVE_ITEM TEST_OPS test_collective_allreduce_api)
LIST(REMOVE_ITEM TEST_OPS test_collective_broadcast_api)
LIST(REMOVE_ITEM TEST_OPS test_collective_allgather_api)
endif() endif()
#TODO(sunxiaolong01): Fix this unitest failed on GCC8. #TODO(sunxiaolong01): Fix this unitest failed on GCC8.
......
...@@ -37,30 +37,6 @@ class TestCollectiveAPIRunnerBase(object): ...@@ -37,30 +37,6 @@ class TestCollectiveAPIRunnerBase(object):
raise NotImplementedError( raise NotImplementedError(
"get model should be implemented by child class.") "get model should be implemented by child class.")
def wait_server_ready(self, endpoints):
assert not isinstance(endpoints, string_types)
while True:
all_ok = True
not_ready_endpoints = []
for ep in endpoints:
ip_port = ep.split(":")
with closing(
socket.socket(socket.AF_INET,
socket.SOCK_STREAM)) as sock:
sock.settimeout(2)
result = sock.connect_ex((ip_port[0], int(ip_port[1])))
if result != 0:
all_ok = False
not_ready_endpoints.append(ep)
if not all_ok:
sys.stderr.write("server not ready, wait 3 sec to retry...\n")
sys.stderr.write("not ready endpoints:" + str(
not_ready_endpoints) + "\n")
sys.stderr.flush()
time.sleep(3)
else:
break
def run_trainer(self, args): def run_trainer(self, args):
train_prog = fluid.Program() train_prog = fluid.Program()
startup_prog = fluid.Program() startup_prog = fluid.Program()
...@@ -157,8 +133,8 @@ class TestDistBase(unittest.TestCase): ...@@ -157,8 +133,8 @@ class TestDistBase(unittest.TestCase):
tr_cmd = "%s %s" tr_cmd = "%s %s"
tr0_cmd = tr_cmd % (self._python_interp, model_file) tr0_cmd = tr_cmd % (self._python_interp, model_file)
tr1_cmd = tr_cmd % (self._python_interp, model_file) tr1_cmd = tr_cmd % (self._python_interp, model_file)
tr0_pipe = open("/tmp/tr0_err.log", "w") tr0_pipe = open("/tmp/tr0_err_%d.log" % os.getpid(), "w")
tr1_pipe = open("/tmp/tr1_err.log", "w") tr1_pipe = open("/tmp/tr1_err_%d.log" % os.getpid(), "w")
#print(tr0_cmd) #print(tr0_cmd)
tr0_proc = subprocess.Popen( tr0_proc = subprocess.Popen(
tr0_cmd.strip().split(), tr0_cmd.strip().split(),
...@@ -179,9 +155,9 @@ class TestDistBase(unittest.TestCase): ...@@ -179,9 +155,9 @@ class TestDistBase(unittest.TestCase):
# close trainer file # close trainer file
tr0_pipe.close() tr0_pipe.close()
tr1_pipe.close() tr1_pipe.close()
with open("/tmp/tr0_err.log", "r") as f: with open("/tmp/tr0_err_%d.log" % os.getpid(), "r") as f:
sys.stderr.write('trainer 0 stderr file: %s\n' % f.read()) sys.stderr.write('trainer 0 stderr file: %s\n' % f.read())
with open("/tmp/tr1_err.log", "r") as f: with open("/tmp/tr1_err_%d.log" % os.getpid(), "r") as f:
sys.stderr.write('trainer 1 stderr file: %s\n' % f.read()) sys.stderr.write('trainer 1 stderr file: %s\n' % f.read())
return pickle.loads(tr0_out), pickle.loads( return pickle.loads(tr0_out), pickle.loads(
tr1_out), tr0_proc.pid, tr1_proc.pid tr1_out), tr0_proc.pid, tr1_proc.pid
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册