diff --git a/python/paddle/distributed/fleet/meta_optimizers/async_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/async_optimizer.py index 0e4876fc6504d0c19a8ced480c60f068feb69da0..b65435497284d279ebdea026e7ac88883a724c7c 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/async_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/async_optimizer.py @@ -65,6 +65,7 @@ class AsyncMetaOptimizer(MetaOptimizerBase): # for startup program _startup = worker.fake_init_ops_pass(_startup, compiled_config) + _startup = worker.init_from_server_pass(_startup, compiled_config) _startup = worker.delet_extra_optimizes_pass(_startup, compiled_config) else: diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py b/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py index e2d0f675216d92f9d33271d49ad9d225b5ba17c0..d2c7397c85f8df155444d9272c7b75596f0fe169 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py @@ -771,6 +771,7 @@ class ParameterServerOptimizer(DistributedOptimizer): # for startup program _startup = worker.fake_init_ops_pass(_startup, compiled_config) + _startup = worker.init_from_server_pass(_startup, compiled_config) _startup = worker.delet_extra_optimizes_pass(_startup, compiled_config) else: diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py b/python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py index 912eee0df0a6f9821066dc5c0285ea27c7e52874..fe483bddd6a482a34431e17fee354f6a8f5d80b1 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py @@ -212,22 +212,22 @@ def append_send_ops_pass(program, config): def init_from_server_pass(program, config): fetch_barrier_out = program.global_block().create_var( name=framework.generate_control_dev_var_name()) - - recv_ctx = config.get_communicator_recv_context(recv_type=1) - recv_varnames = [] - - for name, ctxs in recv_ctx.items(): - recv_varnames.extend(ctxs.origin_varnames()) - - program.global_block().append_op( - type="recv", - inputs={"X": []}, - outputs={"Out": []}, - attrs={ - "recv_varnames": recv_varnames, - "trainer_id": config.get_role_id(), - RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE - }) + # + # recv_ctx = config.get_communicator_recv_context(recv_type=1) + # recv_varnames = [] + # + # for name, ctxs in recv_ctx.items(): + # recv_varnames.extend(ctxs.origin_varnames()) + # + # program.global_block().append_op( + # type="recv", + # inputs={"X": []}, + # outputs={"Out": []}, + # attrs={ + # "recv_varnames": recv_varnames, + # "trainer_id": config.get_role_id(), + # RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE + # }) program.global_block().append_op( type="fetch_barrier",