From 20067bab7f3953d035237687ce592f99019d1c50 Mon Sep 17 00:00:00 2001 From: chenguowei01 Date: Wed, 18 Dec 2019 16:36:58 +0800 Subject: [PATCH] update train.py --- pdseg/train.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/pdseg/train.py b/pdseg/train.py index 7b155a2b..3a717889 100644 --- a/pdseg/train.py +++ b/pdseg/train.py @@ -178,6 +178,11 @@ def load_checkpoint(exe, program): return begin_epoch +def update_best_model(ckpt_dir): + best_model_dir = os.path.join(cfg.TRAIN.MODEL_SAVE_DIR, 'best_model') + if os.path.exists(best_model_dir): + shutil.rmtree(best_model_dir) + shutil.copytree(ckpt_dir, best_model_dir) def print_info(*msg): if cfg.TRAINER_ID == 0: @@ -341,6 +346,8 @@ def train(cfg): all_step *= (cfg.SOLVER.NUM_EPOCHS - begin_epoch + 1) avg_loss = 0.0 + best_mIoU = 0.0 + timer = Timer() timer.start() if begin_epoch > cfg.SOLVER.NUM_EPOCHS: @@ -353,9 +360,6 @@ def train(cfg): else: print_info("Use multi-thread reader") - # 存储评估时最高mIoU - best_mIoU = 0 - for epoch in range(begin_epoch, cfg.SOLVER.NUM_EPOCHS + 1): py_reader.start() while True: @@ -448,13 +452,11 @@ def train(cfg): log_writer.add_scalar('Evaluate/mean_acc', mean_acc, global_step) - # 将最优模型拷贝一份至best_model中 if mean_iou > best_mIoU: best_mIoU = mean_iou - best_model_dir = os.path.join(cfg.TRAIN.MODEL_SAVE_DIR, 'best_model') - if os.path.exists(best_model_dir): - shutil.rmtree(best_model_dir) - shutil.copytree(ckpt_dir, best_model_dir) + update_best_model(ckpt_dir) + print_info("Model {} has best mIoU, save it in {}".format(ckpt_dir, + os.path.join(cfg.TRAIN.MODEL_SAVE_DIR, 'best_model'))) # Use Tensorboard to visualize results if args.use_tb and cfg.DATASET.VIS_FILE_LIST is not None: -- GitLab