From 61284c0aaf8b411976615069337d98ee7c72d078 Mon Sep 17 00:00:00 2001 From: seiriosPlus Date: Fri, 21 Aug 2020 19:31:15 +0800 Subject: [PATCH] optimize init from pserver --- .../fleet/meta_optimizers/async_optimizer.py | 1 + .../distribute_transpiler/__init__.py | 1 + .../fleet/parameter_server/ir/trainer_pass.py | 32 +++++++++---------- 3 files changed, 18 insertions(+), 16 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/async_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/async_optimizer.py index 0e4876fc650..b6543549728 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/async_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/async_optimizer.py @@ -65,6 +65,7 @@ class AsyncMetaOptimizer(MetaOptimizerBase): # for startup program _startup = worker.fake_init_ops_pass(_startup, compiled_config) + _startup = worker.init_from_server_pass(_startup, compiled_config) _startup = worker.delet_extra_optimizes_pass(_startup, compiled_config) else: diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py b/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py index e2d0f675216..d2c7397c85f 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py @@ -771,6 +771,7 @@ class ParameterServerOptimizer(DistributedOptimizer): # for startup program _startup = worker.fake_init_ops_pass(_startup, compiled_config) + _startup = worker.init_from_server_pass(_startup, compiled_config) _startup = worker.delet_extra_optimizes_pass(_startup, compiled_config) else: diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py b/python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py index 912eee0df0a..fe483bddd6a 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py @@ -212,22 +212,22 @@ def append_send_ops_pass(program, config): def init_from_server_pass(program, config): fetch_barrier_out = program.global_block().create_var( name=framework.generate_control_dev_var_name()) - - recv_ctx = config.get_communicator_recv_context(recv_type=1) - recv_varnames = [] - - for name, ctxs in recv_ctx.items(): - recv_varnames.extend(ctxs.origin_varnames()) - - program.global_block().append_op( - type="recv", - inputs={"X": []}, - outputs={"Out": []}, - attrs={ - "recv_varnames": recv_varnames, - "trainer_id": config.get_role_id(), - RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE - }) + # + # recv_ctx = config.get_communicator_recv_context(recv_type=1) + # recv_varnames = [] + # + # for name, ctxs in recv_ctx.items(): + # recv_varnames.extend(ctxs.origin_varnames()) + # + # program.global_block().append_op( + # type="recv", + # inputs={"X": []}, + # outputs={"Out": []}, + # attrs={ + # "recv_varnames": recv_varnames, + # "trainer_id": config.get_role_id(), + # RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE + # }) program.global_block().append_op( type="fetch_barrier", -- GitLab