提交 44925eb4 编写于 作者: Y yi.wu

fix dist ut

上级 d07d9535
......@@ -348,7 +348,8 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
};
void SignalHandler::StopAndExit(int signal_num) {
VLOG(3) << "Catch interrupt signal: " << signal_num << ", program will exit";
// Do not use VLOG here for the device for printing maybe already released.
// exit will release interal allocated resoureces.
exit(0);
}
......
......@@ -22,9 +22,9 @@ from ..executor import global_scope
from layer_function_generator import generate_layer_fn, templatedoc
__all__ = [
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file',
'open_files', 'read_file', 'shuffle', 'batch', 'double_buffer',
'random_data_generator', 'Preprocessor', 'load'
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'Recv',
'open_recordio_file', 'open_files', 'read_file', 'shuffle', 'batch',
'double_buffer', 'random_data_generator', 'Preprocessor', 'load'
]
......@@ -177,18 +177,17 @@ class ListenAndServ(object):
})
def Send(endpoints, send_vars, get_vars=None):
def Send(endpoints, send_vars, sync=True):
"""
Send layer
Send variables to the server side, and get vars from server
side when server have finished running server side program.
Args:
endpoints: comma seperated IP:PORT pairs in the order
endpoints (str): comma seperated IP:PORT pairs in the order
of send_vars to send
send_vars: vars to send
get_vars: vars to get from server after send completes.
Send variables to the server side, and get vars from server
side when server have finished running server side program.
send_vars (list): variables to send to server
sync (bool): whether to wait the request finish
"""
assert (type(send_vars) == list)
......@@ -196,40 +195,33 @@ def Send(endpoints, send_vars, get_vars=None):
endpoints = list(set(epmap))
helper = LayerHelper("Send", **locals())
if not get_vars:
get_vars = []
for s in send_vars:
v = helper.create_tmp_variable(dtype=s.dtype, stop_gradient=True)
get_vars.append(v)
rpc_op_role_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
helper.append_op(
type="send",
inputs={"X": send_vars},
outputs={"Out": get_vars},
attrs={
"endpoints": endpoints,
"epmap": epmap,
rpc_op_role_name: core.op_proto_and_checker_maker.OpRole.RPC
})
if sync:
helper.append_op(type="send_barrier", attrs={"endpoints": endpoints})
return get_vars
def Recv(endpoints, get_vars):
def Recv(endpoints, get_vars, sync=True):
"""
Recv layer
Receive variables from server side
Args:
endpoints: comma seperated IP:PORT pairs in the order
endpoints (str): comma seperated IP:PORT pairs in the order
of send_vars to send
send_vars: vars to send
get_vars: vars to get from server after send completes.
get_vars (list): vars to get from server after send completes.
sync (bool): whether to wait the request finish
Send variables to the server side, and get vars from server
side when server have finished running server side program.
Returns:
list: list of received variables
"""
assert (type(send_vars) == list)
assert (type(get_vars) == list)
epmap = endpoints.split(",")
......@@ -242,6 +234,9 @@ def Recv(endpoints, get_vars):
outputs={"Out": get_vars},
attrs={"endpoints": endpoints,
"epmap": epmap})
if sync:
helper.append_op(type="fetch_barrier", attrs={"endpoints": endpoints})
return get_vars
def monkey_patch_reader_methods(reader):
......
......@@ -16,6 +16,7 @@ import os
import time
import unittest
from multiprocessing import Process
import signal
import numpy
......@@ -24,9 +25,6 @@ import paddle.fluid.layers as layers
class TestSendOp(unittest.TestCase):
@unittest.skip(
"This test is buggy. We cannot use time.sleep to sync processes, the connection may fail in unittest."
)
def test_send(self):
# Run init_serv in a thread
place = fluid.CPUPlace()
......@@ -35,7 +33,9 @@ class TestSendOp(unittest.TestCase):
p.daemon = True
p.start()
time.sleep(10)
self.ps_timeout = 5
self._wait_ps_ready(p.pid)
with open("/tmp/paddle.%d.port" % p.pid, "r") as fn:
selected_port = int(fn.readlines()[0])
self.init_client(place, selected_port)
......@@ -44,9 +44,23 @@ class TestSendOp(unittest.TestCase):
self.assertTrue(numpy.allclose(self.local_out, self.dist_out))
# FIXME(typhoonzero): find a way to gracefully shutdown the server.
os.system("kill -9 %d" % p.pid)
os.kill(p.pid, signal.SIGKILL)
p.join()
def _wait_ps_ready(self, pid):
start_left_time = self.ps_timeout
sleep_time = 0.5
while True:
assert start_left_time >= 0, "wait ps ready failed"
time.sleep(sleep_time)
try:
# the listen_and_serv_op would touch a file which contains the listen port
# on the /tmp directory until it was ready to process all the RPC call.
os.stat("/tmp/paddle.%d.port" % pid)
return
except os.error:
start_left_time -= sleep_time
def init_serv(self, place):
main = fluid.Program()
......@@ -84,7 +98,10 @@ class TestSendOp(unittest.TestCase):
dtype="float32",
persistable=False,
shape=[32, 32])
o = layers.Send("127.0.0.1:%d" % port, [x], [get_var])
fluid.initializer.Constant(value=2.3)(get_var, main.global_block())
layers.Send("127.0.0.1:%d" % port, [x])
o = layers.Recv("127.0.0.1:%d" % port, [get_var])
exe = fluid.Executor(place)
self.dist_out = exe.run(main, fetch_list=o) # o is a list
......
......@@ -57,17 +57,18 @@ class TestListenAndServOp(OpTest):
def setUp(self):
self.ps_timeout = 5
self.ip = "127.0.0.1"
self.port = "6173"
self.port = "0"
self.trainers = 1
self.trainer_id = 1
self.trainer_id = 0
def _start_pserver(self, use_cuda, sync_mode):
p = Process(
target=run_pserver,
args=(use_cuda, sync_mode, self.ip, self.port, self.trainers,
self.trainer_id))
p.daemon = True
p.start()
return p.pid
return p
def _wait_ps_ready(self, pid):
start_left_time = self.ps_timeout
......@@ -89,18 +90,20 @@ class TestListenAndServOp(OpTest):
def test_handle_signal_in_serv_op(self):
# run pserver on CPU in sync mode
pid = self._start_pserver(False, True)
self._wait_ps_ready(pid)
p1 = self._start_pserver(False, True)
self._wait_ps_ready(p1.pid)
# raise SIGTERM to pserver
os.kill(pid, signal.SIGTERM)
os.kill(p1.pid, signal.SIGKILL)
p1.join()
# run pserver on CPU in async mode
pid = self._start_pserver(False, False)
self._wait_ps_ready(pid)
p2 = self._start_pserver(False, False)
self._wait_ps_ready(p2.pid)
# raise SIGTERM to pserver
os.kill(pid, signal.SIGTERM)
os.kill(p2.pid, signal.SIGKILL)
p2.join()
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册