diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index fead95ffdab25c7ea96b7ef223efc0abf7eea3e3..c33539f6b50a3dc079e2a1e7820a63f264457b95 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -48,5 +48,7 @@ foreach(TEST_OP ${TEST_OPS}) endforeach(TEST_OP) py_test_modules(test_warpctc_op MODULES test_warpctc_op ENVS FLAGS_warpctc_dir=${WARPCTC_LIB_DIR} SERIAL) py_test_modules(test_dist_train MODULES test_dist_train SERIAL) -# tests that need to be done in fixed timeout -set_tests_properties(test_listen_and_serv_op PROPERTIES TIMEOUT 20) +# FIXME(Yancey1989): this test would cost much more time on CUDAPlace +# since load cudnn libraries, so we use a longer timeout to make this +# unit test stability. +set_tests_properties(test_listen_and_serv_op PROPERTIES TIMEOUT 30) diff --git a/python/paddle/fluid/tests/unittests/test_listen_and_serv_op.py b/python/paddle/fluid/tests/unittests/test_listen_and_serv_op.py index cf89f9d0ebf6200933e539ef7fa8cbdc8f6db058..ad479657cc2fbeebcac03bdad2e16315882e2f01 100644 --- a/python/paddle/fluid/tests/unittests/test_listen_and_serv_op.py +++ b/python/paddle/fluid/tests/unittests/test_listen_and_serv_op.py @@ -23,7 +23,7 @@ from multiprocessing import Process from op_test import OpTest -def run_pserver(use_cuda, sync_mode, ip, port, trainer_count, trainer_id): +def run_pserver(use_cuda, sync_mode, ip, port, trainers, trainer_id): x = fluid.layers.data(name='x', shape=[1], dtype='float32') y_predict = fluid.layers.fc(input=x, size=1, act=None) y = fluid.layers.data(name='y', shape=[1], dtype='float32') @@ -39,15 +39,8 @@ def run_pserver(use_cuda, sync_mode, ip, port, trainer_count, trainer_id): place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() exe = fluid.Executor(place) - port = os.getenv("PADDLE_INIT_PORT", port) - pserver_ips = os.getenv("PADDLE_INIT_PSERVERS", ip) # ip,ip... - eplist = [] - for ip in pserver_ips.split(","): - eplist.append(':'.join([ip, port])) - pserver_endpoints = ",".join(eplist) # ip:port,ip:port... - trainers = int(os.getenv("TRAINERS", trainer_count)) - current_endpoint = os.getenv("POD_IP", ip) + ":" + port - trainer_id = int(os.getenv("PADDLE_INIT_TRAINER_ID", trainer_id)) + pserver_endpoints = ip + ":" + port + current_endpoint = ip + ":" + port t = fluid.DistributeTranspiler() t.transpile( trainer_id, @@ -62,47 +55,47 @@ def run_pserver(use_cuda, sync_mode, ip, port, trainer_count, trainer_id): class TestListenAndServOp(OpTest): def setUp(self): - self.sleep_time = 5 + self.ps_timeout = 5 self.ip = "127.0.0.1" self.port = "6173" - self.trainer_count = 1 + self.trainers = 1 self.trainer_id = 1 - def _raise_signal(self, parent_pid, raised_signal): - time.sleep(self.sleep_time) - ps_command = subprocess.Popen( - "ps -o pid --ppid %d --noheaders" % parent_pid, - shell=True, - stdout=subprocess.PIPE) - ps_output = ps_command.stdout.read() - retcode = ps_command.wait() - assert retcode == 0, "ps command returned %d" % retcode - - for pid_str in ps_output.split("\n")[:-1]: - try: - os.kill(int(pid_str), raised_signal) - except Exception: - continue - def _start_pserver(self, use_cuda, sync_mode): p = Process( target=run_pserver, - args=(use_cuda, sync_mode, self.ip, self.port, self.trainer_count, + args=(use_cuda, sync_mode, self.ip, self.port, self.trainers, self.trainer_id)) p.start() + return p.pid + + def _wait_ps_ready(self, pid): + retry_times = self.ps_timeout + while True: + time.sleep(1) + assert retry_times >= 0, "wait ps ready failed" + try: + # the listen_and_serv_op would touch a file which contains the listen port + # on the /tmp directory until it was ready to process all the RPC call. + os.stat("/tmp/paddle.%d.port" % pid) + return + except os.error: + retry_times -= 1 def test_handle_signal_in_serv_op(self): # run pserver on CPU in sync mode - self._start_pserver(False, True) + pid = self._start_pserver(False, True) + self._wait_ps_ready(pid) # raise SIGINT to pserver - self._raise_signal(os.getpid(), signal.SIGINT) + os.kill(pid, signal.SIGINT) # run pserver on CPU in async mode - self._start_pserver(False, False) + pid = self._start_pserver(False, False) + self._wait_ps_ready(pid) # raise SIGTERM to pserver - self._raise_signal(os.getpid(), signal.SIGTERM) + os.kill(pid, signal.SIGINT) if __name__ == '__main__':