diff --git a/python/paddle/fluid/trainer.py b/python/paddle/fluid/trainer.py index 7da123dd92ed9d111d68cd70efb8ce1493452609..01c40bb90e464f8c39f7310044969fc153792f27 100644 --- a/python/paddle/fluid/trainer.py +++ b/python/paddle/fluid/trainer.py @@ -27,11 +27,8 @@ import parallel_executor from transpiler import distribute_transpiler __all__ = [ - 'Trainer', - 'BeginEpochEvent', - 'EndEpochEvent', - 'BeginStepEvent', - 'EndStepEvent', + 'Trainer', 'BeginEpochEvent', 'EndEpochEvent', 'BeginStepEvent', + 'EndStepEvent', 'CheckpointConfig' ] @@ -59,6 +56,17 @@ class EndStepEvent(object): self.metrics = metrics +class CheckpointConfig(object): + def __init__(self, + checkpoint_dir=None, + max_num_checkpoints=3, + save_interval_secs=600): + if checkpoint_dir is None: + self.checkpoint_dir = os.getcwd() + self.max_num_checkpoints = max_num_checkpoints + self.save_interval_secs = save_interval_secs + + def check_and_get_place(place): """ Check the type of place or get the default place @@ -97,9 +105,9 @@ class Trainer(object): def __init__(self, train_func, optimizer, - param_path=None, place=None, - parallel=False): + parallel=False, + checkpoint_config=None): self.__stop = False self.parallel = parallel # 1. we need to generate a framework.Program by calling @@ -108,6 +116,16 @@ class Trainer(object): if not isinstance(optimizer, opt_module.Optimizer): raise TypeError("The optimizer should be an instance of Optimizer") + # config for checkpoint + # only chief worker will save variables + self.chief = True + self.checkpoint = checkpoint_config + if self.checkpoint and not isinstance(self.checkpoint, + CheckpointConfig): + raise TypeError( + "The checkpoint_config shoule be an instance of CheckpointConfig" + ) + self.scope = core.Scope() self.startup_program = framework.Program() @@ -136,9 +154,10 @@ class Trainer(object): exe = executor.Executor(place) exe.run(self.startup_program) - if param_path: - # load params from param_path into scope - io.load_persistables(exe, dirname=param_path) + if self.checkpoint: + exe = executor.Executor(place) + io.load_checkpoint(exe, self.checkpoint.checkpoint_dir, + self.startup_program) def _transpile_nccl2_dist(self): # PADDLE_TRAINER_IPS @@ -146,6 +165,7 @@ class Trainer(object): self.nccl_id_var = None else: self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID")) + self.chief = self.trainer_id == 0 port = os.getenv("PADDLE_PSERVER_PORT") worker_ips = os.getenv("PADDLE_TRAINER_IPS") worker_endpoints = [] @@ -194,6 +214,7 @@ class Trainer(object): # the unique trainer id, starting from 0, needed by trainer # only trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) + self.chief = self.trainer_id == 0 # the role, should be either PSERVER or TRAINER training_role = os.getenv("PADDLE_TRAINING_ROLE") with self._prog_and_scope_guard(): @@ -263,6 +284,14 @@ class Trainer(object): exe = executor.Executor(self.place) io.save_persistables(exe, dirname=param_path) + def _save_checkpoint(self): + if self.checkpoint and self.chief: + exe = executor.Executor(self.place) + io.save_checkpoint(exe, self.checkpoint.checkpoint_dir, + self.checkpoint.max_num_checkpoints, + self.checkpoint.save_interval_secs, + self.train_program) + @contextlib.contextmanager def _prog_and_scope_guard(self): with framework.program_guard( @@ -309,6 +338,7 @@ class Trainer(object): else: metrics = exe.run(feed=data, fetch_list=[]) event_handler(EndStepEvent(epoch_id, step_id, metrics)) + self._save_checkpoint() event_handler(EndEpochEvent(epoch_id)) def _test_by_executor(self, reader, feed_order, fetch_list):