diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index bd3c2e3d9a3bb9d7fdf5a87b3fb239ce5dcfd71d..ed560304e25fd9b5a5d8776051190e0ab984550a 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -456,40 +456,18 @@ def get_parameter_value_by_name(name, executor, program=None): return get_parameter_value(var, executor) -def load_persist_vars_without_grad(executor, dirname, program): - """ - load_persist_vars_without_grad will load variables from a directory by an executor, - the variable named end with "@GRAD" will not be loaded. - """ - load_vars( - executor, - dirname=dirname, - main_program=program, - predicate=_is_checkpoint_var, - filename=None) - - -def save_persist_vars_without_grad(executor, dirname, program): - """ - save_persist_vars_without_grad will save variables to a directory by an executor, - the variable named end with "@GRAD" will not be saved. - """ - save_vars( - executor, - dirname=dirname, - main_program=program, - vars=None, - predicate=_is_checkpoint_var, - filename=None) - - SUCCESS_MARK_FILENAME = "_SUCCESS" CHECKPOINT_PREFIX = "checkpoint" +MODEL_DIR = "__model__" +TRAINER_PREFIX = "trainer" CHECKPOINT_SEPARATOR = "_" def save_checkpoint(executor, checkpoint_dir, + trainer_id, + is_chief=False, + trainer_args=None, main_program=None, max_num_checkpoints=3): """ @@ -502,22 +480,35 @@ def save_checkpoint(executor, :param checkpoint_dir :param main_program :param max_num_checkpoints + :param is_chief """ if checkpoint_dir is None: raise ValueError("The values of 'checkpoint_dir' should not be None") + if trainer_args and not isinstance(trainer_args, dict): + raise TypeError("The type of 'trainer_args' should be dict") + if not os.path.isdir(checkpoint_dir): os.makedirs(checkpoint_dir) serial = _get_lastest_checkpoint_dir(checkpoint_dir) + 1 cur_dir = _get_serial_dir(checkpoint_dir, serial) - save_persist_vars_without_grad(executor, cur_dir, main_program) - _write_success(cur_dir) + if is_chief: + save_persist_vars_without_grad(executor, cur_dir, main_program) + + save_trainer_args(cur_dir, trainer_id, trainer_args) _lru_delete(checkpoint_dir, max_num_checkpoints) -def load_checkpoint(executor, checkpoint_dir, main_program=None): +def need_load_checkpoint(checkpoint_dir): + serial = _get_lastest_checkpoint_dir(checkpoint_dir) + if serial < 0: + return None + return serial + + +def load_checkpoint(executor, checkpoint_dir, serial, main_program): """ Load checkpoint from a directory by executor, it will find the most recent saved checkpoint file and load it auto. @@ -528,14 +519,17 @@ def load_checkpoint(executor, checkpoint_dir, main_program=None): """ if checkpoint_dir is None: - raise ValueError("The values of 'checkpoint_dir' should not be None") + raise ValueError( + "The values of 'checkpoint_dir' or 'serial' should not be None") - serial = _get_lastest_checkpoint_dir(checkpoint_dir) + if serial is None or serial < 0: + raise ValueError("The values of 'serial' should not be None or <0 ") - if serial < 0: - return + if main_program is None: + raise ValueError("The values of 'main_program'should not be None") cur_dir = _get_serial_dir(checkpoint_dir, serial) + cur_dir = _get_model_dir(cur_dir) load_persist_vars_without_grad(executor, cur_dir, main_program) @@ -552,6 +546,68 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False): os.rmdir(checkpoint_dir) +def load_persist_vars_without_grad(executor, dirname, program, nest=True): + """ + load_persist_vars_without_grad will load variables from a directory by an executor, + the variable named end with "@GRAD" will not be loaded. + """ + + if nest: + dirname = _get_model_dir(dirname) + + load_vars( + executor, + dirname=dirname, + main_program=program, + predicate=_is_checkpoint_var, + filename=None) + + +def save_persist_vars_without_grad(executor, dirname, program): + """ + save_persist_vars_without_grad will save variables to a directory by an executor, + the variable named end with "@GRAD" will not be saved. + """ + cur_dir = _get_model_dir(dirname) + save_vars( + executor, + dirname=cur_dir, + main_program=program, + vars=None, + predicate=_is_checkpoint_var, + filename=None) + _write_success(cur_dir) + + +def save_trainer_args(dirname, trainer_id, trainer_args): + if not isinstance(trainer_args, dict): + raise TypeError("The type of 'trainer_args' should be dict") + cur_dir = _get_trainer_dir(dirname, trainer_id) + + for name, value in trainer_args.iteritems(): + args_file = os.path.join(cur_dir, name) + with open(args_file, 'w') as f: + f.write(str(value)) + _write_success(cur_dir) + + +def load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args): + cur_dir = _get_serial_dir(checkpoint_dir, serial) + cur_dir = _get_trainer_dir(cur_dir, trainer_id) + + if not isinstance(trainer_args, list): + raise TypeError("The type of 'trainer_args' should be list") + + ret_values = [] + + for arg in trainer_args: + cur_file = os.path.join(cur_dir, arg) + with open(cur_file, 'r') as f: + contents = f.read() + ret_values.append(contents.strip()) + return ret_values + + def _is_checkpoint_var(var): """ the checkpoint will not save or load all the variables. @@ -583,7 +639,31 @@ def _get_dir_serial(dirname): def _get_serial_dir(dirname, serial): serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial) - return os.path.join(dirname, serial_folder) + serial_dir = os.path.join(dirname, serial_folder) + + if not os.path.isdir(serial_dir): + os.makedirs(serial_dir) + + return serial_dir + + +def _get_model_dir(dirname): + model_dir = os.path.join(dirname, MODEL_DIR) + + if not os.path.isdir(model_dir): + os.makedirs(model_dir) + + return model_dir + + +def _get_trainer_dir(dirname, trainer_id): + trainer_folder = TRAINER_PREFIX + CHECKPOINT_SEPARATOR + str(trainer_id) + trainer_dir = os.path.join(dirname, trainer_folder) + + if not os.path.isdir(trainer_dir): + os.makedirs(trainer_dir) + + return trainer_dir def _lru_delete(dirname, max_num_checkpoints=3): @@ -638,7 +718,8 @@ def _get_lastest_checkpoint_dir(checkpoint_dir): return -1 success_path = os.path.join( - _get_serial_dir(checkpoint_dir, serial), SUCCESS_MARK_FILENAME) + _get_serial_dir(checkpoint_dir, serial), MODEL_DIR, + SUCCESS_MARK_FILENAME) if os.path.isfile(success_path): return serial diff --git a/python/paddle/fluid/trainer.py b/python/paddle/fluid/trainer.py index 3cf96ac251132744115795b3dd58ddd7a6ac4d00..206d582cdcaafd5e81c5348f7a6054214005a518 100644 --- a/python/paddle/fluid/trainer.py +++ b/python/paddle/fluid/trainer.py @@ -79,6 +79,9 @@ class CheckpointConfig(object): else: self.step_interval = step_interval + self.epoch_id = 0 + self.step_id = 0 + def check_and_get_place(place): """ @@ -132,6 +135,7 @@ class Trainer(object): # config for checkpoint # only chief worker will save variables + self.trainer_id = 0 self.chief = True self.checkpoint = checkpoint_config if self.checkpoint and \ @@ -139,6 +143,8 @@ class Trainer(object): raise TypeError( "The checkpoint_config shoule be an instance of CheckpointConfig" ) + self.load_checkpoint_serial = io.need_load_checkpoint( + self.checkpoint.checkpoint_dir) self.scope = core.Scope() @@ -168,15 +174,25 @@ class Trainer(object): exe = executor.Executor(place) exe.run(self.startup_program) - if self.checkpoint: + if self.load_checkpoint_serial: exe = executor.Executor(place) io.load_checkpoint(exe, self.checkpoint.checkpoint_dir, + self.load_checkpoint_serial, self.startup_program) - if param_path: + epoch_id, step_id = io.load_trainer_args( + self.checkpoint.checkpoint_dir, self.load_checkpoint_serial, + self.trainer_id, ["epoch_id", "step_id"]) + self.checkpoint.epoch_id = int(epoch_id) + self.checkpoint.step_id = int(step_id) + + if param_path and os.path.isdir(param_path): # load params from param_path into scope io.load_persist_vars_without_grad( - exe, dirname=param_path, program=self.startup_program) + exe, + dirname=param_path, + program=self.startup_program, + nest=False) def _transpile_nccl2_dist(self): # PADDLE_TRAINER_IPS @@ -333,11 +349,20 @@ class Trainer(object): self._train_by_any_executor(event_handler, exe, num_epochs, reader) def _train_by_any_executor(self, event_handler, exe, num_epochs, reader): - for epoch_id in range(num_epochs): + epochs = [ + epoch_id for epoch_id in range(num_epochs) + if epoch_id >= self.checkpoint.epoch_id + ] + for epoch_id in epochs: event_handler(BeginEpochEvent(epoch_id)) for step_id, data in enumerate(reader()): if self.__stop: + self._clean_checkpoint() return + + if self.checkpoint and self.checkpoint.step_id >= step_id and self.checkpoint.epoch_id == epoch_id: + continue + begin_event = BeginStepEvent(epoch_id, step_id) event_handler(begin_event) if begin_event.fetch_metrics: @@ -352,6 +377,7 @@ class Trainer(object): event_handler(EndStepEvent(epoch_id, step_id, metrics)) self._save_checkpoint(epoch_id, step_id) event_handler(EndEpochEvent(epoch_id)) + self._clean_checkpoint() def _test_by_executor(self, reader, feed_order, fetch_list): with executor.scope_guard(self.scope): @@ -390,17 +416,29 @@ class Trainer(object): loss_name=self.train_func_outputs[0].name) return self._get_parallel_executor() + def _clean_checkpoint(self): + if not self.checkpoint: + return + io.clean_checkpoint(checkpoint_dir=self.checkpoint.checkpoint_dir) + def _save_checkpoint(self, epoch_id, step_id): - if not self.checkpoint or not self.chief: + if not self.checkpoint: return if epoch_id % self.checkpoint.epoch_interval == 0 and step_id % self.checkpoint.step_interval == 0: + trainer_args = {} + trainer_args["epoch_id"] = epoch_id + trainer_args["step_id"] = step_id + exe = executor.Executor(self.place) io.save_checkpoint( executor=exe, checkpoint_dir=self.checkpoint.checkpoint_dir, - max_num_checkpoints=self.checkpoint.max_num_checkpoints, - main_program=self.train_program) + trainer_id=self.trainer_id, + is_chief=self.chief, + trainer_args=trainer_args, + main_program=self.train_program, + max_num_checkpoints=self.checkpoint.max_num_checkpoints) def build_feed_var_list(program, feed_order):