From 9a8517fd8b909baddac7945f403763ec1e7bd0c7 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Wed, 24 Jan 2018 14:43:56 +0800 Subject: [PATCH] daemonize the server process --- python/paddle/v2/fluid/layers/io.py | 19 ++++++--- python/paddle/v2/fluid/tests/test_recv_op.py | 44 ++++++++++++++------ 2 files changed, 46 insertions(+), 17 deletions(-) diff --git a/python/paddle/v2/fluid/layers/io.py b/python/paddle/v2/fluid/layers/io.py index be581531d14..bc804a40434 100644 --- a/python/paddle/v2/fluid/layers/io.py +++ b/python/paddle/v2/fluid/layers/io.py @@ -136,17 +136,26 @@ class ListenAndServ(object): # simple recv mode, recv operators inputs. for iname in op.input_names: for in_var_name in op.input(iname): - params.append(parent_block.var(name)) - grads.append(parent_block.var(name)) + params.append(parent_block.var(in_var_name)) + grads.append(parent_block.var(in_var_name)) return params, grads + def parent_block(self): + prog = self.helper.main_program + parent_idx = prog.current_block().parent_idx + assert parent_idx >= 0 + parent_block = prog.block(parent_idx) + return parent_block + def complete_op(self): main_program = self.helper.main_program current_block = main_program.current_block() parent_block = self.parent_block() params, grads = self.get_params_and_grads() + param_names = [p.name for p in params] + grad_names = [g.name for g in grads] parent_block.append_op( type='recv', inputs={}, @@ -154,8 +163,8 @@ class ListenAndServ(object): attrs={ 'endpoint': self.endpoint, 'Fanin': self.fan_in, - 'ParamList': params, - 'GradList': grads, + 'ParamList': param_names, + 'GradList': grad_names, 'OptimizeBlock': current_block }) @@ -177,7 +186,7 @@ def Send(endpoints, send_vars, get_vars): assert (type(get_vars) == list) epmap = endpoints.split(",") - endpoints = set(epmap) + endpoints = list(set(epmap)) helper = LayerHelper("Send", **locals()) helper.append_op( diff --git a/python/paddle/v2/fluid/tests/test_recv_op.py b/python/paddle/v2/fluid/tests/test_recv_op.py index e06f468648f..6ebb58ed33c 100644 --- a/python/paddle/v2/fluid/tests/test_recv_op.py +++ b/python/paddle/v2/fluid/tests/test_recv_op.py @@ -17,40 +17,60 @@ import unittest import paddle.v2.fluid as fluid import paddle.v2.fluid.layers as layers import numpy -import threading +from multiprocessing import Process +import os, sys class TestRecvOp(unittest.TestCase): def test_send(self): # Run init_serv in a thread place = fluid.CPUPlace() - t = threading.Thread(target=self.init_serv, args=(place, )) - t.start() + p = Process(target=self.init_serv, args=(place, )) + p.daemon = True + p.start() self.init_client(place) - t.join() + # FIXME(typhoonzero): find a way to gracefully shutdown the server. + os.system("kill -9 %d" % p.pid) + p.join() def init_serv(self, place): main = fluid.Program() with fluid.program_guard(main): - x = layers.data(shape=[32, 32], dtype='float32', name='X') - i = fluid.initializer.Constant(value=1.0) - y = i(x, main.global_block()) - serv = layers.ListenAndServ("127.0.0.1:6174") + x = layers.data( + shape=[32, 32], + dtype='float32', + name="X", + append_batch_size=False) + fluid.initializer.Constant(value=1.0)(x, main.global_block()) + serv = layers.ListenAndServ("127.0.0.1:6174", optimizer_mode=False) with serv.do(): - layers.scale(input=y, scale=10.0) + o = layers.scale(x=x, scale=10.0) + main.global_block().create_var( + name=o.name, psersistable=False, dtype=o.dtype, shape=o.shape) + print main exe = fluid.Executor(place) exe.run(main) def init_client(self, place): main = fluid.Program() with fluid.program_guard(main): - x = layers.data(shape=[32, 32], dtype='float32', name='X') - i = fluid.initializer.Constant(value=1.0) - i(x, main.global_block()) + x = layers.data( + shape=[32, 32], + dtype='float32', + name='X', + append_batch_size=False) + fluid.initializer.Constant(value=1.0)(x, main.global_block()) layers.Send("127.0.0.1:6174", [x], [x]) + print main exe = fluid.Executor(place) exe.run(main) if __name__ == "__main__": unittest.main() + # test = TestRecvOp() + # place = fluid.CPUPlace() + # if sys.argv[1] == "server": + # test.init_serv(place) + # else: + # test.init_client(place) -- GitLab