diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 8fcc7787091b4d0ec3b6566be7fd826f3f95d7db..34c527b62f4eef8e2257628c911e11b79756d004 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -476,14 +476,14 @@ def save_checkpoint(executor, to keep numbers of checkpoint directory, the numbers of checkpoint directory are max_num_checkpoints at most, The interval between two saved checkpoints must greater than save_interval_secs. - :param executor - :param checkpoint_dir - :param trainer_id - :param is_chief - :param main_program - :param max_num_checkpoints - """ - if checkpoint_dir.strip() is None: + :param executor executor for save the value + :param checkpoint_dir the checkpoint directory + :param trainer_id currect trainer id + :param is_chief if the trainer id equals 0, the is_chief will be true + :param main_program will save all variables in program + :param max_num_checkpoints will keep numbers of checkpoint serials not bigger than max_num_checkpoints + """ + if checkpoint_dir is None: raise ValueError("'checkpoint_dir' should not be None") if trainer_args: @@ -500,7 +500,7 @@ def save_checkpoint(executor, if is_chief: save_persist_vars_without_grad(executor, cur_dir, main_program) - _lru_delete(checkpoint_dir, max_num_checkpoints) + _scroll_delete(checkpoint_dir, max_num_checkpoints) def load_checkpoint(executor, checkpoint_dir, serial, main_program): @@ -508,13 +508,13 @@ def load_checkpoint(executor, checkpoint_dir, serial, main_program): Load checkpoint from a directory by executor, it will find the most recent saved checkpoint file and load it auto. - :param executor - :param checkpoint_dir - :param serial - :param main_program + :param executor executor for load the value + :param checkpoint_dir the checkpoint directory + :param serial the serial folder in checkpoint directory will be load + :param main_program will load all variables in program """ - if checkpoint_dir.strip() is None: + if checkpoint_dir is None: raise ValueError("'checkpoint_dir' should not be None") if serial is None or serial < 0: @@ -536,9 +536,9 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False): :param delete_dir """ - if checkpoint_dir.strip() is None: + if checkpoint_dir is None: raise ValueError("'checkpoint_dir' should not be None") - _lru_delete(checkpoint_dir, max_num_checkpoints=0) + _scroll_delete(checkpoint_dir, max_num_checkpoints=0) if delete_dir and not os.listdir(checkpoint_dir): os.rmdir(checkpoint_dir) @@ -681,7 +681,7 @@ def _get_trainer_dir(dirname, trainer_id): return trainer_dir -def _lru_delete(dirname, max_num_checkpoints=3): +def _scroll_delete(dirname, max_num_checkpoints=3): dirs = os.listdir(dirname) serial_map = {} for serial in dirs: @@ -717,7 +717,7 @@ def get_latest_checkpoint_serial(checkpoint_dir): :param checkpoint_dir """ - if not checkpoint_dir.strip(): + if not checkpoint_dir: return -1 def has_success(checkpoint_dir, cur_dir): @@ -726,10 +726,8 @@ def get_latest_checkpoint_serial(checkpoint_dir): """ serial = _get_dir_serial(cur_dir) - if serial == -1: - return -1 - - if not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)): + if serial == -1 or not os.path.isdir( + os.path.join(checkpoint_dir, cur_dir)): return -1 success_path = os.path.join( diff --git a/python/paddle/fluid/tests/unittests/test_checkpoint.py b/python/paddle/fluid/tests/unittests/test_checkpoint.py index 150e8822d577be7380d826a473c92402317c0ad2..cf70dfd448363c85a257a1c32c68ad67343a15b6 100644 --- a/python/paddle/fluid/tests/unittests/test_checkpoint.py +++ b/python/paddle/fluid/tests/unittests/test_checkpoint.py @@ -15,11 +15,12 @@ import paddle.fluid as fluid import unittest import os +import tempfile class TestCheckpoint(unittest.TestCase): def setUp(self): - self.dirname = "/tmp/ckpt" + self.dirname = tempfile.mktemp() self.max_num_checkpoints = 3 self.epoch_interval = 1 self.step_interval = 1 diff --git a/python/paddle/fluid/trainer.py b/python/paddle/fluid/trainer.py index 9882d5cda04d2f28bca37a9b47445d586a76d57e..e5cec4c76af2f1a21c5b032b5c371d8c66edc90d 100644 --- a/python/paddle/fluid/trainer.py +++ b/python/paddle/fluid/trainer.py @@ -132,19 +132,18 @@ class Trainer(object): # 1. we need to generate a framework.Program by calling # program_func. Reference: fluid.program_guard in # test_word2vec.py - if not isinstance(optimizer, opt_module.Optimizer): - raise TypeError("The optimizer should be an instance of Optimizer") + assert isinstance(optimizer, opt_module.Optimizer) # config for checkpoint # only chief worker will save variables self.trainer_id = 0 self.chief = True - self.checkpoint = checkpoint_config - if self.checkpoint: - assert isinstance(self.checkpoint, CheckpointConfig) + self.checkpoint_cfg = checkpoint_config + if self.checkpoint_cfg: + assert isinstance(self.checkpoint_cfg, CheckpointConfig) serial = io.get_latest_checkpoint_serial( - self.checkpoint.checkpoint_dir) - self.checkpoint.load_serial = serial if serial >= 0 else None + self.checkpoint_cfg.checkpoint_dir) + self.checkpoint_cfg.load_serial = serial if serial >= 0 else None self.scope = core.Scope() @@ -174,19 +173,20 @@ class Trainer(object): exe = executor.Executor(place) exe.run(self.startup_program) - if self.checkpoint and self.checkpoint.load_serial: + if self.checkpoint_cfg and self.checkpoint_cfg.load_serial: with self._prog_and_scope_guard(): exe = executor.Executor(place) - io.load_checkpoint(exe, self.checkpoint.checkpoint_dir, - self.checkpoint.load_serial, + io.load_checkpoint(exe, self.checkpoint_cfg.checkpoint_dir, + self.checkpoint_cfg.load_serial, self.startup_program) - if not self.checkpoint.is_pserver: + if not self.checkpoint_cfg.is_pserver: epoch_id, step_id = io.load_trainer_args( - self.checkpoint.checkpoint_dir, self.checkpoint.load_serial, - self.trainer_id, self._get_checkpoint_load_args()) - self.checkpoint.epoch_id = int(epoch_id) - self.checkpoint.step_id = int(step_id) + self.checkpoint_cfg.checkpoint_dir, + self.checkpoint_cfg.load_serial, self.trainer_id, + self._get_checkpoint_load_args()) + self.checkpoint_cfg.epoch_id = int(epoch_id) + self.checkpoint_cfg.step_id = int(step_id) if param_path and os.path.isdir(param_path): # load params from param_path into scope @@ -256,7 +256,7 @@ class Trainer(object): t.transpile( self.trainer_id, pservers=pserver_endpoints, trainers=trainers) if training_role == "PSERVER": - if self.checkpoint: + if self.checkpoint_cfg: self.is_pserver = True self.train_program = t.get_pserver_program(current_endpoint) @@ -351,10 +351,10 @@ class Trainer(object): self._train_by_any_executor(event_handler, exe, num_epochs, reader) def _train_by_any_executor(self, event_handler, exe, num_epochs, reader): - if self.checkpoint: + if self.checkpoint_cfg: epochs = [ epoch_id for epoch_id in range(num_epochs) - if epoch_id >= self.checkpoint.epoch_id + if epoch_id >= self.checkpoint_cfg.epoch_id ] else: epochs = [epoch_id for epoch_id in range(num_epochs)] @@ -366,8 +366,8 @@ class Trainer(object): self._clean_checkpoint() return - if self.checkpoint and self.checkpoint.load_serial \ - and self.checkpoint.step_id >= step_id and self.checkpoint.epoch_id == epoch_id: + if self.checkpoint_cfg and self.checkpoint_cfg.load_serial \ + and self.checkpoint_cfg.step_id >= step_id and self.checkpoint_cfg.epoch_id == epoch_id: continue begin_event = BeginStepEvent(epoch_id, step_id) @@ -381,10 +381,12 @@ class Trainer(object): else: metrics = exe.run(feed=data, fetch_list=[]) - self._save_checkpoint(epoch_id, step_id) + if self.checkpoint_cfg: + self._save_checkpoint(epoch_id, step_id) event_handler(EndStepEvent(epoch_id, step_id, metrics)) event_handler(EndEpochEvent(epoch_id)) - self._clean_checkpoint() + if self.checkpoint_cfg: + self._clean_checkpoint() def _test_by_executor(self, reader, feed_order, fetch_list): with executor.scope_guard(self.scope): @@ -424,9 +426,8 @@ class Trainer(object): return self._get_parallel_executor() def _clean_checkpoint(self): - if not self.checkpoint: - return - io.clean_checkpoint(checkpoint_dir=self.checkpoint.checkpoint_dir) + assert self.checkpoint_cfg + io.clean_checkpoint(checkpoint_dir=self.checkpoint_cfg.checkpoint_dir) def _get_checkpoint_load_args(self): """ @@ -444,19 +445,18 @@ class Trainer(object): return trainer_args def _save_checkpoint(self, epoch_id, step_id): - if not self.checkpoint: - return + assert self.checkpoint_cfg - if epoch_id % self.checkpoint.epoch_interval == 0 and step_id % self.checkpoint.step_interval == 0: + if epoch_id % self.checkpoint_cfg.epoch_interval == 0 and step_id % self.checkpoint_cfg.step_interval == 0: exe = executor.Executor(self.place) io.save_checkpoint( executor=exe, - checkpoint_dir=self.checkpoint.checkpoint_dir, + checkpoint_dir=self.checkpoint_cfg.checkpoint_dir, trainer_id=self.trainer_id, is_chief=self.chief, trainer_args=self._get_checkpoint_save_args(epoch_id, step_id), main_program=self.train_program, - max_num_checkpoints=self.checkpoint.max_num_checkpoints) + max_num_checkpoints=self.checkpoint_cfg.max_num_checkpoints) def build_feed_var_list(program, feed_order):