提交 a688c3fb 编写于 作者: W wuzewu

Update save_checkpoint api

上级 54a193c4
......@@ -443,7 +443,7 @@ class BasicTask(object):
"step %d: [step/sec: %.2f]" % (self.current_step, run_speed))
def _save_ckpt_interval_event(self):
self.save_checkpoint(self.current_epoch, self.current_step)
self.save_checkpoint()
def _eval_interval_event(self):
self.eval(phase="dev")
......@@ -469,7 +469,7 @@ class BasicTask(object):
# NOTE: current saved checkpoint machanism is not completed,
# it can't restore dataset training status
def save_checkpoint(self, epoch, step):
def save_checkpoint(self):
save_checkpoint(
checkpoint_dir=self.config.checkpoint_dir,
current_epoch=self.current_epoch,
......@@ -512,7 +512,7 @@ class BasicTask(object):
self.env.current_epoch += 1
# Save checkpoint after finetune
self.save_checkpoint(self.current_epoch + 1, self.current_step)
self.save_checkpoint()
# Final evaluation
self.eval(phase="dev")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册