# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import contextlib import os import errno import shutil import time import core import data_feeder import executor import framework import io # optimizer is same as the parameter of Trainer.__init__. Rename it to opt_module import optimizer as opt_module import parallel_executor from transpiler import distribute_transpiler __all__ = [ 'Trainer', 'BeginEpochEvent', 'EndEpochEvent', 'BeginStepEvent', 'EndStepEvent', 'CheckpointConfig' ] class BeginEpochEvent(object): """ The begin of a training epoch. Args: epoch_id(int): The current epoch ID. """ def __init__(self, epoch_id): self.epoch = epoch_id class EndEpochEvent(object): """ The end of a training epoch. Args: epoch_id(int): The current epoch ID. """ def __init__(self, epoch_id): self.epoch = epoch_id class BeginStepEvent(object): """ The begin of a training epoch. Args: epoch_id(int): The current epoch ID. step_id(int): The current step ID. """ def __init__(self, epoch_id, step_id): self.epoch = epoch_id self.step = step_id self.fetch_metrics = True """ If fetch_metrics is true, the metrics will be fetched at the EndStepEvent. Default is True. """ class EndStepEvent(object): """ The end of a training step. Args: epoch_id(int): The current epoch ID. step_id(int): The current step ID. metrics(list): A list of fetched tensor. The order of this list is same as the :code:`train_func` returns. """ def __init__(self, epoch_id, step_id, metrics): self.epoch = epoch_id self.step = step_id self.metrics = metrics class CheckpointConfig(object): """ Parameter object for :code:`save_checkpoint` and :code:`fluid.Trainer`. Used to configuration how to save checkpoint. Args: checkpoint_dir(str): Directory path to save check point. Default is the current directory. max_num_checkpoints(int): The max number of local check points. epoch_interval(int): Every number of epoch to save check point. step_interval(int): Every number of step to save check point. Examples: >>> config = fluid.CheckpointConfig("./checkpoints") >>> trainer = fluid.Trainer(train_func=train_program, >>> place=place, >>> optimizer_func=optimizer_func, >>> checkpoint_config=config) >>> trainer.train(...) """ def __init__(self, checkpoint_dir=None, max_num_checkpoints=3, epoch_interval=1, step_interval=10): assert epoch_interval >= 1 assert step_interval >= 1 self.checkpoint_dir = checkpoint_dir \ if checkpoint_dir is not None else os.getcwd() self.max_num_checkpoints = max_num_checkpoints self.epoch_interval = epoch_interval self.step_interval = step_interval self.epoch_id = 0 self.step_id = 0 self.load_serial = None def check_and_get_place(place): """ Check the type of place or get the default place Args: place(None|core.CUDAPlace|core.CPUPlace): the place that trainer will be executed on. Raises: TypeError if the type mismatched. Returns: the original place if it is not None. if fluid is compiled with CUDA, returns CUDAPlace(0) by default. Otherwise returns CPUPlace by default. """ if place is None: if core.is_compiled_with_cuda(): return core.CUDAPlace(0) else: return core.CPUPlace() else: if not isinstance(place, core.CUDAPlace) and not isinstance( place, core.CPUPlace): raise TypeError("Place should be either CUDAPlace or CPUPlace") return place class Trainer(object): """ A trainer wraps MultiGPU/MultiNode training loops and can be used to train a simple neural network easily. This API takes a :code:`train_func`. A :code:`train_func` is a function that return loss as it first return value. The reset value can be fetched by EndStepEvent.metrics This API also takes a :code:`optimizer_func` that will return an optimizer instance. For example, to train a MLP for MNIST dataset, the sample program is >>> import paddle.fluid as fluid >>> >>> def mlp(image, layer_sizes=[200, 100], activation="relu", num_classes=10): >>> hidden = image >>> for layer_size in layer_sizes: >>> hidden = fluid.layers.fc(input=hidden, size=layer_size, act=activation) >>> return fluid.layers.fc(input=hidden, size=num_classes, act="softmax") >>> >>> def train_mnist_mlp(): >>> img = fluid.layers.data(name='image', shape=[784]) >>> label = fluid.layers.data(name='label', shape=[1], dtype='int64') >>> prediction = mlp(img) >>> return fluid.layers.mean(fluid.layers.cross_entropy(prediction, label)) >>> >>> def optimizer(): >>> return fluid.optimizer.Adam() >>> >>> trainer = Trainer(train_func=train_mnist_mlp, >>> optimizer_func=optimizer, >>> place=fluid.CUDAPlace(0), >>> parallel=True) >>> >>> def train_callback(event): >>> if isinstance(event, fluid.EndStepEvent): >>> print "Epoch ID", event.epoch, "Step ID",\ >>> event.step, "AvgLoss", event.metrics[0] >>> elif isinstance(event, fluid.EndEpochEvent): >>> trainer.save_params("./model_{0}".format(event.epoch)) >>> >>> trainer.train(num_epochs=100, event_handler=train_callback) For more example, please see :ref:`api_guide_high_level_api`. Args: train_func(callable): A function which will return loss. The loss must be a scalar tensor. optimizer_func(callable): A function that returns an Optimizer object. place(CUDAPlace|CPUPlace): The device place of this trainer. If :code:`parallel=True,` all CUDA Places will be used if :code:`place` is a :code:`CUDAPlace`. parallel(bool): True if use multiple devices. checkpoint_config(CheckpointConfig): Configuration about how to save checkpoints. """ def __init__(self, train_func, optimizer_func, param_path=None, place=None, parallel=False, checkpoint_config=None): self.__stop = False self.parallel = parallel # config for checkpoint # only chief worker will save variables self.trainer_id = 0 self.checkpoint_cfg = checkpoint_config if self.checkpoint_cfg: assert isinstance(self.checkpoint_cfg, CheckpointConfig) serial = _get_latest_checkpoint_serial( self.checkpoint_cfg.checkpoint_dir) self.checkpoint_cfg.load_serial = serial if serial >= 0 else None self.scope = core.Scope() # 1. we need to generate a framework.Program by calling # program_func. Reference: fluid.program_guard in # test_word2vec.py self.startup_program = framework.Program() self.train_program = framework.Program() with framework.program_guard(self.train_program, self.startup_program): program_func_outs = train_func() self.train_func_outputs = program_func_outs if isinstance( program_func_outs, list) else [program_func_outs] self.test_program = self.train_program.clone(for_test=True) # The first element of program_func_outs is loss. loss = self.train_func_outputs[0] optimizer = optimizer_func() if not isinstance(optimizer, opt_module.Optimizer): raise TypeError( "The optimizer should be an instance of Optimizer") optimize_ops, params_grads = optimizer.minimize(loss) self.place = check_and_get_place(place) self._dist_transpile_if_necessary(optimize_ops, params_grads) # 2. move the default_main_program to self.program and run the # default_startup program on an empty core.Scope() # Run startup program with self._prog_and_scope_guard(): exe = executor.Executor(place) exe.run(self.startup_program) if self.checkpoint_cfg and self.checkpoint_cfg.load_serial is not None: self._load_checkpoint() if param_path and os.path.isdir(param_path): # load params from param_path into scope io.load_persistables( executor=exe, dirname=param_path, main_program=self.startup_program) def _transpile_nccl2_dist(self): # PADDLE_TRAINER_IPS if "PADDLE_TRAINER_IPS" not in os.environ: self.nccl_id_var = None else: self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID")) port = os.getenv("PADDLE_PSERVER_PORT") worker_ips = os.getenv("PADDLE_TRAINER_IPS") worker_endpoints = [] for ip in worker_ips.split(","): worker_endpoints.append(':'.join([ip, port])) self.num_trainers = len(worker_endpoints) current_endpoint = os.getenv("PADDLE_CURRENT_IP") + ":" + port worker_endpoints.remove(current_endpoint) # TODO(wuyi): use self.nccl_id_var, self.num_trainers and self.trainer_id # in ParallelExecutor to start # distributed training using NCCL2 self.nccl_id_var = self.startup_program.global_block().create_var( name="NCCLID", persistable=True, type=core.VarDesc.VarType.RAW) self.startup_program.global_block().append_op( type="gen_nccl_id", inputs={}, outputs={"NCCLID": self.nccl_id_var}, attrs={ "endpoint": current_endpoint, "endpoint_list": worker_endpoints, "trainer_id": self.trainer_id }) def _dist_transpile_if_necessary(self, optimize_ops, params_grads): self._transpile_nccl2_dist() if self.nccl_id_var != None: return if "PADDLE_TRAINING_ROLE" not in os.environ: return # the port of all pservers, needed by both trainer and pserver port = os.getenv("PADDLE_PSERVER_PORT", "6174") # comma separated ips of all pservers, needed by trainer and # pserver pserver_ips = os.getenv("PADDLE_PSERVER_IPS", "") eplist = [] for ip in pserver_ips.split(","): eplist.append(':'.join([ip, port])) pserver_endpoints = ",".join(eplist) # total number of workers/trainers in the job, needed by # trainer and pserver trainers = int(os.getenv("PADDLE_TRAINERS")) # the IP of the local machine, needed by pserver only current_endpoint = os.getenv("PADDLE_CURRENT_IP", "") + ":" + port # the unique trainer id, starting from 0, needed by trainer # only self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) # the role, should be either PSERVER or TRAINER training_role = os.getenv("PADDLE_TRAINING_ROLE") with self._prog_and_scope_guard(): t = distribute_transpiler.DistributeTranspiler() t.transpile( self.trainer_id, pservers=pserver_endpoints, trainers=trainers) if training_role == "PSERVER": self.pserver_id = eplist.index(current_endpoint) self.pserver_endpoints = pserver_endpoints self.lookup_table_name = t.table_name if t.has_distributed_lookup_table else None self.train_program = t.get_pserver_program(current_endpoint) self.startup_program = t.get_startup_program(current_endpoint, self.train_program) elif training_role == "TRAINER": self.train_program = t.get_trainer_program() else: raise ValueError( 'TRAINING_ROLE environment variable must be either TRAINER or PSERVER' ) def stop(self): """ stop training """ self.__stop = True def train(self, num_epochs, event_handler, reader=None, feed_order=None): """ Start the train loop to train the model. Args: num_epochs(int): The number of epoch. An epoch will process all data in reader event_handler(callable): The event handler. A function with type (ev:Event)->void reader(callable): A reader creator object. See also :ref:`api_guide_python_reader` . feed_order(list): Feeding order of reader. None will following the defining order in program Returns: None """ training_role = os.getenv("PADDLE_TRAINING_ROLE", "") if training_role == "PSERVER": with self._prog_and_scope_guard(): exe = executor.Executor(self.place) exe.run() return if self.parallel: self._train_by_parallel_executor(num_epochs, event_handler, reader, feed_order) else: self._train_by_executor(num_epochs, event_handler, reader, feed_order) def test(self, reader, feed_order): """ Test the model on given test data Args: reader(callable): The reader that yields test data. feed_order(list): Feeding order of reader. None will following the defining order in program """ return self._test_by_executor(reader, feed_order, self.train_func_outputs) def save_params(self, param_path): """ Save all parameters into :code:`param_path`. Only No.0 trainer will save dense params. In standalone PaddlePaddle, the only existing trainer will save dense params. In distributed PaddlePaddle, the No.0 trainer will save dense params, If there have lookup table need to save, No.0 trainer will broadcast notification to all Parameter Servers to save it on Parameter Servers independent. Args: param_path(str): The path to save parameters. Returns: None """ if self.trainer_id != 0: return with self._prog_and_scope_guard(): # save params on trainer exe = executor.Executor(self.place) io.save_persistables(exe, dirname=param_path) # save params on pserver if self.lookup_table_name: _save_pserver_vars_by_notify(exe, param_path, self.lookup_table_name, self.pserver_endpoints) @contextlib.contextmanager def _prog_and_scope_guard(self): with framework.program_guard( main_program=self.train_program, startup_program=self.startup_program): with executor.scope_guard(self.scope): yield def _train_by_executor(self, num_epochs, event_handler, reader, feed_order): """ Train by Executor and single device. Args: num_epochs: event_handler: reader: feed_order: Returns: """ with self._prog_and_scope_guard(): feed_var_list = build_feed_var_list(self.train_program, feed_order) feeder = data_feeder.DataFeeder( feed_list=feed_var_list, place=self.place) exe = executor.Executor(self.place) reader = feeder.decorate_reader(reader, multi_devices=False) self._train_by_any_executor(event_handler, exe, num_epochs, reader) def _train_by_any_executor(self, event_handler, exe, num_epochs, reader): if self.checkpoint_cfg: epochs = [ epoch_id for epoch_id in range(num_epochs) if epoch_id >= self.checkpoint_cfg.epoch_id ] else: epochs = [epoch_id for epoch_id in range(num_epochs)] for epoch_id in epochs: event_handler(BeginEpochEvent(epoch_id)) for step_id, data in enumerate(reader()): if self.__stop: if self.checkpoint_cfg: self._clean_checkpoint() return if self.checkpoint_cfg and \ self.checkpoint_cfg.load_serial is not None and \ self.checkpoint_cfg.step_id >= step_id and \ self.checkpoint_cfg.epoch_id == epoch_id: continue begin_event = BeginStepEvent(epoch_id, step_id) event_handler(begin_event) if begin_event.fetch_metrics: metrics = exe.run(feed=data, fetch_list=[ var.name for var in self.train_func_outputs ]) else: metrics = exe.run(feed=data, fetch_list=[]) if self.checkpoint_cfg: self._save_checkpoint(epoch_id, step_id) event_handler(EndStepEvent(epoch_id, step_id, metrics)) event_handler(EndEpochEvent(epoch_id)) if self.checkpoint_cfg: self._clean_checkpoint() def _test_by_executor(self, reader, feed_order, fetch_list): with executor.scope_guard(self.scope): feed_var_list = build_feed_var_list(self.test_program, feed_order) feeder = data_feeder.DataFeeder( feed_list=feed_var_list, place=self.place) exe = executor.Executor(self.place) accumulated = len(fetch_list) * [0] count = 0 for data in reader(): outs = exe.run(program=self.test_program, feed=feeder.feed(data), fetch_list=fetch_list) accumulated = [x[0] + x[1][0] for x in zip(accumulated, outs)] count += 1 return [x / count for x in accumulated] def _train_by_parallel_executor(self, num_epochs, event_handler, reader, feed_order): with self._prog_and_scope_guard(): pe = self._get_or_create_parallel_executor() feed_var_list = build_feed_var_list(self.train_program, feed_order) feeder = data_feeder.DataFeeder( feed_list=feed_var_list, place=self.place) reader = feeder.decorate_reader(reader, multi_devices=True) self._train_by_any_executor(event_handler, pe, num_epochs, reader) def _get_parallel_executor(self): return getattr(self, 'parallel_executor', None) def _get_or_create_parallel_executor(self): if self._get_parallel_executor() is None: self.parallel_executor = parallel_executor.ParallelExecutor( use_cuda=isinstance(self.place, core.CUDAPlace), loss_name=self.train_func_outputs[0].name) return self._get_parallel_executor() def _clean_checkpoint(self): assert self.checkpoint_cfg clean_checkpoint(checkpoint_dir=self.checkpoint_cfg.checkpoint_dir) def _get_checkpoint_load_args(self): """ epoch_id and step_id are runtime arguments, they are not variables, will load them independently. """ return ["epoch_id", "step_id"] def _get_checkpoint_save_args(self, epoch_id, step_id): """ epoch_id and step_id are runtime arguments, they are not variables, will save them independently. """ trainer_args = {} trainer_args["epoch_id"] = epoch_id trainer_args["step_id"] = step_id return trainer_args def _save_checkpoint(self, epoch_id, step_id): assert self.checkpoint_cfg if epoch_id % self.checkpoint_cfg.epoch_interval == 0 \ and step_id % self.checkpoint_cfg.step_interval == 0: exe = executor.Executor(self.place) save_checkpoint( executor=exe, checkpoint_dir=self.checkpoint_cfg.checkpoint_dir, main_program=self.train_program, trainer_id=self.trainer_id, save_trainer_args=self._get_checkpoint_save_args(epoch_id, step_id), save_lookup_table=self.lookup_table_name, pserver_endpoints=self.pserver_endpoints, max_num_checkpoints=self.checkpoint_cfg.max_num_checkpoints) def _load_checkpoint(self): with self._prog_and_scope_guard(): exe = executor.Executor(self.place) checkpoint_dir = _get_serial_dir(self.checkpoint_cfg.checkpoint_dir, self.checkpoint_cfg.load_serial) # Trainer Load if self.pserver_id is None: # load model load_checkpoint( executor=exe, checkpoint_dir=checkpoint_dir, main_program=self.startup_program, role_id=self.trainer_id, is_trainer=True, load_models=True) # load trainer_args trainer_args = self._get_checkpoint_load_args() trainer_args_ret = load_checkpoint( executor=exe, checkpoint_dir=checkpoint_dir, main_program=self.startup_program, role_id=self.trainer_id, is_trainer=True, load_trainer_args=trainer_args) if len(trainer_args_ret) != 2: raise ValueError( "the return trainer_args length do not equal _get_checkpoint_load_args" ) self.checkpoint_cfg.epoch_id = int(trainer_args_ret[0]) self.checkpoint_cfg.step_id = int(trainer_args_ret[1]) # Pserver Load else: # load model load_checkpoint( executor=exe, checkpoint_dir=checkpoint_dir, main_program=self.startup_program, role_id=self.pserver_id, is_trainer=False, load_models=True, load_lookup_table=self.lookup_table_name) # load lookup table if self.lookup_table_name: load_checkpoint( executor=exe, checkpoint_dir=checkpoint_dir, main_program=self.startup_program, role_id=self.pserver_id, is_trainer=False, load_lookup_table=self.lookup_table_name) def build_feed_var_list(program, feed_order): if not isinstance(program, framework.Program): raise TypeError("The 'program' should be an object of Program") if isinstance(feed_order, list): feed_var_list = [ program.global_block().var(var_name) for var_name in feed_order ] else: if not isinstance(feed_order, dict): raise TypeError( "The 'feed_order' should be either None, list or dict.") if not sorted(feed_order.values()) == range(len(feed_order)): raise ValueError( "The values of 'feed_order' should be a permutation of [0, len(feed_order))" ) sorted_pair_list = sorted(feed_order.items(), key=lambda item: item[1]) feed_var_list = [ program.global_block().var(pair[0]) for pair in sorted_pair_list ] return feed_var_list # move Checkpoint APIs from io.py to trainer.py, make all of them are private. SUCCESS_MARK_FILENAME = "_SUCCESS" CHECKPOINT_PREFIX = "checkpoint" MODEL_DIR = "__model__" LOOKUP_TABLE_DIR = "__lookup_table__" TRAINER_PREFIX = "trainer" CHECKPOINT_SEPARATOR = "_" def save_checkpoint(executor, checkpoint_dir, main_program=None, trainer_id=0, save_trainer_args=None, save_lookup_table=None, pserver_endpoints=None, max_num_checkpoints=3): """ This function filters out all checkpoint variables from the give main_program and then saves these variables to the `checkpoint_dir` directory. In the training precess, we generally save a checkpoint in each iteration. So there might be a lot of checkpoints in the `checkpoint_dir`. To avoid them taking too much disk space, the `max_num_checkpoints` are introduced to limit the total number of checkpoints. If the number of existing checkpints is greater than the `max_num_checkpoints`, oldest ones will be scroll deleted. A variable is a checkpoint variable and will be saved if it meets all following conditions: 1. It's persistable. 2. It's type is not FEED_MINIBATCH nor FETCH_LIST nor RAW. 3. It's name contains no "@GRAD" nor ".trainer_" nor ".block". Args: executor(Executor): The executor to run for save checkpoint. checkpoint_dir(str): The folder where to save checkpoints. trainer_id(int): currect trainer id, if id is equal to 0, the trainer is chief. trainer_args(dict|None): Current training arguments. Such as 'epoch_id' and 'step_id'. Defaut: None main_program(Program): The program whose checkpoint variables will be saved. max_num_checkpoints(int): The max number of total number of existing checkpoints. Default: 3 save_lookup_table(string|None): the lookup table name, when use distribute lookup table, we can get lookup table name by DistributeTranspiler. table_name pserver_endpoints(list|None): the parameter server ip:port list. when use distribute lookup table, we can get pserver_endpoints by distribute arguments. Returns: None Raises: ValueError: If `checkpoint_dir` is None. AssertionError: If `trainer_args` is not a dict. Examples: .. code-block:: python exe = fluid.Executor(fluid.CPUPlace()) path = "./checkpoints" prog = fluid.default_main_program() trainer_args = {"epoch_id": 200, "step_id": 20} # just an example table_name = "share_w" ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"] save_checkpoint(executor=exe, checkpoint_dir=path, trainer_id=0, trainer_args=trainer_args, main_program=prog, max_num_checkpoints=3, save_lookup_table=table_name, pserver_endpoints = ps_endpoints) """ if checkpoint_dir is None: raise ValueError("'checkpoint_dir' should not be None") _make_chekcpoint_dirs(checkpoint_dir) serial = _get_latest_checkpoint_serial(checkpoint_dir) + 1 cur_dir = _get_serial_dir(checkpoint_dir, serial, True) is_chief = trainer_id == 0 if save_trainer_args is not None: _save_trainer_args(cur_dir, trainer_id, save_trainer_args) if is_chief: if main_program is None: raise ValueError('main_program should not be None.') _save_persistable_vars(executor, cur_dir, main_program) if is_chief and save_lookup_table and pserver_endpoints: _save_pserver_vars_by_notify(executor, cur_dir, save_lookup_table, pserver_endpoints) _scroll_delete(checkpoint_dir, max_num_checkpoints) def load_checkpoint(executor, checkpoint_dir, main_program=None, role_id=0, is_trainer=True, load_models=False, load_trainer_args=None, load_lookup_table=None): """ This function filters out all checkpoint variables from the give main_program and then try to load these variables from the `checkpoint_dir` directory. In the training precess, we generally save a checkpoint in each iteration. So there are more than one checkpoint in the `checkpoint_dir` (each checkpoint has its own sub folder), use `serial` to specify which serial of checkpoint you would like to load. A variable is a checkpoint variable and will be loaded if it meets all following conditions: 1. It's persistable. 2. It's type is not FEED_MINIBATCH nor FETCH_LIST nor RAW. 3. It's name contains no "@GRAD" nor ".trainer_" nor ".block". Args: executor(Executor): The executor to run for loading checkpoint. checkpoint_dir(str): The folder where all checkpoints are. serial(int): The serial of checkpoint you would like to load. main_program(Program|None): The program whose checkpoint variables will be loaded. role_id(int): the trainer id or the parameter server id. is_trainer(bool): trainer is True and parameter server is False. load_trainer_args(list|None): list about load trainer args. load_lookup_table(str|None): the lookup table name Returns: None Raises: ValueError: If `checkpoint_dir` is None. ValueError: If `main_program` is None. Examples: .. code-block:: python exe = fluid.Executor(fluid.CPUPlace()) path = "./checkpoints" prog = fluid.default_main_program() load_checkpoint(executor=exe, checkpoint_dir=path, serial=9, main_program=prog) # In this example, `load_checkpoint` function # will first filters out all checkpoint variables in the default # main program, and then try to load these variables form the # folder "./checkpoints/checkpoint_9/__model__". """ if checkpoint_dir is None: raise ValueError("'checkpoint_dir' should not be None") # trainer load if is_trainer: if load_models: _load_persistable_vars(executor, checkpoint_dir, main_program, True) if load_trainer_args: trainer_args_ret = _load_trainer_args(checkpoint_dir, role_id, load_trainer_args) return trainer_args_ret # pserver load else: if load_models: if load_lookup_table: _load_persistable_vars(executor, checkpoint_dir, main_program, True, [load_lookup_table]) else: _load_persistable_vars(executor, checkpoint_dir, main_program, True) if load_lookup_table: _load_lookup_table_vars(executor, checkpoint_dir, main_program, role_id, load_lookup_table) def clean_checkpoint(checkpoint_dir, delete_dir=False): """ clean the checkpoint dir, when the train exits normally, the trainer will call clean_checkpoint to delete checkpoint directory saved before. delete_dir only works when the directory is empty, otherwise, OSError is raised. : param checkpoint_dir : param delete_dir """ if checkpoint_dir is None: raise ValueError("'checkpoint_dir' should not be None") _scroll_delete(checkpoint_dir, max_num_checkpoints=0) if delete_dir and not os.listdir(checkpoint_dir): os.rmdir(checkpoint_dir) def _load_persistable_vars(executor, dirname, program, has_model_dir=False, except_vars=None): """ This function filters out all checkpoint variables from the give program and then trys to load these variables from the given directory. A variable is a checkpoint variable if it meets all following conditions: 1. It's persistable. 2. It's type is not FEED_MINIBATCH nor FETCH_LIST nor RAW. 3. It's name contains no "@GRAD" nor ".trainer_" nor ".block". Args: executor(Executor): The executor to run for loading variables. dirname(str): The directory path. program(Program): The program whose checkpoint variables will be loaded. has_model_dir(bool): if True, the function loads variables from a sub directory named '__model__'. Default: False Returns: None Examples: .. code-block:: python exe = fluid.Executor(fluid.CPUPlace()) param_path = "./my_paddle_model" prog = fluid.default_main_program() _load_persistable_vars(executor=exe, dirname=param_path, program=prog, has_model_dir=True) # In this example, `_load_persistable_vars` function # will first filters out all checkpoint variables in the default # main program, and then trys to load these variables form the # folder "./my_paddle_model/__model__". """ if has_model_dir: dirname = _get_model_dir(dirname) io.load_vars( executor, dirname=dirname, main_program=program, predicate=_is_checkpoint_var(except_vars), filename=None) def _load_lookup_table_vars(executor, dirname, program, pserver_id, table_name): """ The parameter server will load lookup table's local file in selectedrows variable. Args: executor(Executor): The executor to run for loading persistable variables dirname(str): The directory path main_program(Program): Find the variable named table_name in main_program pserver_id(int): the serial number in pserver_endpoints list table_name(str): lookup table name Returns: None Examples: .. code-block:: python exe = fluid.Executor(fluid.CPUPlace()) dirname = "./checkpoints/checkpoint_9/" prog = fluid.default_main_program() pserver_id = 1 table_name = "share_w" _load_lookup_table_vars(executor=exe, dirname=dirname, program=prog, pserver_id=pserver_id, table_name=table_name) """ for var in program.list_vars(): if var.name == table_name: lookup_table_var = var break assert lookup_table_var is not None lookup_table_dir = os.path.join(dirname, LOOKUP_TABLE_DIR) table_file = table_name + CHECKPOINT_SEPARATOR + str(pserver_id) load_prog = framework.Program() load_block = load_prog.global_block() load_block.append_op( type='load', inputs={}, outputs={'Out': [lookup_table_var]}, attrs={'file_path': os.path.join(lookup_table_dir, table_file)}) executor.run(load_prog) def _save_persistable_vars(executor, dirname, program): """ This function filters out all checkpoint variables from the give program and then save these variables to a sub-folder '__model__' of the given directory. A variable is a checkpoint variable if it meets all following conditions: 1. It's persistable. 2. It's type is not FEED_MINIBATCH nor FETCH_LIST nor RAW. 3. It's name contains no "@GRAD" nor ".trainer_" nor ".block". Args: executor(Executor): The executor to run for saving variables. dirname(str): The directory path. program(Program): The program whose checkpoint variables will be saved. Returns: None Examples: .. code-block:: python exe = fluid.Executor(fluid.CPUPlace()) param_path = "./my_paddle_model" prog = fluid.default_main_program() _save_persistable_vars(executor=exe, dirname=param_path, program=prog) # In this example, `_save_persistable_vars` function # will first filters out all checkpoint variables in the default # main program, and then saves these variables to the folder # "./my_paddle_model/__model__". """ cur_dir = _get_model_dir(dirname) io.save_vars( executor, dirname=cur_dir, main_program=program, vars=None, predicate=_is_checkpoint_var(), filename=None) _write_success(cur_dir) def _save_pserver_vars_by_notify(executor, dirname, lookup_table, pserver_endpoints): """ This function will send checkpoint notify message from Trainer 0 to all the pservers. The checkpoint notify message contains lookup table name, the absolute path on pserver to save lookup_table. Args: executor(Executor): The executor to run for send checkpoint notify. dirname(str): The folder where to save checkpoints. lookup_table(string): the lookup table name, when use distribute lookup table, we can get lookup table name by DistributeTranspiler. table_name pserver_endpoints(list): the parameter server ip:port list. when use distribute lookup table, we can get pserver_endpoints by distribute arguments. Return: None Examples: .. code-block:: python exe = fluid.Executor(fluid.CPUPlace()) param_path = "./my_paddle_model" prog = fluid.default_main_program() table_name = "share_w" ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"] _save_pserver_vars_by_notify(executor=exe, dirname=param_path, lookup_table=table_name, ps_endpoint_list=ps_endpoints) """ cur_dir = _get_lookuptable_dir(dirname) checkpoint_notify_program = framework.Program() checkpoint_notify_block = checkpoint_notify_program.global_block() attrs = {} attrs['epmap'] = pserver_endpoints.split(",") attrs['dir'] = cur_dir attrs['lookup_table'] = lookup_table checkpoint_notify_block.append_op( type='checkpoint_notify', inputs={}, outputs={}, attrs=attrs) executor.run(checkpoint_notify_program) def _save_trainer_args(dirname, trainer_id, trainer_args): assert isinstance(trainer_args, dict) cur_dir = _get_trainer_dir(dirname, trainer_id) for name, value in trainer_args.iteritems(): args_file = os.path.join(cur_dir, name) with open(args_file, 'w') as f: f.write(str(value)) _write_success(cur_dir) def _load_trainer_args(checkpoint_dir, trainer_id, trainer_args): """ trainer will load some args from it's independent directory, such as epoch_id and step_id. Args: checkpoint_dir(str): The folder where all checkpoints are. serial(int): The serial of checkpoint you would like to load. trainer_id(int): current trainer id. trainer_args(list): list about load trainer args Return: None Examples: .. code-block:: python param_path = "./checkpoint/" serial = 7 trainer_id = 2 trainer_args = ["epoch_id", "step_id"] _load_trainer_args(checkpoint_dir=param_path, serial=serial, trainer_id=trainer_id, trainer_args=trainer_args) """ assert isinstance(trainer_args, list) cur_dir = _get_trainer_dir(checkpoint_dir, trainer_id) ret_values = [] for arg in trainer_args: cur_file = os.path.join(cur_dir, arg) with open(cur_file, 'r') as f: contents = f.read() ret_values.append(contents.strip()) return ret_values def _is_checkpoint_var(except_vars=None): except_vars = [] if except_vars is None else except_vars def _except_vars(var): """ the checkpoint will not save or load all the variables. var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded. : param var(Variable) """ if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \ var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \ var.desc.type() == core.VarDesc.VarType.RAW: return False # @GRAD are named for gradient variables, checkpoint will not save it. if "@GRAD" in var.name: return False # .trainer_ are named for distribute train variables, checkpoint will not save it. if ".trainer_" in var.name: return False # .block is named for distribute train variables, checkpoint will not save it. if ".block" in var.name: return False if var in except_vars: return False return var.persistable return _except_vars def _make_chekcpoint_dirs(dirs): """ _make_chekcpoint_dirs will makdir local directory directly, when the directory is exist, it will igore it. """ assert dirs is not None if os.path.isfile(dirs): raise OSError(errno.ENOTDIR, "dirs path shoule be a Directory.", dirs) if not os.path.isdir(dirs): try: os.makedirs(dirs) except OSError as err: if err.errno != errno.EEXIST: raise err def _get_dir_serial(dirname): try: _, serial = dirname.split(CHECKPOINT_SEPARATOR) serial_num = int(serial) except ValueError: serial_num = -1 return serial_num def _get_serial_dir(dirname, serial, makedirs=False): serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial) serial_dir = os.path.join(dirname, serial_folder) if makedirs: _make_chekcpoint_dirs(serial_dir) return serial_dir def _get_model_dir(dirname): model_dir = os.path.join(dirname, MODEL_DIR) _make_chekcpoint_dirs(model_dir) return model_dir def _get_lookuptable_dir(dirname): lookuptable_dir = os.path.join(dirname, LOOKUP_TABLE_DIR) _make_chekcpoint_dirs(lookuptable_dir) return lookuptable_dir def _get_trainer_dir(dirname, trainer_id): trainer_folder = TRAINER_PREFIX + CHECKPOINT_SEPARATOR + str(trainer_id) trainer_dir = os.path.join(dirname, trainer_folder) _make_chekcpoint_dirs(trainer_dir) return trainer_dir def _scroll_delete(dirname, max_num_checkpoints=3): dirs = os.listdir(dirname) serial_map = {} for serial in dirs: serial_num = _get_dir_serial(serial) serial_map[serial_num] = serial if len(serial_map.keys()) <= max_num_checkpoints: return serials = serial_map.keys() serials.sort(reverse=True) serials = serials[max_num_checkpoints:] for serial in serials: cur_dir = _get_serial_dir(dirname, serial) try: shutil.rmtree(cur_dir) except OSError as err: if err.errno != errno.ENOENT: raise err def _write_success(dirname): """ write an empty file named "_SUCCESS" in checkpoint dir, indicate this checkpoint is correct. : param dirname """ success_file = os.path.join(dirname, SUCCESS_MARK_FILENAME) with open(success_file, 'a') as f: now = time.ctime() f.write(now) def _get_latest_checkpoint_serial(checkpoint_dir): """ get the latest file in checkpoint directory, the _SUCCESS file must exist in the directory : param checkpoint_dir """ def has_success(checkpoint_dir, cur_dir): """ is _SUCCESS in this dir """ serial = _get_dir_serial(cur_dir) if serial == -1 or \ not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)): return -1 success_path = os.path.join( _get_serial_dir(checkpoint_dir, serial), MODEL_DIR, SUCCESS_MARK_FILENAME) if os.path.isfile(success_path): return serial current_dir = -1 if not checkpoint_dir or not os.path.isdir(checkpoint_dir): return current_dir dirs = os.listdir(checkpoint_dir) for cur_dir in dirs: success_num = has_success(checkpoint_dir, cur_dir) if success_num > current_dir: current_dir = success_num return current_dir