提交 20ca4cc3 编写于 作者: K Kaipeng Deng 提交者: GitHub

fix add map_type (#3273)

* fix add map_type
上级 e9dea08b
...@@ -35,6 +35,7 @@ MISC_CONFIG = { ...@@ -35,6 +35,7 @@ MISC_CONFIG = {
"save_dir": "<value>", "save_dir": "<value>",
"weights": "<value>", "weights": "<value>",
"metric": "<value>", "metric": "<value>",
"map_type": "11point",
"log_smooth_window": 20, "log_smooth_window": 20,
"snapshot_iter": 10000, "snapshot_iter": 10000,
"use_gpu": True, "use_gpu": True,
......
...@@ -136,8 +136,10 @@ def main(): ...@@ -136,8 +136,10 @@ def main():
resolution = None resolution = None
if 'mask' in results[0]: if 'mask' in results[0]:
resolution = model.mask_head.resolution resolution = model.mask_head.resolution
# if map_type not set, use default 11point, only use in VOC eval
map_type = cfg.map_type if 'map_type' in cfg else '11point'
eval_results(results, eval_feed, cfg.metric, cfg.num_classes, resolution, eval_results(results, eval_feed, cfg.metric, cfg.num_classes, resolution,
is_bbox_normalized, FLAGS.output_eval, cfg.map_type) is_bbox_normalized, FLAGS.output_eval, map_type)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -163,6 +163,9 @@ def main(): ...@@ -163,6 +163,9 @@ def main():
callable(model.is_bbox_normalized): callable(model.is_bbox_normalized):
is_bbox_normalized = model.is_bbox_normalized() is_bbox_normalized = model.is_bbox_normalized()
# if map_type not set, use default 11point, only use in VOC eval
map_type = cfg.map_type if 'map_type' in cfg else '11point'
train_stats = TrainingStats(cfg.log_smooth_window, train_keys) train_stats = TrainingStats(cfg.log_smooth_window, train_keys)
train_pyreader.start() train_pyreader.start()
start_time = time.time() start_time = time.time()
...@@ -200,7 +203,7 @@ def main(): ...@@ -200,7 +203,7 @@ def main():
if 'mask' in results[0]: if 'mask' in results[0]:
resolution = model.mask_head.resolution resolution = model.mask_head.resolution
box_ap_stats = eval_results(results, eval_feed, cfg.metric, cfg.num_classes, box_ap_stats = eval_results(results, eval_feed, cfg.metric, cfg.num_classes,
resolution, is_bbox_normalized, FLAGS.output_eval) resolution, is_bbox_normalized, FLAGS.output_eval, map_type)
if box_ap_stats[0] > best_box_ap_list[0]: if box_ap_stats[0] > best_box_ap_list[0]:
best_box_ap_list[0] = box_ap_stats[0] best_box_ap_list[0] = box_ap_stats[0]
best_box_ap_list[1] = it best_box_ap_list[1] = it
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册