提交 57a1d183 编写于 作者: T tangwei12

hidden slice_vars in distribute transpile, hidden it to users

上级 3693394c
...@@ -895,14 +895,14 @@ def get_test_program(filelist, program=None, startup_program=None): ...@@ -895,14 +895,14 @@ def get_test_program(filelist, program=None, startup_program=None):
return program return program
def _load_slice_up_vars(executor, dirname, _slice_vars_and_atts): def _load_slice_up_vars(executor, dirname, slice_vars_and_atts):
if slice_vars == None or len(slice_vars) == 0: if slice_vars_and_atts == None or len(slice_vars_and_atts) == 0:
return return
load_prog = Program() load_prog = Program()
load_block = load_prog.global_block() load_block = load_prog.global_block()
for var_tuple in slice_vars: for var_tuple in slice_vars_and_atts:
orig_var = var_tuple[0] orig_var = var_tuple[0]
start = var_tuple[1] start = var_tuple[1]
slice_var = var_tuple[2] slice_var = var_tuple[2]
......
...@@ -235,6 +235,9 @@ class Trainer(object): ...@@ -235,6 +235,9 @@ class Trainer(object):
# config for checkpoint # config for checkpoint
# only chief worker will save variables # only chief worker will save variables
self.trainer_id = 0 self.trainer_id = 0
self.pserver_id = None
self.pserver_endpoints = None
self.lookup_table_name = None
self.checkpoint_cfg = checkpoint_config self.checkpoint_cfg = checkpoint_config
if self.checkpoint_cfg: if self.checkpoint_cfg:
assert isinstance(self.checkpoint_cfg, CheckpointConfig) assert isinstance(self.checkpoint_cfg, CheckpointConfig)
...@@ -282,10 +285,12 @@ class Trainer(object): ...@@ -282,10 +285,12 @@ class Trainer(object):
if param_path and os.path.isdir(param_path): if param_path and os.path.isdir(param_path):
# load params from param_path into scope # load params from param_path into scope
io.load_persistables( _load_persistable_vars(exe, param_path, self.startup_program, False,
executor=exe, [self.lookup_table_name]
dirname=param_path, if self.lookup_table_name else [])
main_program=self.startup_program) if self.lookup_table_name and self.pserver_id:
_load_lookup_table_vars(exe, param_path, self.startup_program,
self.pserver_id, self.lookup_table_name)
def _transpile_nccl2_dist(self): def _transpile_nccl2_dist(self):
# PADDLE_TRAINER_IPS # PADDLE_TRAINER_IPS
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册