From 0deb6f90baa5dab02b5ff1cbc98dcaf7fae9b80b Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Wed, 30 May 2018 14:20:51 +0800 Subject: [PATCH] annotation optimized and code style optimized --- python/paddle/fluid/io.py | 22 +++++++++++++++++++++- python/paddle/fluid/trainer.py | 12 ++++++------ 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 2925e8eb2..d52c9a882 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -478,9 +478,10 @@ def save_checkpoint(executor, :param executor :param checkpoint_dir + :param trainer_id + :param is_chief :param main_program :param max_num_checkpoints - :param is_chief """ if checkpoint_dir is None: raise ValueError("The values of 'checkpoint_dir' should not be None") @@ -502,6 +503,11 @@ def save_checkpoint(executor, def need_load_checkpoint(checkpoint_dir): + """ + If the directory have checkpoint files, it will return lastest checkpoint directory serial number + + :param checkpoint_dir + """ serial = _get_lastest_checkpoint_dir(checkpoint_dir) if serial < 0: return None @@ -515,6 +521,7 @@ def load_checkpoint(executor, checkpoint_dir, serial, main_program): :param executor :param checkpoint_dir + :param serial :param main_program """ @@ -536,7 +543,11 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False): """ clean the checkpoint dir, when the train exits normally, the trainer will call clean_checkpoint to delete checkpoint directory saved before. delete_dir only works when the directory is empty, otherwise, OSError is raised. + + :param checkpoint_dir + :param delete_dir """ + if checkpoint_dir is None: raise ValueError("The values of 'checkpoint_dir' should not be None") _lru_delete(checkpoint_dir, max_num_checkpoints=0) @@ -549,6 +560,11 @@ def load_persist_vars_without_grad(executor, dirname, program, nest=True): """ load_persist_vars_without_grad will load variables from a directory by an executor, the variable named end with "@GRAD" will not be loaded. + + :param executor + :param dirname + :param program + :param nest """ if nest: @@ -566,6 +582,10 @@ def save_persist_vars_without_grad(executor, dirname, program): """ save_persist_vars_without_grad will save variables to a directory by an executor, the variable named end with "@GRAD" will not be saved. + + :param executor + :param dirname + :param program """ cur_dir = _get_model_dir(dirname) save_vars( diff --git a/python/paddle/fluid/trainer.py b/python/paddle/fluid/trainer.py index 34db9b39b..6d8d4a3e4 100644 --- a/python/paddle/fluid/trainer.py +++ b/python/paddle/fluid/trainer.py @@ -79,8 +79,8 @@ class CheckpointConfig(object): else: self.step_interval = step_interval - self._epoch_id = 0 - self._step_id = 0 + self.epoch_id = 0 + self.step_id = 0 self._load_serial = None @@ -185,8 +185,8 @@ class Trainer(object): epoch_id, step_id = io.load_trainer_args( self.checkpoint.checkpoint_dir, self.checkpoint._load_serial, self.trainer_id, ["epoch_id", "step_id"]) - self.checkpoint._epoch_id = int(epoch_id) - self.checkpoint._step_id = int(step_id) + self.checkpoint.epoch_id = int(epoch_id) + self.checkpoint.step_id = int(step_id) if param_path and os.path.isdir(param_path): # load params from param_path into scope @@ -353,7 +353,7 @@ class Trainer(object): def _train_by_any_executor(self, event_handler, exe, num_epochs, reader): epochs = [ epoch_id for epoch_id in range(num_epochs) - if epoch_id >= self.checkpoint._epoch_id + if epoch_id >= self.checkpoint.epoch_id ] for epoch_id in epochs: event_handler(BeginEpochEvent(epoch_id)) @@ -363,7 +363,7 @@ class Trainer(object): return if self.checkpoint and self.checkpoint._load_serial \ - and self.checkpoint._step_id >= step_id and self.checkpoint._epoch_id == epoch_id: + and self.checkpoint.step_id >= step_id and self.checkpoint.epoch_id == epoch_id: continue begin_event = BeginStepEvent(epoch_id, step_id) -- GitLab