未验证 提交 7a687876 编写于 作者: D Darcy 提交者: GitHub

Merge pull request #7667 from abhinavarora/dist_transpiler

Remove optimize_op argument from get_pserver_program method of Distributed Transpiler
......@@ -407,7 +407,7 @@ class DistributeTranspiler:
outputs=opt_op.outputs,
attrs=opt_op.attrs)
def get_pserver_program(self, endpoint, optimize_ops):
def get_pserver_program(self, endpoint):
"""
get pserver side program by endpoint
......@@ -422,9 +422,9 @@ class DistributeTranspiler:
self._clone_var(pserver_program.global_block(), v)
# step6
optimize_sub_program = Program()
for idx, opt_op in enumerate(optimize_ops):
is_op_on_pserver = self._is_op_on_pserver(endpoint, optimize_ops,
idx)
for idx, opt_op in enumerate(self.optimize_ops):
is_op_on_pserver = self._is_op_on_pserver(endpoint,
self.optimize_ops, idx)
if not is_op_on_pserver:
continue
if opt_op.inputs.has_key("Grad"):
......
......@@ -53,7 +53,7 @@ if training_role == "PSERVER":
if not current_endpoint:
print("need env SERVER_ENDPOINT")
exit(1)
pserver_prog = t.get_pserver_program(current_endpoint, optimize_ops)
pserver_prog = t.get_pserver_program(current_endpoint)
exe.run(fluid.default_startup_program())
exe.run(pserver_prog)
else:
......
......@@ -197,7 +197,7 @@ def main():
if not current_endpoint:
print("need env SERVER_ENDPOINT")
exit(1)
pserver_prog = t.get_pserver_program(current_endpoint, optimize_ops)
pserver_prog = t.get_pserver_program(current_endpoint)
exe.run(fluid.default_startup_program())
exe.run(pserver_prog)
elif training_role == "TRAINER":
......
......@@ -87,7 +87,7 @@ if training_role == "PSERVER":
if not current_endpoint:
print("need env SERVER_ENDPOINT")
exit(1)
pserver_prog = t.get_pserver_program(current_endpoint, optimize_ops)
pserver_prog = t.get_pserver_program(current_endpoint)
exe.run(fluid.default_startup_program())
exe.run(pserver_prog)
elif training_role == "TRAINER":
......
......@@ -66,7 +66,7 @@ if training_role == "PSERVER":
if not current_endpoint:
print("need env SERVER_ENDPOINT")
exit(1)
pserver_prog = t.get_pserver_program(current_endpoint, optimize_ops)
pserver_prog = t.get_pserver_program(current_endpoint)
exe.run(fluid.default_startup_program())
exe.run(pserver_prog)
elif training_role == "TRAINER":
......
......@@ -60,7 +60,7 @@ if training_role == "PSERVER":
if not current_endpoint:
print("need env SERVER_ENDPOINT")
exit(1)
pserver_prog = t.get_pserver_program(current_endpoint, optimize_ops)
pserver_prog = t.get_pserver_program(current_endpoint)
exe.run(fluid.default_startup_program())
exe.run(pserver_prog)
elif training_role == "TRAINER":
......
......@@ -98,7 +98,7 @@ def main():
if not current_endpoint:
print("need env SERVER_ENDPOINT")
exit(1)
pserver_prog = t.get_pserver_program(current_endpoint, optimize_ops)
pserver_prog = t.get_pserver_program(current_endpoint)
exe.run(pserver_prog)
elif training_role == "TRAINER":
trainer_prog = t.get_trainer_program()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册