From 6db240d78b3b515a1b2d885e8cc6d8e0b2ffd638 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 5 Jun 2018 19:25:55 +0800 Subject: [PATCH] update trainer about epoch_id and step id --- python/paddle/fluid/trainer.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/python/paddle/fluid/trainer.py b/python/paddle/fluid/trainer.py index fbdd28f53..4ffc20645 100644 --- a/python/paddle/fluid/trainer.py +++ b/python/paddle/fluid/trainer.py @@ -188,7 +188,7 @@ class Trainer(object): if not self.checkpoint.is_pserver: epoch_id, step_id = io.load_trainer_args( self.checkpoint.checkpoint_dir, self.checkpoint.load_serial, - self.trainer_id, ["epoch_id", "step_id"]) + self.trainer_id, self._get_checkpoint_load_args()) self.checkpoint.epoch_id = int(epoch_id) self.checkpoint.step_id = int(step_id) @@ -432,22 +432,33 @@ class Trainer(object): return io.clean_checkpoint(checkpoint_dir=self.checkpoint.checkpoint_dir) + def _get_checkpoint_load_args(self): + """ + epoch_id and step_id are runtime arguments, they are not variables, will load them independently. + """ + return ["epoch_id", "step_id"] + + def _get_checkpoint_save_args(self, epoch_id, step_id): + """ + epoch_id and step_id are runtime arguments, they are not variables, will save them independently. + """ + trainer_args = {} + trainer_args["epoch_id"] = epoch_id + trainer_args["step_id"] = step_id + return trainer_args + def _save_checkpoint(self, epoch_id, step_id): if not self.checkpoint: return if epoch_id % self.checkpoint.epoch_interval == 0 and step_id % self.checkpoint.step_interval == 0: - trainer_args = {} - trainer_args["epoch_id"] = epoch_id - trainer_args["step_id"] = step_id - exe = executor.Executor(self.place) io.save_checkpoint( executor=exe, checkpoint_dir=self.checkpoint.checkpoint_dir, trainer_id=self.trainer_id, is_chief=self.chief, - trainer_args=trainer_args, + trainer_args=self._get_checkpoint_save_args(epoch_id, step_id), main_program=self.train_program, max_num_checkpoints=self.checkpoint.max_num_checkpoints) -- GitLab