未验证 提交 f765e59c 编写于 作者: Y Yancey 提交者: GitHub

Merge pull request #9327 from Yancey1989/fix_test_recv_op

Fix test_recv_op
...@@ -113,9 +113,9 @@ class ListenAndServ(object): ...@@ -113,9 +113,9 @@ class ListenAndServ(object):
which can receive variables from clients and run a block. which can receive variables from clients and run a block.
""" """
def __init__(self, endpoint, fan_in=1, optimizer_mode=True): def __init__(self, endpoint, inputs, fan_in=1, optimizer_mode=True):
self.helper = LayerHelper("listen_and_serv") self.helper = LayerHelper("listen_and_serv")
self.inputs = [] self.inputs = inputs
self.outputs = [] self.outputs = []
self.endpoint = endpoint self.endpoint = endpoint
self.fan_in = fan_in self.fan_in = fan_in
...@@ -160,18 +160,13 @@ class ListenAndServ(object): ...@@ -160,18 +160,13 @@ class ListenAndServ(object):
current_block = main_program.current_block() current_block = main_program.current_block()
parent_block = self.parent_block() parent_block = self.parent_block()
params, grads = self.get_params_and_grads()
param_names = [p.name for p in params]
grad_names = [g.name for g in grads]
parent_block.append_op( parent_block.append_op(
type='listen_and_serv', type='listen_and_serv',
inputs={}, inputs={"X": self.inputs},
outputs={}, outputs={},
attrs={ attrs={
'endpoint': self.endpoint, 'endpoint': self.endpoint,
'Fanin': self.fan_in, 'Fanin': self.fan_in,
'ParamList': param_names,
'GradList': grad_names,
'OptimizeBlock': current_block 'OptimizeBlock': current_block
}) })
...@@ -196,10 +191,14 @@ def Send(endpoints, send_vars, get_vars): ...@@ -196,10 +191,14 @@ def Send(endpoints, send_vars, get_vars):
endpoints = list(set(epmap)) endpoints = list(set(epmap))
helper = LayerHelper("Send", **locals()) helper = LayerHelper("Send", **locals())
rpc_client_var = default_main_program().global_block().create_var(
name="RPC_CLIENT_VAR", persistable=True, type=core.VarDesc.VarType.RAW)
helper.append_op( helper.append_op(
type="send", type="send",
inputs={"X": send_vars}, inputs={"X": send_vars},
outputs={"Out": get_vars}, outputs={"Out": get_vars,
"RPCClient": rpc_client_var},
attrs={"endpoints": endpoints, attrs={"endpoints": endpoints,
"epmap": epmap}) "epmap": epmap})
......
...@@ -38,14 +38,15 @@ class TestRecvOp(unittest.TestCase): ...@@ -38,14 +38,15 @@ class TestRecvOp(unittest.TestCase):
def init_serv(self, place): def init_serv(self, place):
main = fluid.Program() main = fluid.Program()
with fluid.program_guard(main): with fluid.program_guard(main):
x = layers.data( serv = layers.ListenAndServ(
shape=[32, 32], "127.0.0.1:6174", ["X"], optimizer_mode=False)
dtype='float32',
name="X",
append_batch_size=False)
fluid.initializer.Constant(value=1.0)(x, main.global_block())
serv = layers.ListenAndServ("127.0.0.1:6174", optimizer_mode=False)
with serv.do(): with serv.do():
x = layers.data(
shape=[32, 32],
dtype='float32',
name="X",
append_batch_size=False)
fluid.initializer.Constant(value=1.0)(x, main.global_block())
o = layers.scale(x=x, scale=10.0) o = layers.scale(x=x, scale=10.0)
main.global_block().create_var( main.global_block().create_var(
name=o.name, psersistable=False, dtype=o.dtype, shape=o.shape) name=o.name, psersistable=False, dtype=o.dtype, shape=o.shape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册