diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index c4fad620f0c49bb6b0ad3be22a564c16619efb0b..68aee304a6761e97a0dab4183611d9d07152da16 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -26,6 +26,7 @@ from trainer import BeginEpochEvent from trainer import EndEpochEvent from trainer import BeginStepEvent from trainer import EndStepEvent +from trainer import CheckpointConfig import inferencer from inferencer import Inferencer diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 8e58e5eb794e1bb507ab05394a1f7b57a1d2ed42..6323c9899e0080b436a52f852c647466b8f94bc1 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -24,7 +24,8 @@ __all__ = [ 'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params', 'load_persistables', 'save_inference_model', 'load_inference_model', 'get_inference_program', 'save_checkpoint', 'load_checkpoint', - 'clean_checkpoint' + 'clean_checkpoint', 'load_persist_vars_without_grad', + 'save_persist_vars_without_grad', 'get_latest_checkpoint_serial' ] @@ -457,95 +458,161 @@ def get_parameter_value_by_name(name, executor, program=None): SUCCESS_MARK_FILENAME = "_SUCCESS" CHECKPOINT_PREFIX = "checkpoint" +MODEL_DIR = "__model__" +TRAINER_PREFIX = "trainer" CHECKPOINT_SEPARATOR = "_" def save_checkpoint(executor, - checkpoint_dir=None, - max_num_checkpoints=3, - save_interval_secs=600, - main_program=None): + checkpoint_dir, + trainer_id, + trainer_args=None, + main_program=None, + max_num_checkpoints=3): """ Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory, the directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy to keep numbers of checkpoint directory, the numbers of checkpoint directory are max_num_checkpoints at most, The interval between two saved checkpoints must greater than save_interval_secs. - :param executor - :param checkpoint_dir - :param max_num_checkpoints - :param save_interval_secs - :param main_program + :param executor executor for save the value + :param checkpoint_dir the checkpoint directory + :param trainer_id currect trainer id, if id is equal to 0, the trainer is chief + :param main_program will save all variables in program + :param max_num_checkpoints will keep numbers of checkpoint serials not bigger than max_num_checkpoints """ if checkpoint_dir is None: - checkpoint_dir = os.getcwd() + raise ValueError("'checkpoint_dir' should not be None") + + if trainer_args: + assert isinstance(trainer_args, dict) if not os.path.isdir(checkpoint_dir): os.makedirs(checkpoint_dir) - serial = _get_lastest_checkpoint_dir(checkpoint_dir) - if serial >= 0 and not _interval_secs_exceed( - _get_serial_dir(serial, checkpoint_dir), save_interval_secs): - return + serial = get_latest_checkpoint_serial(checkpoint_dir) + 1 + cur_dir = _get_serial_dir(checkpoint_dir, serial) - serial += 1 - cur_dir = _get_serial_dir(serial, checkpoint_dir) + save_trainer_args(cur_dir, trainer_id, trainer_args) - save_vars( - executor, - dirname=cur_dir, - main_program=main_program, - vars=None, - predicate=_is_checkpoint_var, - filename=None) - _write_success(cur_dir) - _lru_delete(checkpoint_dir, max_num_checkpoints) + if trainer_id == 0: + save_persist_vars_without_grad(executor, cur_dir, main_program) + + _scroll_delete(checkpoint_dir, max_num_checkpoints) -def load_checkpoint(executor, checkpoint_dir=None, main_program=None): +def load_checkpoint(executor, checkpoint_dir, serial, main_program): """ Load checkpoint from a directory by executor, it will find the most recent saved checkpoint file and load it auto. - :param executor - :param checkpoint_dir - :param main_program + :param executor executor for load the value + :param checkpoint_dir the checkpoint directory + :param serial the serial folder in checkpoint directory will be load + :param main_program will load all variables in program """ if checkpoint_dir is None: - checkpoint_dir = os.getcwd() + raise ValueError("'checkpoint_dir' should not be None") - serial = _get_lastest_checkpoint_dir(checkpoint_dir) + if serial is None or serial < 0: + raise ValueError("'serial' should not be None or <0 ") - if serial < 0: - return + if main_program is None: + raise ValueError('main_program should not be None.') - cur_dir = _get_serial_dir(serial, checkpoint_dir) - - load_vars( - executor, - dirname=cur_dir, - main_program=main_program, - predicate=_is_checkpoint_var, - filename=None) + cur_dir = _get_serial_dir(checkpoint_dir, serial) + load_persist_vars_without_grad(executor, cur_dir, main_program, True) def clean_checkpoint(checkpoint_dir, delete_dir=False): """ clean the checkpoint dir, when the train exits normally, the trainer will call clean_checkpoint to delete checkpoint directory saved before. delete_dir only works when the directory is empty, otherwise, OSError is raised. + + :param checkpoint_dir + :param delete_dir """ + if checkpoint_dir is None: - checkpoint_dir = os.getcwd() - _lru_delete(checkpoint_dir, max_num_checkpoints=0) + raise ValueError("'checkpoint_dir' should not be None") + _scroll_delete(checkpoint_dir, max_num_checkpoints=0) if delete_dir and not os.listdir(checkpoint_dir): os.rmdir(checkpoint_dir) -def _get_serial_dir(serial, checkpoint_dir): - serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial) - return os.path.join(checkpoint_dir, serial_folder) +def load_persist_vars_without_grad(executor, + dirname, + program, + has_model_dir=False): + """ + load_persist_vars_without_grad will load variables from a directory by an executor, + the variable named end with "@GRAD" will not be loaded. + + :param executor executor for load the value + :param dirname the checkpoint directory + :param program will load all variables in program + :param has_model_dir if has_model_dir is True, will load variables from sub directory named __model__ + """ + + if has_model_dir: + dirname = _get_model_dir(dirname) + + load_vars( + executor, + dirname=dirname, + main_program=program, + predicate=_is_checkpoint_var, + filename=None) + + +def save_persist_vars_without_grad(executor, dirname, program): + """ + save_persist_vars_without_grad will save variables to a directory by an executor, + the variable named end with "@GRAD" will not be saved. + + :param executor executor for load the value + :param dirname the checkpoint directory + :param program will load all variables in program + """ + cur_dir = _get_model_dir(dirname) + save_vars( + executor, + dirname=cur_dir, + main_program=program, + vars=None, + predicate=_is_checkpoint_var, + filename=None) + _write_success(cur_dir) + + +def save_trainer_args(dirname, trainer_id, trainer_args): + assert isinstance(trainer_args, dict) + + cur_dir = _get_trainer_dir(dirname, trainer_id) + + for name, value in trainer_args.iteritems(): + args_file = os.path.join(cur_dir, name) + with open(args_file, 'w') as f: + f.write(str(value)) + _write_success(cur_dir) + + +def load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args): + assert isinstance(trainer_args, list) + + cur_dir = _get_serial_dir(checkpoint_dir, serial) + cur_dir = _get_trainer_dir(cur_dir, trainer_id) + + ret_values = [] + + for arg in trainer_args: + cur_file = os.path.join(cur_dir, arg) + with open(cur_file, 'r') as f: + contents = f.read() + ret_values.append(contents.strip()) + return ret_values def _is_checkpoint_var(var): @@ -559,36 +626,74 @@ def _is_checkpoint_var(var): var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \ var.desc.type() == core.VarDesc.VarType.RAW: return False + # @GRAD are named for gradient variables, checkpoint will not save it. + if "@GRAD" in var.name: + return False + # .trainer_ are named for distribute train variables, checkpoint will not save it. + if ".trainer_" in var.name: + return False - if var.name.endswith("@GRAD"): + # .block is named for distribute train variables, checkpoint will not save it. + if ".block" in var.name: return False return var.persistable -def _interval_secs_exceed(dirname, save_interval_secs): - dir_time = os.path.getmtime(dirname) - if save_interval_secs > (time.time() - dir_time): - return False - return True +def _get_dir_serial(dirname): + _, serial = dirname.split(CHECKPOINT_SEPARATOR) + + try: + serial_num = int(serial) + except ValueError: + serial_num = -1 + return serial_num + + +def _get_serial_dir(dirname, serial): + serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial) + serial_dir = os.path.join(dirname, serial_folder) + + if not os.path.isdir(serial_dir): + os.makedirs(serial_dir) + + return serial_dir + +def _get_model_dir(dirname): + model_dir = os.path.join(dirname, MODEL_DIR) -def _lru_delete(dirname, max_num_checkpoints=3): + if not os.path.isdir(model_dir): + os.makedirs(model_dir) + + return model_dir + + +def _get_trainer_dir(dirname, trainer_id): + trainer_folder = TRAINER_PREFIX + CHECKPOINT_SEPARATOR + str(trainer_id) + trainer_dir = os.path.join(dirname, trainer_folder) + + if not os.path.isdir(trainer_dir): + os.makedirs(trainer_dir) + + return trainer_dir + + +def _scroll_delete(dirname, max_num_checkpoints=3): dirs = os.listdir(dirname) - serials = [] + serial_map = {} for serial in dirs: - try: - serials.append(int(serial)) - except ValueError: - continue + serial_num = _get_dir_serial(serial) + serial_map[serial_num] = serial - if len(serials) <= max_num_checkpoints: + if len(serial_map.keys()) <= max_num_checkpoints: return + serials = serial_map.keys() serials.sort(reverse=True) serials = serials[max_num_checkpoints:] for serial in serials: - cur_dir = os.path.join(dirname, str(serial)) + cur_dir = _get_serial_dir(dirname, serial) shutil.rmtree(cur_dir) @@ -604,33 +709,30 @@ def _write_success(dirname): f.write(now) -def _get_lastest_checkpoint_dir(checkpoint_dir): +def get_latest_checkpoint_serial(checkpoint_dir): """ get the latest file in checkpoint directory, the _SUCCESS file must exist in the directory :param checkpoint_dir """ - if not checkpoint_dir.strip(): + if not checkpoint_dir: return -1 def has_success(checkpoint_dir, cur_dir): """ is _SUCCESS in this dir """ - _, serial = cur_dir.split(CHECKPOINT_SEPARATOR) - - try: - int(serial) - except ValueError: - return -1 - if not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)): + serial = _get_dir_serial(cur_dir) + if serial == -1 or not os.path.isdir( + os.path.join(checkpoint_dir, cur_dir)): return -1 success_path = os.path.join( - _get_serial_dir(serial, checkpoint_dir), SUCCESS_MARK_FILENAME) + _get_serial_dir(checkpoint_dir, serial), MODEL_DIR, + SUCCESS_MARK_FILENAME) if os.path.isfile(success_path): - return int(serial) + return serial if not os.path.isdir(checkpoint_dir): return -1 diff --git a/python/paddle/fluid/tests/unittests/test_checkpoint.py b/python/paddle/fluid/tests/unittests/test_checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..e22400a045ced16c46b0bf005155f621f249d263 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_checkpoint.py @@ -0,0 +1,75 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +import unittest +import os +import tempfile + + +class TestCheckpoint(unittest.TestCase): + def setUp(self): + self.dirname = tempfile.mktemp() + self.max_num_checkpoints = 3 + self.epoch_interval = 1 + self.step_interval = 1 + self.trainer_id = 0 + self.chief = self.trainer_id == 0 + self.place = fluid.CPUPlace() + self.epoch_id = 100 + self.step_id = 20 + + def test_checkpoint(self): + self.save_checkpoint() + serial = fluid.io.get_latest_checkpoint_serial(self.dirname) + self.assertTrue(serial >= 0) + trainer_args = ["epoch_id", "step_id"] + epoch_id, step_id = fluid.io.load_trainer_args( + self.dirname, serial, self.trainer_id, trainer_args) + self.assertEqual(self.step_id, int(step_id)) + self.assertEqual(self.epoch_id, int(epoch_id)) + + program = fluid.Program() + with fluid.program_guard(program): + exe = fluid.Executor(self.place) + fluid.io.load_checkpoint(exe, self.dirname, serial, program) + + fluid.io.clean_checkpoint(self.dirname, delete_dir=True) + self.assertFalse(os.path.isdir(self.dirname)) + + def save_checkpoint(self): + config = fluid.CheckpointConfig(self.dirname, self.max_num_checkpoints, + self.epoch_interval, self.step_interval) + + trainer_args = {} + trainer_args["epoch_id"] = self.epoch_id + trainer_args["step_id"] = self.step_id + + program = fluid.Program() + with fluid.program_guard(program): + program.global_block().create_var( + name="scale_0", + psersistable=True, + dtype="float32", + shape=[32, 32]) + + exe = fluid.Executor(self.place) + for i in xrange(10): + fluid.io.save_checkpoint(exe, config.checkpoint_dir, + self.trainer_id, trainer_args, program, + config.max_num_checkpoints) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/trainer.py b/python/paddle/fluid/trainer.py index ac313b237eed12f39b2a8e2c7dc6397eeaa224fc..efc28d899304b01a3085891f3ae9396d57c589a1 100644 --- a/python/paddle/fluid/trainer.py +++ b/python/paddle/fluid/trainer.py @@ -27,11 +27,8 @@ import parallel_executor from transpiler import distribute_transpiler __all__ = [ - 'Trainer', - 'BeginEpochEvent', - 'EndEpochEvent', - 'BeginStepEvent', - 'EndStepEvent', + 'Trainer', 'BeginEpochEvent', 'EndEpochEvent', 'BeginStepEvent', + 'EndStepEvent', 'CheckpointConfig' ] @@ -59,6 +56,35 @@ class EndStepEvent(object): self.metrics = metrics +class CheckpointConfig(object): + def __init__(self, + checkpoint_dir=None, + max_num_checkpoints=3, + epoch_interval=1, + step_interval=10): + if checkpoint_dir is None: + self.checkpoint_dir = os.getcwd() + else: + self.checkpoint_dir = checkpoint_dir + + self.max_num_checkpoints = max_num_checkpoints + + if epoch_interval < 1: + self.epoch_interval = 1 + else: + self.epoch_interval = epoch_interval + + if step_interval < 1: + self.step_interval = 10 + else: + self.step_interval = step_interval + + self.epoch_id = 0 + self.step_id = 0 + self.load_serial = None + self.is_pserver = False + + def check_and_get_place(place): """ Check the type of place or get the default place @@ -99,13 +125,24 @@ class Trainer(object): optimizer_func, param_path=None, place=None, - parallel=False): + parallel=False, + checkpoint_config=None): self.__stop = False self.parallel = parallel # 1. we need to generate a framework.Program by calling # program_func. Reference: fluid.program_guard in # test_word2vec.py + # config for checkpoint + # only chief worker will save variables + self.trainer_id = 0 + self.checkpoint_cfg = checkpoint_config + if self.checkpoint_cfg: + assert isinstance(self.checkpoint_cfg, CheckpointConfig) + serial = io.get_latest_checkpoint_serial( + self.checkpoint_cfg.checkpoint_dir) + self.checkpoint_cfg.load_serial = serial if serial >= 0 else None + self.scope = core.Scope() self.startup_program = framework.Program() @@ -137,9 +174,25 @@ class Trainer(object): exe = executor.Executor(place) exe.run(self.startup_program) - if param_path: + if self.checkpoint_cfg and self.checkpoint_cfg.load_serial: + with self._prog_and_scope_guard(): + exe = executor.Executor(place) + io.load_checkpoint(exe, self.checkpoint_cfg.checkpoint_dir, + self.checkpoint_cfg.load_serial, + self.startup_program) + + if not self.checkpoint_cfg.is_pserver: + epoch_id, step_id = io.load_trainer_args( + self.checkpoint_cfg.checkpoint_dir, + self.checkpoint_cfg.load_serial, self.trainer_id, + self._get_checkpoint_load_args()) + self.checkpoint_cfg.epoch_id = int(epoch_id) + self.checkpoint_cfg.step_id = int(step_id) + + if param_path and os.path.isdir(param_path): # load params from param_path into scope - io.load_persistables(exe, dirname=param_path) + io.load_persist_vars_without_grad( + exe, dirname=param_path, program=self.startup_program) def _transpile_nccl2_dist(self): # PADDLE_TRAINER_IPS @@ -194,14 +247,18 @@ class Trainer(object): current_endpoint = os.getenv("PADDLE_CURRENT_IP", "") + ":" + port # the unique trainer id, starting from 0, needed by trainer # only - trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) + self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) + # the role, should be either PSERVER or TRAINER training_role = os.getenv("PADDLE_TRAINING_ROLE") with self._prog_and_scope_guard(): t = distribute_transpiler.DistributeTranspiler() t.transpile( - trainer_id, pservers=pserver_endpoints, trainers=trainers) + self.trainer_id, pservers=pserver_endpoints, trainers=trainers) if training_role == "PSERVER": + if self.checkpoint_cfg: + self.is_pserver = True + self.train_program = t.get_pserver_program(current_endpoint) self.startup_program = t.get_startup_program(current_endpoint, self.train_program) @@ -294,11 +351,26 @@ class Trainer(object): self._train_by_any_executor(event_handler, exe, num_epochs, reader) def _train_by_any_executor(self, event_handler, exe, num_epochs, reader): - for epoch_id in range(num_epochs): + if self.checkpoint_cfg: + epochs = [ + epoch_id for epoch_id in range(num_epochs) + if epoch_id >= self.checkpoint_cfg.epoch_id + ] + else: + epochs = [epoch_id for epoch_id in range(num_epochs)] + + for epoch_id in epochs: event_handler(BeginEpochEvent(epoch_id)) for step_id, data in enumerate(reader()): if self.__stop: + if self.checkpoint_cfg: + self._clean_checkpoint() return + + if self.checkpoint_cfg and self.checkpoint_cfg.load_serial \ + and self.checkpoint_cfg.step_id >= step_id and self.checkpoint_cfg.epoch_id == epoch_id: + continue + begin_event = BeginStepEvent(epoch_id, step_id) event_handler(begin_event) if begin_event.fetch_metrics: @@ -309,8 +381,13 @@ class Trainer(object): ]) else: metrics = exe.run(feed=data, fetch_list=[]) + + if self.checkpoint_cfg: + self._save_checkpoint(epoch_id, step_id) event_handler(EndStepEvent(epoch_id, step_id, metrics)) event_handler(EndEpochEvent(epoch_id)) + if self.checkpoint_cfg: + self._clean_checkpoint() def _test_by_executor(self, reader, feed_order, fetch_list): with executor.scope_guard(self.scope): @@ -349,6 +426,38 @@ class Trainer(object): loss_name=self.train_func_outputs[0].name) return self._get_parallel_executor() + def _clean_checkpoint(self): + assert self.checkpoint_cfg + io.clean_checkpoint(checkpoint_dir=self.checkpoint_cfg.checkpoint_dir) + + def _get_checkpoint_load_args(self): + """ + epoch_id and step_id are runtime arguments, they are not variables, will load them independently. + """ + return ["epoch_id", "step_id"] + + def _get_checkpoint_save_args(self, epoch_id, step_id): + """ + epoch_id and step_id are runtime arguments, they are not variables, will save them independently. + """ + trainer_args = {} + trainer_args["epoch_id"] = epoch_id + trainer_args["step_id"] = step_id + return trainer_args + + def _save_checkpoint(self, epoch_id, step_id): + assert self.checkpoint_cfg + + if epoch_id % self.checkpoint_cfg.epoch_interval == 0 and step_id % self.checkpoint_cfg.step_interval == 0: + exe = executor.Executor(self.place) + io.save_checkpoint( + executor=exe, + checkpoint_dir=self.checkpoint_cfg.checkpoint_dir, + trainer_id=self.trainer_id, + trainer_args=self._get_checkpoint_save_args(epoch_id, step_id), + main_program=self.train_program, + max_num_checkpoints=self.checkpoint_cfg.max_num_checkpoints) + def build_feed_var_list(program, feed_order): if not isinstance(program, framework.Program):