diff --git a/python/paddle/distributed/fleet/runtime/the_one_ps.py b/python/paddle/distributed/fleet/runtime/the_one_ps.py index c229d82dd0410bf9aa8ac8efcf164ed871f33d85..cd66a330ee1568468f9831e083aaa570c8caa406 100644 --- a/python/paddle/distributed/fleet/runtime/the_one_ps.py +++ b/python/paddle/distributed/fleet/runtime/the_one_ps.py @@ -528,7 +528,7 @@ class TheOnePSRuntime(RuntimeBase): split_dense_table=self.role_maker._is_heter_parameter_server_mode) send_ctx = self.compiled_strategy.get_the_one_send_context( split_dense_table=self.role_maker._is_heter_parameter_server_mode, - use_origin_program=True, + use_origin_program=self.role_maker._is_heter_parameter_server_mode, ep_list=endpoints) trainer_config = self.async_strategy.get_trainer_runtime_config() diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py b/python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py index 7df175d79b75e4368f7fccdbec6079306987cd37..59d26f4837534b45dd986b0f2a6cd10e3f552ff2 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py @@ -105,8 +105,9 @@ def distributed_ops_pass(program, config, use_ps_gpu=False): if op.type in SPARSE_OP_TYPE_DICT.keys() \ and op.attr('remote_prefetch') is True: param_name = op.input(SPARSE_OP_TYPE_DICT[op.type])[0] - # trick for matchnet, need to modify - param_name += op.input("Ids")[0][0] + if config.is_heter_ps_mode: + # trick for matchnet, need to modify + param_name += op.input("Ids")[0][0] ops = pull_sparse_ops.get(param_name, []) ops.append(op) pull_sparse_ops[param_name] = ops