From a8959162749257cb52449a8effda19bd0c191205 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Thu, 14 Jun 2018 11:33:39 +0800 Subject: [PATCH] [wip] add load lookup table in io and trianer --- python/paddle/fluid/io.py | 13 ++++++++++- python/paddle/fluid/trainer.py | 42 ++++++++++++++++------------------ 2 files changed, 32 insertions(+), 23 deletions(-) diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 6323c9899e0..0fb88de0bbb 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -25,7 +25,8 @@ __all__ = [ 'load_persistables', 'save_inference_model', 'load_inference_model', 'get_inference_program', 'save_checkpoint', 'load_checkpoint', 'clean_checkpoint', 'load_persist_vars_without_grad', - 'save_persist_vars_without_grad', 'get_latest_checkpoint_serial' + 'load_lookup_table_vars', 'save_persist_vars_without_grad', + 'get_latest_checkpoint_serial' ] @@ -459,7 +460,9 @@ def get_parameter_value_by_name(name, executor, program=None): SUCCESS_MARK_FILENAME = "_SUCCESS" CHECKPOINT_PREFIX = "checkpoint" MODEL_DIR = "__model__" +LOOKUP_TABLE_DIR = "__lookup_table__" TRAINER_PREFIX = "trainer" +PSERVER_PREFIX = "pserver" CHECKPOINT_SEPARATOR = "_" @@ -567,6 +570,14 @@ def load_persist_vars_without_grad(executor, filename=None) +def load_lookup_table_vars(executor, dirname, pserver_id, table_name): + lookup_table_dir = os.path.join(dirname, LOOKUP_TABLE_DIR) + table_file = table_name + CHECKPOINT_SEPARATOR + PSERVER_PREFIX + CHECKPOINT_SEPARATOR + str( + pserver_id) + + load_vars(executor, lookup_table_dir, vars=table_name, filename=table_file) + + def save_persist_vars_without_grad(executor, dirname, program): """ save_persist_vars_without_grad will save variables to a directory by an executor, diff --git a/python/paddle/fluid/trainer.py b/python/paddle/fluid/trainer.py index efc28d89930..2cb908f799b 100644 --- a/python/paddle/fluid/trainer.py +++ b/python/paddle/fluid/trainer.py @@ -62,27 +62,20 @@ class CheckpointConfig(object): max_num_checkpoints=3, epoch_interval=1, step_interval=10): - if checkpoint_dir is None: - self.checkpoint_dir = os.getcwd() - else: - self.checkpoint_dir = checkpoint_dir - - self.max_num_checkpoints = max_num_checkpoints - - if epoch_interval < 1: - self.epoch_interval = 1 - else: - self.epoch_interval = epoch_interval - if step_interval < 1: - self.step_interval = 10 - else: - self.step_interval = step_interval + assert epoch_interval >= 1 + assert step_interval >= 1 + self.checkpoint_dir = checkpoint_dir if checkpoint_dir is not None else os.getcwd( + ) + self.max_num_checkpoints = max_num_checkpoints + self.epoch_interval = epoch_interval + self.step_interval = step_interval self.epoch_id = 0 self.step_id = 0 self.load_serial = None self.is_pserver = False + self.has_lookup_table = False def check_and_get_place(place): @@ -181,13 +174,18 @@ class Trainer(object): self.checkpoint_cfg.load_serial, self.startup_program) - if not self.checkpoint_cfg.is_pserver: - epoch_id, step_id = io.load_trainer_args( - self.checkpoint_cfg.checkpoint_dir, - self.checkpoint_cfg.load_serial, self.trainer_id, - self._get_checkpoint_load_args()) - self.checkpoint_cfg.epoch_id = int(epoch_id) - self.checkpoint_cfg.step_id = int(step_id) + if not self.checkpoint_cfg.is_pserver: + epoch_id, step_id = io.load_trainer_args( + self.checkpoint_cfg.checkpoint_dir, + self.checkpoint_cfg.load_serial, self.trainer_id, + self._get_checkpoint_load_args()) + self.checkpoint_cfg.epoch_id = int(epoch_id) + self.checkpoint_cfg.step_id = int(step_id) + else: + if self.checkpoint_cfg.has_lookup_table: + io.load_lookup_table_vars( + exe, self.checkpoint_cfg.checkpoint_dir, 0, + "table_name") if param_path and os.path.isdir(param_path): # load params from param_path into scope -- GitLab