From 8c1e1ded7e94e70ba7929db1472155b7b74e82e2 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Thu, 17 Oct 2019 13:16:24 +0800 Subject: [PATCH] fix fetch handler error with pslib (#20681) * fix fetch handler error with pslib * fix distributed lookup table op with 1 pserver --- python/paddle/fluid/executor.py | 17 ++++------------- .../fluid/transpiler/distribute_transpiler.py | 9 ++++----- 2 files changed, 8 insertions(+), 18 deletions(-) diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index d61f9d9ad0d..4209db5a7a9 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -998,18 +998,6 @@ class Executor(object): if fetch_handler is not None: fetch_instance = fetch_handler - elif fetch_handler is None and fetch_list is not None: - - class FH(FetchHandler): - def handler(self, fetch_target_vars): - for i in range(len(fetch_target_vars)): - print("{}: \n {}\n".format(fetch_info[i], - fetch_target_vars[i])) - - fetch_target_names = [var.name for var in fetch_list] - fetch_instance = FH(fetch_target_names, - period_secs=print_period, - return_np=False) else: fetch_instance = FetchHandler([]) @@ -1018,7 +1006,10 @@ class Executor(object): dataset=dataset, scope=scope, thread=thread, - debug=debug) + debug=debug, + fetch_list=fetch_list, + fetch_info=fetch_info, + print_period=print_period) trainer._set_infer(is_infer) trainer._gen_trainer_desc() diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index c1eff2f5f7e..b608543b3e8 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -793,6 +793,8 @@ class DistributeTranspiler(object): if self.sync_mode: fetch_barrier_input.extend(splited_var) + self._update_remote_sparse_update_op(program, need_sparse_update_params) + if self.sync_mode: # form a WAW dependency program.global_block().append_op( @@ -806,11 +808,10 @@ class DistributeTranspiler(object): }) for param_varname, splited_var in six.iteritems(self.param_var_mapping): - if len(splited_var) <= 1: - continue orig_param = program.global_block().vars[param_varname] if param_varname not in self.sparse_param_to_height_sections: - if not self.config.runtime_split_send_recv: + if len(splited_var + ) > 1 and not self.config.runtime_split_send_recv: program.global_block().append_op( type="concat", inputs={"X": splited_var}, @@ -820,8 +821,6 @@ class DistributeTranspiler(object): RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE }) - self._update_remote_sparse_update_op(program, - need_sparse_update_params) if not self.sync_mode: lr_ops = self._get_lr_ops() if len(lr_ops) > 0: -- GitLab