提交 ec78ccab 编写于 作者: C chenguowei01

save best model

上级 fe611760
......@@ -353,6 +353,9 @@ def train(cfg):
else:
print_info("Use multi-thread reader")
# 存储评估时最高mIoU
best_mIoU = 0
for epoch in range(begin_epoch, cfg.SOLVER.NUM_EPOCHS + 1):
py_reader.start()
while True:
......@@ -445,6 +448,14 @@ def train(cfg):
log_writer.add_scalar('Evaluate/mean_acc', mean_acc,
global_step)
# 将最优模型拷贝一份至best_model中
if mean_iou > best_mIoU:
best_mIoU = mean_iou
best_model_dir = os.path.join(cfg.TRAIN.MODEL_SAVE_DIR, 'best_model')
if os.path.exists(best_model_dir):
shutil.rmtree(best_model_dir)
shutil.copytree(ckpt_dir, best_model_dir)
# Use Tensorboard to visualize results
if args.use_tb and cfg.DATASET.VIS_FILE_LIST is not None:
visualize(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册