From 5eea5db95fb6eaca2db9a0af63e871a9fc29c6bf Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 29 May 2018 14:37:59 +0800 Subject: [PATCH] optimized checkpoint and save_model --- python/paddle/fluid/__init__.py | 1 + python/paddle/fluid/io.py | 61 +++++++++++++++------------------ python/paddle/fluid/trainer.py | 40 +++++++++++++++------ 3 files changed, 58 insertions(+), 44 deletions(-) diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 859605d0053..aece8fc1490 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -26,6 +26,7 @@ from trainer import BeginEpochEvent from trainer import EndEpochEvent from trainer import BeginStepEvent from trainer import EndStepEvent +from trainer import CheckpointConfig import inferencer from inferencer import Inferencer diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index f626039363a..aa039bdfaa4 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -491,7 +491,6 @@ CHECKPOINT_SEPARATOR = "_" def save_checkpoint(executor, checkpoint_dir=None, max_num_checkpoints=3, - save_interval_secs=600, main_program=None): """ Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory, @@ -511,15 +510,10 @@ def save_checkpoint(executor, if not os.path.isdir(checkpoint_dir): os.makedirs(checkpoint_dir) - serial = _get_lastest_checkpoint_dir(checkpoint_dir) - if serial >= 0 and not _interval_secs_exceed( - _get_serial_dir(serial, checkpoint_dir), save_interval_secs): - return - - serial += 1 - cur_dir = _get_serial_dir(serial, checkpoint_dir) + serial = _get_lastest_checkpoint_dir(checkpoint_dir) + 1 + cur_dir = _get_serial_dir(checkpoint_dir, serial) - load_persist_vars_without_grad(executor, cur_dir, main_program) + save_persist_vars_without_grad(executor, cur_dir, main_program) _write_success(cur_dir) _lru_delete(checkpoint_dir, max_num_checkpoints) @@ -542,7 +536,7 @@ def load_checkpoint(executor, checkpoint_dir=None, main_program=None): if serial < 0: return - cur_dir = _get_serial_dir(serial, checkpoint_dir) + cur_dir = _get_serial_dir(checkpoint_dir, serial) load_persist_vars_without_grad(executor, cur_dir, main_program) @@ -559,11 +553,6 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False): os.rmdir(checkpoint_dir) -def _get_serial_dir(serial, checkpoint_dir): - serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial) - return os.path.join(checkpoint_dir, serial_folder) - - def _is_checkpoint_var(var): """ the checkpoint will not save or load all the variables. @@ -582,29 +571,37 @@ def _is_checkpoint_var(var): return var.persistable -def _interval_secs_exceed(dirname, save_interval_secs): - dir_time = os.path.getmtime(dirname) - if save_interval_secs > (time.time() - dir_time): - return False - return True +def _get_dir_serial(dirname): + _, serial = dirname.split(CHECKPOINT_SEPARATOR) + + serial_num = -1 + try: + serial_num = int(serial) + except ValueError: + serial_num = -1 + return serial_num + + +def _get_serial_dir(dirname, serial): + serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial) + return os.path.join(dirname, serial_folder) def _lru_delete(dirname, max_num_checkpoints=3): dirs = os.listdir(dirname) - serials = [] + serial_map = {} for serial in dirs: - try: - serials.append(int(serial)) - except ValueError: - continue + serial_num = _get_dir_serial(serial) + serial_map[serial_num] = serial - if len(serials) <= max_num_checkpoints: + if len(serial_map.keys()) <= max_num_checkpoints: return + serials = serial_map.keys() serials.sort(reverse=True) serials = serials[max_num_checkpoints:] for serial in serials: - cur_dir = os.path.join(dirname, str(serial)) + cur_dir = _get_serial_dir(dirname, serial) shutil.rmtree(cur_dir) @@ -633,20 +630,18 @@ def _get_lastest_checkpoint_dir(checkpoint_dir): """ is _SUCCESS in this dir """ - _, serial = cur_dir.split(CHECKPOINT_SEPARATOR) - try: - int(serial) - except ValueError: + serial = _get_dir_serial(cur_dir) + if serial == -1: return -1 if not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)): return -1 success_path = os.path.join( - _get_serial_dir(serial, checkpoint_dir), SUCCESS_MARK_FILENAME) + _get_serial_dir(checkpoint_dir, serial), SUCCESS_MARK_FILENAME) if os.path.isfile(success_path): - return int(serial) + return serial if not os.path.isdir(checkpoint_dir): return -1 diff --git a/python/paddle/fluid/trainer.py b/python/paddle/fluid/trainer.py index b4b7b75b96e..3cf96ac2511 100644 --- a/python/paddle/fluid/trainer.py +++ b/python/paddle/fluid/trainer.py @@ -60,11 +60,24 @@ class CheckpointConfig(object): def __init__(self, checkpoint_dir=None, max_num_checkpoints=3, - save_interval_secs=600): + epoch_interval=1, + step_interval=10): if checkpoint_dir is None: self.checkpoint_dir = os.getcwd() + else: + self.checkpoint_dir = checkpoint_dir + self.max_num_checkpoints = max_num_checkpoints - self.save_interval_secs = save_interval_secs + + if epoch_interval < 1: + self.epoch_interval = 1 + else: + self.epoch_interval = epoch_interval + + if step_interval < 1: + self.step_interval = 10 + else: + self.step_interval = step_interval def check_and_get_place(place): @@ -290,14 +303,6 @@ class Trainer(object): exe = executor.Executor(self.place) io.save_persistables(exe, dirname=param_path) - def _save_checkpoint(self): - if self.checkpoint and self.chief: - exe = executor.Executor(self.place) - io.save_checkpoint(exe, self.checkpoint.checkpoint_dir, - self.checkpoint.max_num_checkpoints, - self.checkpoint.save_interval_secs, - self.train_program) - @contextlib.contextmanager def _prog_and_scope_guard(self): with framework.program_guard( @@ -343,8 +348,9 @@ class Trainer(object): ]) else: metrics = exe.run(feed=data, fetch_list=[]) + event_handler(EndStepEvent(epoch_id, step_id, metrics)) - self._save_checkpoint() + self._save_checkpoint(epoch_id, step_id) event_handler(EndEpochEvent(epoch_id)) def _test_by_executor(self, reader, feed_order, fetch_list): @@ -384,6 +390,18 @@ class Trainer(object): loss_name=self.train_func_outputs[0].name) return self._get_parallel_executor() + def _save_checkpoint(self, epoch_id, step_id): + if not self.checkpoint or not self.chief: + return + + if epoch_id % self.checkpoint.epoch_interval == 0 and step_id % self.checkpoint.step_interval == 0: + exe = executor.Executor(self.place) + io.save_checkpoint( + executor=exe, + checkpoint_dir=self.checkpoint.checkpoint_dir, + max_num_checkpoints=self.checkpoint.max_num_checkpoints, + main_program=self.train_program) + def build_feed_var_list(program, feed_order): if not isinstance(program, framework.Program): -- GitLab