未验证 提交 201928d0 编写于 作者: Q Qiyang Min 提交者: GitHub

Merge pull request #11839 from velconia/fix_reader_op_in_pserver

Do NOT clone input vars if op NOT in pserver_program
...@@ -309,10 +309,10 @@ class DistributeTranspiler(object): ...@@ -309,10 +309,10 @@ class DistributeTranspiler(object):
def get_pserver_program(self, endpoint): def get_pserver_program(self, endpoint):
""" """
Get parameter server side program. Get parameter server side program.
Args: Args:
endpoint (str): current parameter server endpoint. endpoint (str): current parameter server endpoint.
Returns: Returns:
Program: the program for current parameter server to run. Program: the program for current parameter server to run.
""" """
...@@ -516,7 +516,7 @@ class DistributeTranspiler(object): ...@@ -516,7 +516,7 @@ class DistributeTranspiler(object):
endpoint (str): current pserver endpoint. endpoint (str): current pserver endpoint.
pserver_program (Program): call get_pserver_program first and pserver_program (Program): call get_pserver_program first and
pass the result here. pass the result here.
Returns: Returns:
Program: parameter server side startup program. Program: parameter server side startup program.
""" """
...@@ -552,10 +552,10 @@ class DistributeTranspiler(object): ...@@ -552,10 +552,10 @@ class DistributeTranspiler(object):
op_on_pserver = True op_on_pserver = True
new_outputs[key] = pserver_vars[op.output(key)[0]] new_outputs[key] = pserver_vars[op.output(key)[0]]
# most startup program ops have no inputs
new_inputs = self._get_input_map_from_op(pserver_vars, op)
if op_on_pserver: if op_on_pserver:
# most startup program ops have no inputs
new_inputs = self._get_input_map_from_op(pserver_vars, op)
if op.type in [ if op.type in [
"gaussian_random", "fill_constant", "uniform_random" "gaussian_random", "fill_constant", "uniform_random"
]: ]:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册