未验证 提交 7a687876 编写于 作者: D Darcy 提交者: GitHub

Merge pull request #7667 from abhinavarora/dist_transpiler

Remove optimize_op argument from get_pserver_program method of Distributed Transpiler
...@@ -407,7 +407,7 @@ class DistributeTranspiler: ...@@ -407,7 +407,7 @@ class DistributeTranspiler:
outputs=opt_op.outputs, outputs=opt_op.outputs,
attrs=opt_op.attrs) attrs=opt_op.attrs)
def get_pserver_program(self, endpoint, optimize_ops): def get_pserver_program(self, endpoint):
""" """
get pserver side program by endpoint get pserver side program by endpoint
...@@ -422,9 +422,9 @@ class DistributeTranspiler: ...@@ -422,9 +422,9 @@ class DistributeTranspiler:
self._clone_var(pserver_program.global_block(), v) self._clone_var(pserver_program.global_block(), v)
# step6 # step6
optimize_sub_program = Program() optimize_sub_program = Program()
for idx, opt_op in enumerate(optimize_ops): for idx, opt_op in enumerate(self.optimize_ops):
is_op_on_pserver = self._is_op_on_pserver(endpoint, optimize_ops, is_op_on_pserver = self._is_op_on_pserver(endpoint,
idx) self.optimize_ops, idx)
if not is_op_on_pserver: if not is_op_on_pserver:
continue continue
if opt_op.inputs.has_key("Grad"): if opt_op.inputs.has_key("Grad"):
......
...@@ -53,7 +53,7 @@ if training_role == "PSERVER": ...@@ -53,7 +53,7 @@ if training_role == "PSERVER":
if not current_endpoint: if not current_endpoint:
print("need env SERVER_ENDPOINT") print("need env SERVER_ENDPOINT")
exit(1) exit(1)
pserver_prog = t.get_pserver_program(current_endpoint, optimize_ops) pserver_prog = t.get_pserver_program(current_endpoint)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
exe.run(pserver_prog) exe.run(pserver_prog)
else: else:
......
...@@ -197,7 +197,7 @@ def main(): ...@@ -197,7 +197,7 @@ def main():
if not current_endpoint: if not current_endpoint:
print("need env SERVER_ENDPOINT") print("need env SERVER_ENDPOINT")
exit(1) exit(1)
pserver_prog = t.get_pserver_program(current_endpoint, optimize_ops) pserver_prog = t.get_pserver_program(current_endpoint)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
exe.run(pserver_prog) exe.run(pserver_prog)
elif training_role == "TRAINER": elif training_role == "TRAINER":
......
...@@ -87,7 +87,7 @@ if training_role == "PSERVER": ...@@ -87,7 +87,7 @@ if training_role == "PSERVER":
if not current_endpoint: if not current_endpoint:
print("need env SERVER_ENDPOINT") print("need env SERVER_ENDPOINT")
exit(1) exit(1)
pserver_prog = t.get_pserver_program(current_endpoint, optimize_ops) pserver_prog = t.get_pserver_program(current_endpoint)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
exe.run(pserver_prog) exe.run(pserver_prog)
elif training_role == "TRAINER": elif training_role == "TRAINER":
......
...@@ -66,7 +66,7 @@ if training_role == "PSERVER": ...@@ -66,7 +66,7 @@ if training_role == "PSERVER":
if not current_endpoint: if not current_endpoint:
print("need env SERVER_ENDPOINT") print("need env SERVER_ENDPOINT")
exit(1) exit(1)
pserver_prog = t.get_pserver_program(current_endpoint, optimize_ops) pserver_prog = t.get_pserver_program(current_endpoint)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
exe.run(pserver_prog) exe.run(pserver_prog)
elif training_role == "TRAINER": elif training_role == "TRAINER":
......
...@@ -60,7 +60,7 @@ if training_role == "PSERVER": ...@@ -60,7 +60,7 @@ if training_role == "PSERVER":
if not current_endpoint: if not current_endpoint:
print("need env SERVER_ENDPOINT") print("need env SERVER_ENDPOINT")
exit(1) exit(1)
pserver_prog = t.get_pserver_program(current_endpoint, optimize_ops) pserver_prog = t.get_pserver_program(current_endpoint)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
exe.run(pserver_prog) exe.run(pserver_prog)
elif training_role == "TRAINER": elif training_role == "TRAINER":
......
...@@ -98,7 +98,7 @@ def main(): ...@@ -98,7 +98,7 @@ def main():
if not current_endpoint: if not current_endpoint:
print("need env SERVER_ENDPOINT") print("need env SERVER_ENDPOINT")
exit(1) exit(1)
pserver_prog = t.get_pserver_program(current_endpoint, optimize_ops) pserver_prog = t.get_pserver_program(current_endpoint)
exe.run(pserver_prog) exe.run(pserver_prog)
elif training_role == "TRAINER": elif training_role == "TRAINER":
trainer_prog = t.get_trainer_program() trainer_prog = t.get_trainer_program()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册