From 8a178165a6ae4d19f226e2e13d291d96be61798e Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Thu, 14 Jun 2018 20:33:51 +0800 Subject: [PATCH] add lookuo table in python --- python/paddle/fluid/io.py | 30 +++++++++++++++++++++++++++++- python/paddle/fluid/trainer.py | 3 ++- 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 0fb88de0bbb..6a0e422cb37 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -500,6 +500,7 @@ def save_checkpoint(executor, if trainer_id == 0: save_persist_vars_without_grad(executor, cur_dir, main_program) + save_pserver_vars_by_notify(executor, cur_dir, "") _scroll_delete(checkpoint_dir, max_num_checkpoints) @@ -530,7 +531,8 @@ def load_checkpoint(executor, checkpoint_dir, serial, main_program): def clean_checkpoint(checkpoint_dir, delete_dir=False): """ - clean the checkpoint dir, when the train exits normally, the trainer will call clean_checkpoint to delete checkpoint directory saved before. + clean the checkpoint dir, when the train exits normally, + the trainer will call clean_checkpoint to delete checkpoint directory saved before. delete_dir only works when the directory is empty, otherwise, OSError is raised. :param checkpoint_dir @@ -598,6 +600,23 @@ def save_persist_vars_without_grad(executor, dirname, program): _write_success(cur_dir) +def save_pserver_vars_by_notify(executor, dirname, epmap): + """ + """ + cur_dir = _get_lookuptable_dir(dirname) + + checkpoint_notify_program = Program() + checkpoint_notify_block = checkpoint_notify_program.global_block() + + attrs = {} + attrs['epmap'] = None + attrs['dir'] = cur_dir + + checkpoint_notify_block.append_op( + type='checkpointnotify', inputs={}, output={}, attrs=attrs) + executor.run(checkpoint_notify_program) + + def save_trainer_args(dirname, trainer_id, trainer_args): assert isinstance(trainer_args, dict) @@ -680,6 +699,15 @@ def _get_model_dir(dirname): return model_dir +def _get_lookuptable_dir(dirname): + lookuptable_dir = os.path.join(dirname, LOOKUP_TABLE_DIR) + + if not os.path.isdir(lookuptable_dir): + os.makedirs(lookuptable_dir) + + return lookuptable_dir + + def _get_trainer_dir(dirname, trainer_id): trainer_folder = TRAINER_PREFIX + CHECKPOINT_SEPARATOR + str(trainer_id) trainer_dir = os.path.join(dirname, trainer_folder) diff --git a/python/paddle/fluid/trainer.py b/python/paddle/fluid/trainer.py index 2cb908f799b..f77c0f65dcb 100644 --- a/python/paddle/fluid/trainer.py +++ b/python/paddle/fluid/trainer.py @@ -446,7 +446,8 @@ class Trainer(object): def _save_checkpoint(self, epoch_id, step_id): assert self.checkpoint_cfg - if epoch_id % self.checkpoint_cfg.epoch_interval == 0 and step_id % self.checkpoint_cfg.step_interval == 0: + if epoch_id % self.checkpoint_cfg.epoch_interval == 0 \ + and step_id % self.checkpoint_cfg.step_interval == 0: exe = executor.Executor(self.place) io.save_checkpoint( executor=exe, -- GitLab