提交 a688c3fb 编写于 作者: W wuzewu

Update save_checkpoint api

上级 54a193c4
...@@ -443,7 +443,7 @@ class BasicTask(object): ...@@ -443,7 +443,7 @@ class BasicTask(object):
"step %d: [step/sec: %.2f]" % (self.current_step, run_speed)) "step %d: [step/sec: %.2f]" % (self.current_step, run_speed))
def _save_ckpt_interval_event(self): def _save_ckpt_interval_event(self):
self.save_checkpoint(self.current_epoch, self.current_step) self.save_checkpoint()
def _eval_interval_event(self): def _eval_interval_event(self):
self.eval(phase="dev") self.eval(phase="dev")
...@@ -469,7 +469,7 @@ class BasicTask(object): ...@@ -469,7 +469,7 @@ class BasicTask(object):
# NOTE: current saved checkpoint machanism is not completed, # NOTE: current saved checkpoint machanism is not completed,
# it can't restore dataset training status # it can't restore dataset training status
def save_checkpoint(self, epoch, step): def save_checkpoint(self):
save_checkpoint( save_checkpoint(
checkpoint_dir=self.config.checkpoint_dir, checkpoint_dir=self.config.checkpoint_dir,
current_epoch=self.current_epoch, current_epoch=self.current_epoch,
...@@ -512,7 +512,7 @@ class BasicTask(object): ...@@ -512,7 +512,7 @@ class BasicTask(object):
self.env.current_epoch += 1 self.env.current_epoch += 1
# Save checkpoint after finetune # Save checkpoint after finetune
self.save_checkpoint(self.current_epoch + 1, self.current_step) self.save_checkpoint()
# Final evaluation # Final evaluation
self.eval(phase="dev") self.eval(phase="dev")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册