未验证 提交 50599e26 编写于 作者: W wangguanzhong 提交者: GitHub

fix model save for last epoch (#1761)

上级 19eb7f47
...@@ -196,7 +196,9 @@ def run(FLAGS, cfg, place): ...@@ -196,7 +196,9 @@ def run(FLAGS, cfg, place):
logger.info(strs) logger.info(strs)
# Save Stage # Save Stage
if ParallelEnv().local_rank == 0 and cur_eid % cfg.snapshot_epoch == 0: if ParallelEnv().local_rank == 0 and (
cur_eid % cfg.snapshot_epoch == 0 or
(cur_eid + 1) == int(cfg.epoch)):
cfg_name = os.path.basename(FLAGS.config).split('.')[0] cfg_name = os.path.basename(FLAGS.config).split('.')[0]
save_name = str(cur_eid) if cur_eid + 1 != int( save_name = str(cur_eid) if cur_eid + 1 != int(
cfg.epoch) else "model_final" cfg.epoch) else "model_final"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册