提交 8939f17d 编写于 作者: Y Yancey1989

sppedup test_listen_and_serv_op

上级 9503dbb1
...@@ -48,5 +48,7 @@ foreach(TEST_OP ${TEST_OPS}) ...@@ -48,5 +48,7 @@ foreach(TEST_OP ${TEST_OPS})
endforeach(TEST_OP) endforeach(TEST_OP)
py_test_modules(test_warpctc_op MODULES test_warpctc_op ENVS FLAGS_warpctc_dir=${WARPCTC_LIB_DIR} SERIAL) py_test_modules(test_warpctc_op MODULES test_warpctc_op ENVS FLAGS_warpctc_dir=${WARPCTC_LIB_DIR} SERIAL)
py_test_modules(test_dist_train MODULES test_dist_train SERIAL) py_test_modules(test_dist_train MODULES test_dist_train SERIAL)
# tests that need to be done in fixed timeout # FIXME(Yancey1989): this test would cost much more time on CUDAPlace
set_tests_properties(test_listen_and_serv_op PROPERTIES TIMEOUT 20) # since load cudnn libraries, so we use a longer timeout to make this
# unit test stability.
set_tests_properties(test_listen_and_serv_op PROPERTIES TIMEOUT 30)
...@@ -23,7 +23,7 @@ from multiprocessing import Process ...@@ -23,7 +23,7 @@ from multiprocessing import Process
from op_test import OpTest from op_test import OpTest
def run_pserver(use_cuda, sync_mode, ip, port, trainer_count, trainer_id): def run_pserver(use_cuda, sync_mode, ip, port, trainers, trainer_id):
x = fluid.layers.data(name='x', shape=[1], dtype='float32') x = fluid.layers.data(name='x', shape=[1], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None) y_predict = fluid.layers.fc(input=x, size=1, act=None)
y = fluid.layers.data(name='y', shape=[1], dtype='float32') y = fluid.layers.data(name='y', shape=[1], dtype='float32')
...@@ -39,15 +39,8 @@ def run_pserver(use_cuda, sync_mode, ip, port, trainer_count, trainer_id): ...@@ -39,15 +39,8 @@ def run_pserver(use_cuda, sync_mode, ip, port, trainer_count, trainer_id):
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
port = os.getenv("PADDLE_INIT_PORT", port) pserver_endpoints = ip + ":" + port
pserver_ips = os.getenv("PADDLE_INIT_PSERVERS", ip) # ip,ip... current_endpoint = ip + ":" + port
eplist = []
for ip in pserver_ips.split(","):
eplist.append(':'.join([ip, port]))
pserver_endpoints = ",".join(eplist) # ip:port,ip:port...
trainers = int(os.getenv("TRAINERS", trainer_count))
current_endpoint = os.getenv("POD_IP", ip) + ":" + port
trainer_id = int(os.getenv("PADDLE_INIT_TRAINER_ID", trainer_id))
t = fluid.DistributeTranspiler() t = fluid.DistributeTranspiler()
t.transpile( t.transpile(
trainer_id, trainer_id,
...@@ -62,47 +55,47 @@ def run_pserver(use_cuda, sync_mode, ip, port, trainer_count, trainer_id): ...@@ -62,47 +55,47 @@ def run_pserver(use_cuda, sync_mode, ip, port, trainer_count, trainer_id):
class TestListenAndServOp(OpTest): class TestListenAndServOp(OpTest):
def setUp(self): def setUp(self):
self.sleep_time = 5 self.ps_timeout = 5
self.ip = "127.0.0.1" self.ip = "127.0.0.1"
self.port = "6173" self.port = "6173"
self.trainer_count = 1 self.trainers = 1
self.trainer_id = 1 self.trainer_id = 1
def _raise_signal(self, parent_pid, raised_signal):
time.sleep(self.sleep_time)
ps_command = subprocess.Popen(
"ps -o pid --ppid %d --noheaders" % parent_pid,
shell=True,
stdout=subprocess.PIPE)
ps_output = ps_command.stdout.read()
retcode = ps_command.wait()
assert retcode == 0, "ps command returned %d" % retcode
for pid_str in ps_output.split("\n")[:-1]:
try:
os.kill(int(pid_str), raised_signal)
except Exception:
continue
def _start_pserver(self, use_cuda, sync_mode): def _start_pserver(self, use_cuda, sync_mode):
p = Process( p = Process(
target=run_pserver, target=run_pserver,
args=(use_cuda, sync_mode, self.ip, self.port, self.trainer_count, args=(use_cuda, sync_mode, self.ip, self.port, self.trainers,
self.trainer_id)) self.trainer_id))
p.start() p.start()
return p.pid
def _wait_ps_ready(self, pid):
retry_times = self.ps_timeout
while True:
time.sleep(1)
assert retry_times >= 0, "wait ps ready failed"
try:
# the listen_and_serv_op would touch a file which contains the listen port
# on the /tmp directory until it was ready to process all the RPC call.
os.stat("/tmp/paddle.%d.port" % pid)
return
except os.error:
retry_times -= 1
def test_handle_signal_in_serv_op(self): def test_handle_signal_in_serv_op(self):
# run pserver on CPU in sync mode # run pserver on CPU in sync mode
self._start_pserver(False, True) pid = self._start_pserver(False, True)
self._wait_ps_ready(pid)
# raise SIGINT to pserver # raise SIGINT to pserver
self._raise_signal(os.getpid(), signal.SIGINT) os.kill(pid, signal.SIGINT)
# run pserver on CPU in async mode # run pserver on CPU in async mode
self._start_pserver(False, False) pid = self._start_pserver(False, False)
self._wait_ps_ready(pid)
# raise SIGTERM to pserver # raise SIGTERM to pserver
self._raise_signal(os.getpid(), signal.SIGTERM) os.kill(pid, signal.SIGINT)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册