From 9735f25011b04116d271861fde8df05def81c3ce Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 5 Jun 2018 14:47:13 +0800 Subject: [PATCH] optimized --- python/paddle/fluid/io.py | 44 +++++++++++++--------------------- python/paddle/fluid/trainer.py | 8 +++---- 2 files changed, 20 insertions(+), 32 deletions(-) diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index b5d96441bcf..5abadc73f76 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -492,7 +492,7 @@ def save_checkpoint(executor, if not os.path.isdir(checkpoint_dir): os.makedirs(checkpoint_dir) - serial = _get_latest_checkpoint_dir(checkpoint_dir) + 1 + serial = get_latest_checkpoint_serial(checkpoint_dir) + 1 cur_dir = _get_serial_dir(checkpoint_dir, serial) save_trainer_args(cur_dir, trainer_id, trainer_args) @@ -503,18 +503,6 @@ def save_checkpoint(executor, _lru_delete(checkpoint_dir, max_num_checkpoints) -def get_latest_checkpoint_serial(checkpoint_dir): - """ - If the directory have checkpoint files, it will return latest checkpoint directory serial number - - :param checkpoint_dir - """ - serial = _get_latest_checkpoint_dir(checkpoint_dir) - if serial < 0: - return None - return serial - - def load_checkpoint(executor, checkpoint_dir, serial, main_program): """ Load checkpoint from a directory by executor, @@ -527,17 +515,16 @@ def load_checkpoint(executor, checkpoint_dir, serial, main_program): """ if checkpoint_dir is None: - raise ValueError( - "The values of 'checkpoint_dir' or 'serial' should not be None") + raise ValueError("The values of 'checkpoint_dir' should not be None") if serial is None or serial < 0: raise ValueError("The values of 'serial' should not be None or <0 ") if main_program is None: - raise ValueError("The values of 'main_program'should not be None") + raise ValueError('main_program should not be None.') cur_dir = _get_serial_dir(checkpoint_dir, serial) - load_persist_vars_without_grad(executor, cur_dir, main_program) + load_persist_vars_without_grad(executor, cur_dir, main_program, True) def clean_checkpoint(checkpoint_dir, delete_dir=False): @@ -557,18 +544,21 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False): os.rmdir(checkpoint_dir) -def load_persist_vars_without_grad(executor, dirname, program, nest=True): +def load_persist_vars_without_grad(executor, + dirname, + program, + has_model_dir=False): """ load_persist_vars_without_grad will load variables from a directory by an executor, the variable named end with "@GRAD" will not be loaded. - :param executor - :param dirname - :param program - :param nest + :param executor executor for load the value + :param dirname the checkpoint directory + :param program will load all variables in program + :param has_model_dir if has_model_dir is True, will load variables from sub directory named __model__ """ - if nest: + if has_model_dir: dirname = _get_model_dir(dirname) load_vars( @@ -584,9 +574,9 @@ def save_persist_vars_without_grad(executor, dirname, program): save_persist_vars_without_grad will save variables to a directory by an executor, the variable named end with "@GRAD" will not be saved. - :param executor - :param dirname - :param program + :param executor executor for load the value + :param dirname the checkpoint directory + :param program will load all variables in program """ cur_dir = _get_model_dir(dirname) save_vars( @@ -722,7 +712,7 @@ def _write_success(dirname): f.write(now) -def _get_latest_checkpoint_dir(checkpoint_dir): +def get_latest_checkpoint_serial(checkpoint_dir): """ get the latest file in checkpoint directory, the _SUCCESS file must exist in the directory diff --git a/python/paddle/fluid/trainer.py b/python/paddle/fluid/trainer.py index 3c32ec1de8a..fbdd28f53ef 100644 --- a/python/paddle/fluid/trainer.py +++ b/python/paddle/fluid/trainer.py @@ -146,8 +146,9 @@ class Trainer(object): "The checkpoint_config shoule be an instance of CheckpointConfig" ) else: - self.checkpoint.load_serial = io.get_latest_checkpoint_serial( + serial = io.get_latest_checkpoint_serial( self.checkpoint.checkpoint_dir) + self.checkpoint.load_serial = serial if serial >= 0 else None self.scope = core.Scope() @@ -194,10 +195,7 @@ class Trainer(object): if param_path and os.path.isdir(param_path): # load params from param_path into scope io.load_persist_vars_without_grad( - exe, - dirname=param_path, - program=self.startup_program, - nest=False) + exe, dirname=param_path, program=self.startup_program) def _transpile_nccl2_dist(self): # PADDLE_TRAINER_IPS -- GitLab