提交 486e1e33 编写于 作者: T tangwei12

bug fix and optimize

上级 ad9dfeb0
......@@ -79,8 +79,9 @@ class CheckpointConfig(object):
else:
self.step_interval = step_interval
self.epoch_id = 0
self.step_id = 0
self._epoch_id = 0
self._step_id = 0
self._load_serial = None
def check_and_get_place(place):
......@@ -174,17 +175,17 @@ class Trainer(object):
exe = executor.Executor(place)
exe.run(self.startup_program)
if self.load_checkpoint_serial:
if self.checkpoint._load_serial:
exe = executor.Executor(place)
io.load_checkpoint(exe, self.checkpoint.checkpoint_dir,
self.load_checkpoint_serial,
self.checkpoint._load_serial,
self.startup_program)
epoch_id, step_id = io.load_trainer_args(
self.checkpoint.checkpoint_dir, self.load_checkpoint_serial,
self.trainer_id, ["epoch_id", "step_id"])
self.checkpoint.epoch_id = int(epoch_id)
self.checkpoint.step_id = int(step_id)
self.checkpoint._epoch_id = int(epoch_id)
self.checkpoint._step_id = int(step_id)
if param_path and os.path.isdir(param_path):
# load params from param_path into scope
......@@ -351,7 +352,7 @@ class Trainer(object):
def _train_by_any_executor(self, event_handler, exe, num_epochs, reader):
epochs = [
epoch_id for epoch_id in range(num_epochs)
if epoch_id >= self.checkpoint.epoch_id
if epoch_id >= self.checkpoint._epoch_id
]
for epoch_id in epochs:
event_handler(BeginEpochEvent(epoch_id))
......@@ -360,7 +361,8 @@ class Trainer(object):
self._clean_checkpoint()
return
if self.checkpoint and self.checkpoint.step_id >= step_id and self.checkpoint.epoch_id == epoch_id:
if self.checkpoint and self.checkpoint._load_serial \
and self.checkpoint._step_id >= step_id and self.checkpoint._epoch_id == epoch_id:
continue
begin_event = BeginStepEvent(epoch_id, step_id)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册