提交 61284c0a 编写于 作者: S seiriosPlus

optimize init from pserver

上级 2ee51619
...@@ -65,6 +65,7 @@ class AsyncMetaOptimizer(MetaOptimizerBase): ...@@ -65,6 +65,7 @@ class AsyncMetaOptimizer(MetaOptimizerBase):
# for startup program # for startup program
_startup = worker.fake_init_ops_pass(_startup, compiled_config) _startup = worker.fake_init_ops_pass(_startup, compiled_config)
_startup = worker.init_from_server_pass(_startup, compiled_config)
_startup = worker.delet_extra_optimizes_pass(_startup, _startup = worker.delet_extra_optimizes_pass(_startup,
compiled_config) compiled_config)
else: else:
......
...@@ -771,6 +771,7 @@ class ParameterServerOptimizer(DistributedOptimizer): ...@@ -771,6 +771,7 @@ class ParameterServerOptimizer(DistributedOptimizer):
# for startup program # for startup program
_startup = worker.fake_init_ops_pass(_startup, compiled_config) _startup = worker.fake_init_ops_pass(_startup, compiled_config)
_startup = worker.init_from_server_pass(_startup, compiled_config)
_startup = worker.delet_extra_optimizes_pass(_startup, _startup = worker.delet_extra_optimizes_pass(_startup,
compiled_config) compiled_config)
else: else:
......
...@@ -212,22 +212,22 @@ def append_send_ops_pass(program, config): ...@@ -212,22 +212,22 @@ def append_send_ops_pass(program, config):
def init_from_server_pass(program, config): def init_from_server_pass(program, config):
fetch_barrier_out = program.global_block().create_var( fetch_barrier_out = program.global_block().create_var(
name=framework.generate_control_dev_var_name()) name=framework.generate_control_dev_var_name())
#
recv_ctx = config.get_communicator_recv_context(recv_type=1) # recv_ctx = config.get_communicator_recv_context(recv_type=1)
recv_varnames = [] # recv_varnames = []
#
for name, ctxs in recv_ctx.items(): # for name, ctxs in recv_ctx.items():
recv_varnames.extend(ctxs.origin_varnames()) # recv_varnames.extend(ctxs.origin_varnames())
#
program.global_block().append_op( # program.global_block().append_op(
type="recv", # type="recv",
inputs={"X": []}, # inputs={"X": []},
outputs={"Out": []}, # outputs={"Out": []},
attrs={ # attrs={
"recv_varnames": recv_varnames, # "recv_varnames": recv_varnames,
"trainer_id": config.get_role_id(), # "trainer_id": config.get_role_id(),
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE # RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
}) # })
program.global_block().append_op( program.global_block().append_op(
type="fetch_barrier", type="fetch_barrier",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册