diff --git a/pdseg/train.py b/pdseg/train.py index 7b155a2b44e0dad56cd818150e75d8de5fb81ee9..3a717889d13784cbddfb0aa00979a69b6eb53a17 100644 --- a/pdseg/train.py +++ b/pdseg/train.py @@ -178,6 +178,11 @@ def load_checkpoint(exe, program): return begin_epoch +def update_best_model(ckpt_dir): + best_model_dir = os.path.join(cfg.TRAIN.MODEL_SAVE_DIR, 'best_model') + if os.path.exists(best_model_dir): + shutil.rmtree(best_model_dir) + shutil.copytree(ckpt_dir, best_model_dir) def print_info(*msg): if cfg.TRAINER_ID == 0: @@ -341,6 +346,8 @@ def train(cfg): all_step *= (cfg.SOLVER.NUM_EPOCHS - begin_epoch + 1) avg_loss = 0.0 + best_mIoU = 0.0 + timer = Timer() timer.start() if begin_epoch > cfg.SOLVER.NUM_EPOCHS: @@ -353,9 +360,6 @@ def train(cfg): else: print_info("Use multi-thread reader") - # 存储评估时最高mIoU - best_mIoU = 0 - for epoch in range(begin_epoch, cfg.SOLVER.NUM_EPOCHS + 1): py_reader.start() while True: @@ -448,13 +452,11 @@ def train(cfg): log_writer.add_scalar('Evaluate/mean_acc', mean_acc, global_step) - # 将最优模型拷贝一份至best_model中 if mean_iou > best_mIoU: best_mIoU = mean_iou - best_model_dir = os.path.join(cfg.TRAIN.MODEL_SAVE_DIR, 'best_model') - if os.path.exists(best_model_dir): - shutil.rmtree(best_model_dir) - shutil.copytree(ckpt_dir, best_model_dir) + update_best_model(ckpt_dir) + print_info("Model {} has best mIoU, save it in {}".format(ckpt_dir, + os.path.join(cfg.TRAIN.MODEL_SAVE_DIR, 'best_model'))) # Use Tensorboard to visualize results if args.use_tb and cfg.DATASET.VIS_FILE_LIST is not None: