提交 46f2688f 编写于 作者: T tangwei12

bug fix

上级 bca4da42
...@@ -356,10 +356,14 @@ class Trainer(object): ...@@ -356,10 +356,14 @@ class Trainer(object):
self._train_by_any_executor(event_handler, exe, num_epochs, reader) self._train_by_any_executor(event_handler, exe, num_epochs, reader)
def _train_by_any_executor(self, event_handler, exe, num_epochs, reader): def _train_by_any_executor(self, event_handler, exe, num_epochs, reader):
if self.checkpoint:
epochs = [ epochs = [
epoch_id for epoch_id in range(num_epochs) epoch_id for epoch_id in range(num_epochs)
if epoch_id >= self.checkpoint.epoch_id if epoch_id >= self.checkpoint.epoch_id
] ]
else:
epochs = [epoch_id for epoch_id in range(num_epochs)]
for epoch_id in epochs: for epoch_id in epochs:
event_handler(BeginEpochEvent(epoch_id)) event_handler(BeginEpochEvent(epoch_id))
for step_id, data in enumerate(reader()): for step_id, data in enumerate(reader()):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册