From 486e1e337d05679a22b389840136b9f07714646b Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 29 May 2018 20:36:45 +0800 Subject: [PATCH] bug fix and optimize --- python/paddle/fluid/trainer.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/python/paddle/fluid/trainer.py b/python/paddle/fluid/trainer.py index 206d582cd..35bb8ded5 100644 --- a/python/paddle/fluid/trainer.py +++ b/python/paddle/fluid/trainer.py @@ -79,8 +79,9 @@ class CheckpointConfig(object): else: self.step_interval = step_interval - self.epoch_id = 0 - self.step_id = 0 + self._epoch_id = 0 + self._step_id = 0 + self._load_serial = None def check_and_get_place(place): @@ -174,17 +175,17 @@ class Trainer(object): exe = executor.Executor(place) exe.run(self.startup_program) - if self.load_checkpoint_serial: + if self.checkpoint._load_serial: exe = executor.Executor(place) io.load_checkpoint(exe, self.checkpoint.checkpoint_dir, - self.load_checkpoint_serial, + self.checkpoint._load_serial, self.startup_program) epoch_id, step_id = io.load_trainer_args( self.checkpoint.checkpoint_dir, self.load_checkpoint_serial, self.trainer_id, ["epoch_id", "step_id"]) - self.checkpoint.epoch_id = int(epoch_id) - self.checkpoint.step_id = int(step_id) + self.checkpoint._epoch_id = int(epoch_id) + self.checkpoint._step_id = int(step_id) if param_path and os.path.isdir(param_path): # load params from param_path into scope @@ -351,7 +352,7 @@ class Trainer(object): def _train_by_any_executor(self, event_handler, exe, num_epochs, reader): epochs = [ epoch_id for epoch_id in range(num_epochs) - if epoch_id >= self.checkpoint.epoch_id + if epoch_id >= self.checkpoint._epoch_id ] for epoch_id in epochs: event_handler(BeginEpochEvent(epoch_id)) @@ -360,7 +361,8 @@ class Trainer(object): self._clean_checkpoint() return - if self.checkpoint and self.checkpoint.step_id >= step_id and self.checkpoint.epoch_id == epoch_id: + if self.checkpoint and self.checkpoint._load_serial \ + and self.checkpoint._step_id >= step_id and self.checkpoint._epoch_id == epoch_id: continue begin_event = BeginStepEvent(epoch_id, step_id) -- GitLab