未验证 提交 5ea039b3 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #11470 from typhoonzero/fix_unitests

Fix dist ut
...@@ -348,7 +348,8 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -348,7 +348,8 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
}; };
void SignalHandler::StopAndExit(int signal_num) { void SignalHandler::StopAndExit(int signal_num) {
VLOG(3) << "Catch interrupt signal: " << signal_num << ", program will exit"; // Do not use VLOG here for the device for printing maybe already released.
// exit will release interal allocated resoureces.
exit(0); exit(0);
} }
......
...@@ -22,9 +22,9 @@ from ..executor import global_scope ...@@ -22,9 +22,9 @@ from ..executor import global_scope
from layer_function_generator import generate_layer_fn, templatedoc from layer_function_generator import generate_layer_fn, templatedoc
__all__ = [ __all__ = [
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file', 'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'Recv',
'open_files', 'read_file', 'shuffle', 'batch', 'double_buffer', 'open_recordio_file', 'open_files', 'read_file', 'shuffle', 'batch',
'random_data_generator', 'Preprocessor', 'load' 'double_buffer', 'random_data_generator', 'Preprocessor', 'load'
] ]
...@@ -177,18 +177,17 @@ class ListenAndServ(object): ...@@ -177,18 +177,17 @@ class ListenAndServ(object):
}) })
def Send(endpoints, send_vars, get_vars=None): def Send(endpoints, send_vars, sync=True):
""" """
Send layer Send variables to the server side, and get vars from server
side when server have finished running server side program.
Args: Args:
endpoints: comma seperated IP:PORT pairs in the order endpoints (str): comma seperated IP:PORT pairs in the order
of send_vars to send of send_vars to send
send_vars: vars to send send_vars (list): variables to send to server
get_vars: vars to get from server after send completes. sync (bool): whether to wait the request finish
Send variables to the server side, and get vars from server
side when server have finished running server side program.
""" """
assert (type(send_vars) == list) assert (type(send_vars) == list)
...@@ -196,40 +195,33 @@ def Send(endpoints, send_vars, get_vars=None): ...@@ -196,40 +195,33 @@ def Send(endpoints, send_vars, get_vars=None):
endpoints = list(set(epmap)) endpoints = list(set(epmap))
helper = LayerHelper("Send", **locals()) helper = LayerHelper("Send", **locals())
if not get_vars:
get_vars = []
for s in send_vars:
v = helper.create_tmp_variable(dtype=s.dtype, stop_gradient=True)
get_vars.append(v)
rpc_op_role_name = core.op_proto_and_checker_maker.kOpRoleAttrName() rpc_op_role_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
helper.append_op( helper.append_op(
type="send", type="send",
inputs={"X": send_vars}, inputs={"X": send_vars},
outputs={"Out": get_vars},
attrs={ attrs={
"endpoints": endpoints, "endpoints": endpoints,
"epmap": epmap, "epmap": epmap,
rpc_op_role_name: core.op_proto_and_checker_maker.OpRole.RPC rpc_op_role_name: core.op_proto_and_checker_maker.OpRole.RPC
}) })
if sync:
helper.append_op(type="send_barrier", attrs={"endpoints": endpoints})
return get_vars
def Recv(endpoints, get_vars, sync=True):
def Recv(endpoints, get_vars):
""" """
Recv layer Receive variables from server side
Args: Args:
endpoints: comma seperated IP:PORT pairs in the order endpoints (str): comma seperated IP:PORT pairs in the order
of send_vars to send of send_vars to send
send_vars: vars to send get_vars (list): vars to get from server after send completes.
get_vars: vars to get from server after send completes. sync (bool): whether to wait the request finish
Send variables to the server side, and get vars from server Returns:
side when server have finished running server side program. list: list of received variables
""" """
assert (type(send_vars) == list)
assert (type(get_vars) == list) assert (type(get_vars) == list)
epmap = endpoints.split(",") epmap = endpoints.split(",")
...@@ -242,6 +234,9 @@ def Recv(endpoints, get_vars): ...@@ -242,6 +234,9 @@ def Recv(endpoints, get_vars):
outputs={"Out": get_vars}, outputs={"Out": get_vars},
attrs={"endpoints": endpoints, attrs={"endpoints": endpoints,
"epmap": epmap}) "epmap": epmap})
if sync:
helper.append_op(type="fetch_barrier", attrs={"endpoints": endpoints})
return get_vars
def monkey_patch_reader_methods(reader): def monkey_patch_reader_methods(reader):
......
...@@ -16,6 +16,7 @@ import os ...@@ -16,6 +16,7 @@ import os
import time import time
import unittest import unittest
from multiprocessing import Process from multiprocessing import Process
import signal
import numpy import numpy
...@@ -24,9 +25,6 @@ import paddle.fluid.layers as layers ...@@ -24,9 +25,6 @@ import paddle.fluid.layers as layers
class TestSendOp(unittest.TestCase): class TestSendOp(unittest.TestCase):
@unittest.skip(
"This test is buggy. We cannot use time.sleep to sync processes, the connection may fail in unittest."
)
def test_send(self): def test_send(self):
# Run init_serv in a thread # Run init_serv in a thread
place = fluid.CPUPlace() place = fluid.CPUPlace()
...@@ -35,7 +33,9 @@ class TestSendOp(unittest.TestCase): ...@@ -35,7 +33,9 @@ class TestSendOp(unittest.TestCase):
p.daemon = True p.daemon = True
p.start() p.start()
time.sleep(10) self.ps_timeout = 5
self._wait_ps_ready(p.pid)
with open("/tmp/paddle.%d.port" % p.pid, "r") as fn: with open("/tmp/paddle.%d.port" % p.pid, "r") as fn:
selected_port = int(fn.readlines()[0]) selected_port = int(fn.readlines()[0])
self.init_client(place, selected_port) self.init_client(place, selected_port)
...@@ -44,9 +44,23 @@ class TestSendOp(unittest.TestCase): ...@@ -44,9 +44,23 @@ class TestSendOp(unittest.TestCase):
self.assertTrue(numpy.allclose(self.local_out, self.dist_out)) self.assertTrue(numpy.allclose(self.local_out, self.dist_out))
# FIXME(typhoonzero): find a way to gracefully shutdown the server. # FIXME(typhoonzero): find a way to gracefully shutdown the server.
os.system("kill -9 %d" % p.pid) os.kill(p.pid, signal.SIGKILL)
p.join() p.join()
def _wait_ps_ready(self, pid):
start_left_time = self.ps_timeout
sleep_time = 0.5
while True:
assert start_left_time >= 0, "wait ps ready failed"
time.sleep(sleep_time)
try:
# the listen_and_serv_op would touch a file which contains the listen port
# on the /tmp directory until it was ready to process all the RPC call.
os.stat("/tmp/paddle.%d.port" % pid)
return
except os.error:
start_left_time -= sleep_time
def init_serv(self, place): def init_serv(self, place):
main = fluid.Program() main = fluid.Program()
...@@ -84,7 +98,10 @@ class TestSendOp(unittest.TestCase): ...@@ -84,7 +98,10 @@ class TestSendOp(unittest.TestCase):
dtype="float32", dtype="float32",
persistable=False, persistable=False,
shape=[32, 32]) shape=[32, 32])
o = layers.Send("127.0.0.1:%d" % port, [x], [get_var]) fluid.initializer.Constant(value=2.3)(get_var, main.global_block())
layers.Send("127.0.0.1:%d" % port, [x])
o = layers.Recv("127.0.0.1:%d" % port, [get_var])
exe = fluid.Executor(place) exe = fluid.Executor(place)
self.dist_out = exe.run(main, fetch_list=o) # o is a list self.dist_out = exe.run(main, fetch_list=o) # o is a list
......
...@@ -57,17 +57,18 @@ class TestListenAndServOp(OpTest): ...@@ -57,17 +57,18 @@ class TestListenAndServOp(OpTest):
def setUp(self): def setUp(self):
self.ps_timeout = 5 self.ps_timeout = 5
self.ip = "127.0.0.1" self.ip = "127.0.0.1"
self.port = "6173" self.port = "0"
self.trainers = 1 self.trainers = 1
self.trainer_id = 1 self.trainer_id = 0
def _start_pserver(self, use_cuda, sync_mode): def _start_pserver(self, use_cuda, sync_mode):
p = Process( p = Process(
target=run_pserver, target=run_pserver,
args=(use_cuda, sync_mode, self.ip, self.port, self.trainers, args=(use_cuda, sync_mode, self.ip, self.port, self.trainers,
self.trainer_id)) self.trainer_id))
p.daemon = True
p.start() p.start()
return p.pid return p
def _wait_ps_ready(self, pid): def _wait_ps_ready(self, pid):
start_left_time = self.ps_timeout start_left_time = self.ps_timeout
...@@ -89,18 +90,20 @@ class TestListenAndServOp(OpTest): ...@@ -89,18 +90,20 @@ class TestListenAndServOp(OpTest):
def test_handle_signal_in_serv_op(self): def test_handle_signal_in_serv_op(self):
# run pserver on CPU in sync mode # run pserver on CPU in sync mode
pid = self._start_pserver(False, True) p1 = self._start_pserver(False, True)
self._wait_ps_ready(pid) self._wait_ps_ready(p1.pid)
# raise SIGTERM to pserver # raise SIGTERM to pserver
os.kill(pid, signal.SIGTERM) os.kill(p1.pid, signal.SIGKILL)
p1.join()
# run pserver on CPU in async mode # run pserver on CPU in async mode
pid = self._start_pserver(False, False) p2 = self._start_pserver(False, False)
self._wait_ps_ready(pid) self._wait_ps_ready(p2.pid)
# raise SIGTERM to pserver # raise SIGTERM to pserver
os.kill(pid, signal.SIGTERM) os.kill(p2.pid, signal.SIGKILL)
p2.join()
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册