From 0e850c7417084675dcc997768c7f854333625bfe Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Tue, 23 Jan 2018 20:26:00 +0800 Subject: [PATCH] WIP --- python/paddle/v2/fluid/layers/io.py | 23 +++++++++++++++----- python/paddle/v2/fluid/tests/test_recv_op.py | 21 +++++++++++++----- 2 files changed, 33 insertions(+), 11 deletions(-) diff --git a/python/paddle/v2/fluid/layers/io.py b/python/paddle/v2/fluid/layers/io.py index 6a6c561641e..be581531d14 100644 --- a/python/paddle/v2/fluid/layers/io.py +++ b/python/paddle/v2/fluid/layers/io.py @@ -14,8 +14,10 @@ from .. import core from ..layer_helper import LayerHelper +from control_flow import BlockGuard +from ..layer_helper import LayerHelper -__all__ = ['data'] +__all__ = ['data', 'BlockGuardServ', 'ListenAndServ', 'Send'] def data(name, @@ -105,12 +107,14 @@ class ListenAndServ(object): which can receive variables from clients and run a block. """ - def __init__(self, endpoint, fan_in=1): - self.helper = LayerHelper("recv", name=name) + def __init__(self, endpoint, fan_in=1, optimizer_mode=True): + self.helper = LayerHelper("recv") self.inputs = [] self.outputs = [] self.endpoint = endpoint self.fan_in = fan_in + # FIXME(typhoonzero): Add this switch is stupid + self.optimizer_mode = optimizer_mode def do(self): return BlockGuardServ(self) @@ -124,9 +128,16 @@ class ListenAndServ(object): grads = list() for op in current_block.ops: # FIXME(typhoonzero): op.inputs is None if it's cloned. - if "Grad" in op.inputs and "Param" in op.inputs: - params.append(op.inputs["Param"].name) - grads.append(op.inputs["Grad"].name) + if self.optimizer_mode: + if "Grad" in op.inputs and "Param" in op.inputs: + params.append(op.inputs["Param"].name) + grads.append(op.inputs["Grad"].name) + else: + # simple recv mode, recv operators inputs. + for iname in op.input_names: + for in_var_name in op.input(iname): + params.append(parent_block.var(name)) + grads.append(parent_block.var(name)) return params, grads diff --git a/python/paddle/v2/fluid/tests/test_recv_op.py b/python/paddle/v2/fluid/tests/test_recv_op.py index fbd182a7162..e06f468648f 100644 --- a/python/paddle/v2/fluid/tests/test_recv_op.py +++ b/python/paddle/v2/fluid/tests/test_recv_op.py @@ -17,20 +17,27 @@ import unittest import paddle.v2.fluid as fluid import paddle.v2.fluid.layers as layers import numpy +import threading class TestRecvOp(unittest.TestCase): - def run_test(self): + def test_send(self): # Run init_serv in a thread - pass + place = fluid.CPUPlace() + t = threading.Thread(target=self.init_serv, args=(place, )) + t.start() + self.init_client(place) + t.join() def init_serv(self, place): main = fluid.Program() with fluid.program_guard(main): x = layers.data(shape=[32, 32], dtype='float32', name='X') - serv = fluid.ListenAndServ("127.0.0.1:6174") + i = fluid.initializer.Constant(value=1.0) + y = i(x, main.global_block()) + serv = layers.ListenAndServ("127.0.0.1:6174") with serv.do(): - layers.scale(input=x, scale=10) + layers.scale(input=y, scale=10.0) exe = fluid.Executor(place) exe.run(main) @@ -38,8 +45,12 @@ class TestRecvOp(unittest.TestCase): main = fluid.Program() with fluid.program_guard(main): x = layers.data(shape=[32, 32], dtype='float32', name='X') - i = fluid.initializer.Constant(x=1.0) + i = fluid.initializer.Constant(value=1.0) i(x, main.global_block()) layers.Send("127.0.0.1:6174", [x], [x]) exe = fluid.Executor(place) exe.run(main) + + +if __name__ == "__main__": + unittest.main() -- GitLab