diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 5183f1d8b86f22762535d2aa0c1bcbdfd92bbf1d..e5eb34eb0f5cfceeaa7ac09d5bd3d9dc60b7692f 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -895,14 +895,14 @@ def get_test_program(filelist, program=None, startup_program=None): return program -def _load_slice_up_vars(executor, dirname, _slice_vars_and_atts): - if slice_vars == None or len(slice_vars) == 0: +def _load_slice_up_vars(executor, dirname, slice_vars_and_atts): + if slice_vars_and_atts == None or len(slice_vars_and_atts) == 0: return load_prog = Program() load_block = load_prog.global_block() - for var_tuple in slice_vars: + for var_tuple in slice_vars_and_atts: orig_var = var_tuple[0] start = var_tuple[1] slice_var = var_tuple[2] diff --git a/python/paddle/fluid/trainer.py b/python/paddle/fluid/trainer.py index 8194a4e331c2a3fda13dc515d2129e43b7723d52..68c62a0e0bf09312647f28bd12a47971698c4c5f 100644 --- a/python/paddle/fluid/trainer.py +++ b/python/paddle/fluid/trainer.py @@ -235,6 +235,9 @@ class Trainer(object): # config for checkpoint # only chief worker will save variables self.trainer_id = 0 + self.pserver_id = None + self.pserver_endpoints = None + self.lookup_table_name = None self.checkpoint_cfg = checkpoint_config if self.checkpoint_cfg: assert isinstance(self.checkpoint_cfg, CheckpointConfig) @@ -282,10 +285,12 @@ class Trainer(object): if param_path and os.path.isdir(param_path): # load params from param_path into scope - io.load_persistables( - executor=exe, - dirname=param_path, - main_program=self.startup_program) + _load_persistable_vars(exe, param_path, self.startup_program, False, + [self.lookup_table_name] + if self.lookup_table_name else []) + if self.lookup_table_name and self.pserver_id: + _load_lookup_table_vars(exe, param_path, self.startup_program, + self.pserver_id, self.lookup_table_name) def _transpile_nccl2_dist(self): # PADDLE_TRAINER_IPS