diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 6323c9899e0080b436a52f852c647466b8f94bc1..0fb88de0bbbf6524e11cc706a620e0acbc5f65b7 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -25,7 +25,8 @@ __all__ = [ 'load_persistables', 'save_inference_model', 'load_inference_model', 'get_inference_program', 'save_checkpoint', 'load_checkpoint', 'clean_checkpoint', 'load_persist_vars_without_grad', - 'save_persist_vars_without_grad', 'get_latest_checkpoint_serial' + 'load_lookup_table_vars', 'save_persist_vars_without_grad', + 'get_latest_checkpoint_serial' ] @@ -459,7 +460,9 @@ def get_parameter_value_by_name(name, executor, program=None): SUCCESS_MARK_FILENAME = "_SUCCESS" CHECKPOINT_PREFIX = "checkpoint" MODEL_DIR = "__model__" +LOOKUP_TABLE_DIR = "__lookup_table__" TRAINER_PREFIX = "trainer" +PSERVER_PREFIX = "pserver" CHECKPOINT_SEPARATOR = "_" @@ -567,6 +570,14 @@ def load_persist_vars_without_grad(executor, filename=None) +def load_lookup_table_vars(executor, dirname, pserver_id, table_name): + lookup_table_dir = os.path.join(dirname, LOOKUP_TABLE_DIR) + table_file = table_name + CHECKPOINT_SEPARATOR + PSERVER_PREFIX + CHECKPOINT_SEPARATOR + str( + pserver_id) + + load_vars(executor, lookup_table_dir, vars=table_name, filename=table_file) + + def save_persist_vars_without_grad(executor, dirname, program): """ save_persist_vars_without_grad will save variables to a directory by an executor, diff --git a/python/paddle/fluid/trainer.py b/python/paddle/fluid/trainer.py index efc28d899304b01a3085891f3ae9396d57c589a1..2cb908f799bf87146ed48760999125d625baed7e 100644 --- a/python/paddle/fluid/trainer.py +++ b/python/paddle/fluid/trainer.py @@ -62,27 +62,20 @@ class CheckpointConfig(object): max_num_checkpoints=3, epoch_interval=1, step_interval=10): - if checkpoint_dir is None: - self.checkpoint_dir = os.getcwd() - else: - self.checkpoint_dir = checkpoint_dir - - self.max_num_checkpoints = max_num_checkpoints - - if epoch_interval < 1: - self.epoch_interval = 1 - else: - self.epoch_interval = epoch_interval - if step_interval < 1: - self.step_interval = 10 - else: - self.step_interval = step_interval + assert epoch_interval >= 1 + assert step_interval >= 1 + self.checkpoint_dir = checkpoint_dir if checkpoint_dir is not None else os.getcwd( + ) + self.max_num_checkpoints = max_num_checkpoints + self.epoch_interval = epoch_interval + self.step_interval = step_interval self.epoch_id = 0 self.step_id = 0 self.load_serial = None self.is_pserver = False + self.has_lookup_table = False def check_and_get_place(place): @@ -181,13 +174,18 @@ class Trainer(object): self.checkpoint_cfg.load_serial, self.startup_program) - if not self.checkpoint_cfg.is_pserver: - epoch_id, step_id = io.load_trainer_args( - self.checkpoint_cfg.checkpoint_dir, - self.checkpoint_cfg.load_serial, self.trainer_id, - self._get_checkpoint_load_args()) - self.checkpoint_cfg.epoch_id = int(epoch_id) - self.checkpoint_cfg.step_id = int(step_id) + if not self.checkpoint_cfg.is_pserver: + epoch_id, step_id = io.load_trainer_args( + self.checkpoint_cfg.checkpoint_dir, + self.checkpoint_cfg.load_serial, self.trainer_id, + self._get_checkpoint_load_args()) + self.checkpoint_cfg.epoch_id = int(epoch_id) + self.checkpoint_cfg.step_id = int(step_id) + else: + if self.checkpoint_cfg.has_lookup_table: + io.load_lookup_table_vars( + exe, self.checkpoint_cfg.checkpoint_dir, 0, + "table_name") if param_path and os.path.isdir(param_path): # load params from param_path into scope