提交 6b954bda 编写于 作者: C chenguowei01

update train.py

上级 20067bab
...@@ -178,12 +178,14 @@ def load_checkpoint(exe, program): ...@@ -178,12 +178,14 @@ def load_checkpoint(exe, program):
return begin_epoch return begin_epoch
def update_best_model(ckpt_dir): def update_best_model(ckpt_dir):
best_model_dir = os.path.join(cfg.TRAIN.MODEL_SAVE_DIR, 'best_model') best_model_dir = os.path.join(cfg.TRAIN.MODEL_SAVE_DIR, 'best_model')
if os.path.exists(best_model_dir): if os.path.exists(best_model_dir):
shutil.rmtree(best_model_dir) shutil.rmtree(best_model_dir)
shutil.copytree(ckpt_dir, best_model_dir) shutil.copytree(ckpt_dir, best_model_dir)
def print_info(*msg): def print_info(*msg):
if cfg.TRAINER_ID == 0: if cfg.TRAINER_ID == 0:
print(*msg) print(*msg)
...@@ -455,7 +457,8 @@ def train(cfg): ...@@ -455,7 +457,8 @@ def train(cfg):
if mean_iou > best_mIoU: if mean_iou > best_mIoU:
best_mIoU = mean_iou best_mIoU = mean_iou
update_best_model(ckpt_dir) update_best_model(ckpt_dir)
print_info("Model {} has best mIoU, save it in {}".format(ckpt_dir, print_info("Model {} has best mIoU, save it in {}".format(
ckpt_dir,
os.path.join(cfg.TRAIN.MODEL_SAVE_DIR, 'best_model'))) os.path.join(cfg.TRAIN.MODEL_SAVE_DIR, 'best_model')))
# Use Tensorboard to visualize results # Use Tensorboard to visualize results
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册