From ec78ccab603a4a532cd53321c48ae5af756061ed Mon Sep 17 00:00:00 2001 From: chenguowei01 Date: Tue, 17 Dec 2019 16:35:53 +0800 Subject: [PATCH] save best model --- pdseg/train.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/pdseg/train.py b/pdseg/train.py index 67a7bf5a..7b155a2b 100644 --- a/pdseg/train.py +++ b/pdseg/train.py @@ -353,6 +353,9 @@ def train(cfg): else: print_info("Use multi-thread reader") + # 存储评估时最高mIoU + best_mIoU = 0 + for epoch in range(begin_epoch, cfg.SOLVER.NUM_EPOCHS + 1): py_reader.start() while True: @@ -445,6 +448,14 @@ def train(cfg): log_writer.add_scalar('Evaluate/mean_acc', mean_acc, global_step) + # 将最优模型拷贝一份至best_model中 + if mean_iou > best_mIoU: + best_mIoU = mean_iou + best_model_dir = os.path.join(cfg.TRAIN.MODEL_SAVE_DIR, 'best_model') + if os.path.exists(best_model_dir): + shutil.rmtree(best_model_dir) + shutil.copytree(ckpt_dir, best_model_dir) + # Use Tensorboard to visualize results if args.use_tb and cfg.DATASET.VIS_FILE_LIST is not None: visualize( -- GitLab