未验证 提交 1d925440 编写于 作者: T tangwei12 提交者: GitHub

fix fetch handler error with pslib (#20679)

* fix fetch handler error with pslib
* fix distributed lookup table op with 1 pserver
上级 78431dc7
...@@ -998,18 +998,6 @@ class Executor(object): ...@@ -998,18 +998,6 @@ class Executor(object):
if fetch_handler is not None: if fetch_handler is not None:
fetch_instance = fetch_handler fetch_instance = fetch_handler
elif fetch_handler is None and fetch_list is not None:
class FH(FetchHandler):
def handler(self, fetch_target_vars):
for i in range(len(fetch_target_vars)):
print("{}: \n {}\n".format(fetch_info[i],
fetch_target_vars[i]))
fetch_target_names = [var.name for var in fetch_list]
fetch_instance = FH(fetch_target_names,
period_secs=print_period,
return_np=False)
else: else:
fetch_instance = FetchHandler([]) fetch_instance = FetchHandler([])
...@@ -1018,7 +1006,10 @@ class Executor(object): ...@@ -1018,7 +1006,10 @@ class Executor(object):
dataset=dataset, dataset=dataset,
scope=scope, scope=scope,
thread=thread, thread=thread,
debug=debug) debug=debug,
fetch_list=fetch_list,
fetch_info=fetch_info,
print_period=print_period)
trainer._set_infer(is_infer) trainer._set_infer(is_infer)
trainer._gen_trainer_desc() trainer._gen_trainer_desc()
......
...@@ -793,6 +793,8 @@ class DistributeTranspiler(object): ...@@ -793,6 +793,8 @@ class DistributeTranspiler(object):
if self.sync_mode: if self.sync_mode:
fetch_barrier_input.extend(splited_var) fetch_barrier_input.extend(splited_var)
self._update_remote_sparse_update_op(program, need_sparse_update_params)
if self.sync_mode: if self.sync_mode:
# form a WAW dependency # form a WAW dependency
program.global_block().append_op( program.global_block().append_op(
...@@ -806,11 +808,10 @@ class DistributeTranspiler(object): ...@@ -806,11 +808,10 @@ class DistributeTranspiler(object):
}) })
for param_varname, splited_var in six.iteritems(self.param_var_mapping): for param_varname, splited_var in six.iteritems(self.param_var_mapping):
if len(splited_var) <= 1:
continue
orig_param = program.global_block().vars[param_varname] orig_param = program.global_block().vars[param_varname]
if param_varname not in self.sparse_param_to_height_sections: if param_varname not in self.sparse_param_to_height_sections:
if not self.config.runtime_split_send_recv: if len(splited_var
) > 1 and not self.config.runtime_split_send_recv:
program.global_block().append_op( program.global_block().append_op(
type="concat", type="concat",
inputs={"X": splited_var}, inputs={"X": splited_var},
...@@ -820,8 +821,6 @@ class DistributeTranspiler(object): ...@@ -820,8 +821,6 @@ class DistributeTranspiler(object):
RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE
}) })
self._update_remote_sparse_update_op(program,
need_sparse_update_params)
if not self.sync_mode: if not self.sync_mode:
lr_ops = self._get_lr_ops() lr_ops = self._get_lr_ops()
if len(lr_ops) > 0: if len(lr_ops) > 0:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册