From 589b21e3fc38597e0599c8c7bc3c9204222bb782 Mon Sep 17 00:00:00 2001 From: Kaipeng Deng Date: Thu, 5 Sep 2019 11:10:56 +0800 Subject: [PATCH] fix add map_type (#3273) * fix add map_type --- PaddleCV/PaddleDetection/tools/configure.py | 1 + PaddleCV/PaddleDetection/tools/eval.py | 4 +++- PaddleCV/PaddleDetection/tools/train.py | 5 ++++- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/PaddleCV/PaddleDetection/tools/configure.py b/PaddleCV/PaddleDetection/tools/configure.py index 84b031c3..8b43b00b 100644 --- a/PaddleCV/PaddleDetection/tools/configure.py +++ b/PaddleCV/PaddleDetection/tools/configure.py @@ -35,6 +35,7 @@ MISC_CONFIG = { "save_dir": "", "weights": "", "metric": "", + "map_type": "11point", "log_smooth_window": 20, "snapshot_iter": 10000, "use_gpu": True, diff --git a/PaddleCV/PaddleDetection/tools/eval.py b/PaddleCV/PaddleDetection/tools/eval.py index 3df77b21..d1be9390 100644 --- a/PaddleCV/PaddleDetection/tools/eval.py +++ b/PaddleCV/PaddleDetection/tools/eval.py @@ -136,8 +136,10 @@ def main(): resolution = None if 'mask' in results[0]: resolution = model.mask_head.resolution + # if map_type not set, use default 11point, only use in VOC eval + map_type = cfg.map_type if 'map_type' in cfg else '11point' eval_results(results, eval_feed, cfg.metric, cfg.num_classes, resolution, - is_bbox_normalized, FLAGS.output_eval, cfg.map_type) + is_bbox_normalized, FLAGS.output_eval, map_type) if __name__ == '__main__': diff --git a/PaddleCV/PaddleDetection/tools/train.py b/PaddleCV/PaddleDetection/tools/train.py index 619ea972..abcb7814 100644 --- a/PaddleCV/PaddleDetection/tools/train.py +++ b/PaddleCV/PaddleDetection/tools/train.py @@ -163,6 +163,9 @@ def main(): callable(model.is_bbox_normalized): is_bbox_normalized = model.is_bbox_normalized() + # if map_type not set, use default 11point, only use in VOC eval + map_type = cfg.map_type if 'map_type' in cfg else '11point' + train_stats = TrainingStats(cfg.log_smooth_window, train_keys) train_pyreader.start() start_time = time.time() @@ -200,7 +203,7 @@ def main(): if 'mask' in results[0]: resolution = model.mask_head.resolution box_ap_stats = eval_results(results, eval_feed, cfg.metric, cfg.num_classes, - resolution, is_bbox_normalized, FLAGS.output_eval) + resolution, is_bbox_normalized, FLAGS.output_eval, map_type) if box_ap_stats[0] > best_box_ap_list[0]: best_box_ap_list[0] = box_ap_stats[0] best_box_ap_list[1] = it -- GitLab