From 57a1d183447fe2231a691ecd0f975d2d2971b41e Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Wed, 25 Jul 2018 14:52:15 +0800 Subject: [PATCH] hidden slice_vars in distribute transpile, hidden it to users --- python/paddle/fluid/io.py | 6 +++--- python/paddle/fluid/trainer.py | 13 +++++++++---- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 5183f1d8b8..e5eb34eb0f 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -895,14 +895,14 @@ def get_test_program(filelist, program=None, startup_program=None): return program -def _load_slice_up_vars(executor, dirname, _slice_vars_and_atts): - if slice_vars == None or len(slice_vars) == 0: +def _load_slice_up_vars(executor, dirname, slice_vars_and_atts): + if slice_vars_and_atts == None or len(slice_vars_and_atts) == 0: return load_prog = Program() load_block = load_prog.global_block() - for var_tuple in slice_vars: + for var_tuple in slice_vars_and_atts: orig_var = var_tuple[0] start = var_tuple[1] slice_var = var_tuple[2] diff --git a/python/paddle/fluid/trainer.py b/python/paddle/fluid/trainer.py index 8194a4e331..68c62a0e0b 100644 --- a/python/paddle/fluid/trainer.py +++ b/python/paddle/fluid/trainer.py @@ -235,6 +235,9 @@ class Trainer(object): # config for checkpoint # only chief worker will save variables self.trainer_id = 0 + self.pserver_id = None + self.pserver_endpoints = None + self.lookup_table_name = None self.checkpoint_cfg = checkpoint_config if self.checkpoint_cfg: assert isinstance(self.checkpoint_cfg, CheckpointConfig) @@ -282,10 +285,12 @@ class Trainer(object): if param_path and os.path.isdir(param_path): # load params from param_path into scope - io.load_persistables( - executor=exe, - dirname=param_path, - main_program=self.startup_program) + _load_persistable_vars(exe, param_path, self.startup_program, False, + [self.lookup_table_name] + if self.lookup_table_name else []) + if self.lookup_table_name and self.pserver_id: + _load_lookup_table_vars(exe, param_path, self.startup_program, + self.pserver_id, self.lookup_table_name) def _transpile_nccl2_dist(self): # PADDLE_TRAINER_IPS -- GitLab