提交 d4df83f4 编写于 作者: W wuyefeilin 提交者: wuzewu

save best model (#123)

* save best model
上级 0f365acc
......@@ -179,6 +179,13 @@ def load_checkpoint(exe, program):
return begin_epoch
def update_best_model(ckpt_dir):
best_model_dir = os.path.join(cfg.TRAIN.MODEL_SAVE_DIR, 'best_model')
if os.path.exists(best_model_dir):
shutil.rmtree(best_model_dir)
shutil.copytree(ckpt_dir, best_model_dir)
def print_info(*msg):
if cfg.TRAINER_ID == 0:
print(*msg)
......@@ -341,6 +348,8 @@ def train(cfg):
all_step *= (cfg.SOLVER.NUM_EPOCHS - begin_epoch + 1)
avg_loss = 0.0
best_mIoU = 0.0
timer = Timer()
timer.start()
if begin_epoch > cfg.SOLVER.NUM_EPOCHS:
......@@ -445,6 +454,14 @@ def train(cfg):
log_writer.add_scalar('Evaluate/mean_acc', mean_acc,
global_step)
if mean_iou > best_mIoU:
best_mIoU = mean_iou
update_best_model(ckpt_dir)
print_info("Save best model {} to {}, mIoU = {:.4f}".format(
ckpt_dir,
os.path.join(cfg.TRAIN.MODEL_SAVE_DIR, 'best_model'),
mean_iou))
# Use Tensorboard to visualize results
if args.use_tb and cfg.DATASET.VIS_FILE_LIST is not None:
visualize(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册