From 20ca4cc34ad52ffef89108e9c7f7f82ceb74174b Mon Sep 17 00:00:00 2001 From: Kaipeng Deng Date: Thu, 5 Sep 2019 11:10:56 +0800 Subject: [PATCH] fix add map_type (#3273) * fix add map_type --- tools/configure.py | 1 + tools/eval.py | 4 +++- tools/train.py | 5 ++++- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/tools/configure.py b/tools/configure.py index 84b031c39..8b43b00bc 100644 --- a/tools/configure.py +++ b/tools/configure.py @@ -35,6 +35,7 @@ MISC_CONFIG = { "save_dir": "", "weights": "", "metric": "", + "map_type": "11point", "log_smooth_window": 20, "snapshot_iter": 10000, "use_gpu": True, diff --git a/tools/eval.py b/tools/eval.py index 3df77b21f..d1be93908 100644 --- a/tools/eval.py +++ b/tools/eval.py @@ -136,8 +136,10 @@ def main(): resolution = None if 'mask' in results[0]: resolution = model.mask_head.resolution + # if map_type not set, use default 11point, only use in VOC eval + map_type = cfg.map_type if 'map_type' in cfg else '11point' eval_results(results, eval_feed, cfg.metric, cfg.num_classes, resolution, - is_bbox_normalized, FLAGS.output_eval, cfg.map_type) + is_bbox_normalized, FLAGS.output_eval, map_type) if __name__ == '__main__': diff --git a/tools/train.py b/tools/train.py index 619ea972b..abcb7814a 100644 --- a/tools/train.py +++ b/tools/train.py @@ -163,6 +163,9 @@ def main(): callable(model.is_bbox_normalized): is_bbox_normalized = model.is_bbox_normalized() + # if map_type not set, use default 11point, only use in VOC eval + map_type = cfg.map_type if 'map_type' in cfg else '11point' + train_stats = TrainingStats(cfg.log_smooth_window, train_keys) train_pyreader.start() start_time = time.time() @@ -200,7 +203,7 @@ def main(): if 'mask' in results[0]: resolution = model.mask_head.resolution box_ap_stats = eval_results(results, eval_feed, cfg.metric, cfg.num_classes, - resolution, is_bbox_normalized, FLAGS.output_eval) + resolution, is_bbox_normalized, FLAGS.output_eval, map_type) if box_ap_stats[0] > best_box_ap_list[0]: best_box_ap_list[0] = box_ap_stats[0] best_box_ap_list[1] = it -- GitLab