diff --git a/pdseg/train.py b/pdseg/train.py index 67a7bf5a991c45e1e007a613ee6d0b6e8df7c280..7b155a2b44e0dad56cd818150e75d8de5fb81ee9 100644 --- a/pdseg/train.py +++ b/pdseg/train.py @@ -353,6 +353,9 @@ def train(cfg): else: print_info("Use multi-thread reader") + # 存储评估时最高mIoU + best_mIoU = 0 + for epoch in range(begin_epoch, cfg.SOLVER.NUM_EPOCHS + 1): py_reader.start() while True: @@ -445,6 +448,14 @@ def train(cfg): log_writer.add_scalar('Evaluate/mean_acc', mean_acc, global_step) + # 将最优模型拷贝一份至best_model中 + if mean_iou > best_mIoU: + best_mIoU = mean_iou + best_model_dir = os.path.join(cfg.TRAIN.MODEL_SAVE_DIR, 'best_model') + if os.path.exists(best_model_dir): + shutil.rmtree(best_model_dir) + shutil.copytree(ckpt_dir, best_model_dir) + # Use Tensorboard to visualize results if args.use_tb and cfg.DATASET.VIS_FILE_LIST is not None: visualize(