未验证 提交 f765e59c 编写于 作者: Y Yancey 提交者: GitHub

Merge pull request #9327 from Yancey1989/fix_test_recv_op

Fix test_recv_op
......@@ -113,9 +113,9 @@ class ListenAndServ(object):
which can receive variables from clients and run a block.
"""
def __init__(self, endpoint, fan_in=1, optimizer_mode=True):
def __init__(self, endpoint, inputs, fan_in=1, optimizer_mode=True):
self.helper = LayerHelper("listen_and_serv")
self.inputs = []
self.inputs = inputs
self.outputs = []
self.endpoint = endpoint
self.fan_in = fan_in
......@@ -160,18 +160,13 @@ class ListenAndServ(object):
current_block = main_program.current_block()
parent_block = self.parent_block()
params, grads = self.get_params_and_grads()
param_names = [p.name for p in params]
grad_names = [g.name for g in grads]
parent_block.append_op(
type='listen_and_serv',
inputs={},
inputs={"X": self.inputs},
outputs={},
attrs={
'endpoint': self.endpoint,
'Fanin': self.fan_in,
'ParamList': param_names,
'GradList': grad_names,
'OptimizeBlock': current_block
})
......@@ -196,10 +191,14 @@ def Send(endpoints, send_vars, get_vars):
endpoints = list(set(epmap))
helper = LayerHelper("Send", **locals())
rpc_client_var = default_main_program().global_block().create_var(
name="RPC_CLIENT_VAR", persistable=True, type=core.VarDesc.VarType.RAW)
helper.append_op(
type="send",
inputs={"X": send_vars},
outputs={"Out": get_vars},
outputs={"Out": get_vars,
"RPCClient": rpc_client_var},
attrs={"endpoints": endpoints,
"epmap": epmap})
......
......@@ -38,14 +38,15 @@ class TestRecvOp(unittest.TestCase):
def init_serv(self, place):
main = fluid.Program()
with fluid.program_guard(main):
x = layers.data(
shape=[32, 32],
dtype='float32',
name="X",
append_batch_size=False)
fluid.initializer.Constant(value=1.0)(x, main.global_block())
serv = layers.ListenAndServ("127.0.0.1:6174", optimizer_mode=False)
serv = layers.ListenAndServ(
"127.0.0.1:6174", ["X"], optimizer_mode=False)
with serv.do():
x = layers.data(
shape=[32, 32],
dtype='float32',
name="X",
append_batch_size=False)
fluid.initializer.Constant(value=1.0)(x, main.global_block())
o = layers.scale(x=x, scale=10.0)
main.global_block().create_var(
name=o.name, psersistable=False, dtype=o.dtype, shape=o.shape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册