未验证 提交 73acfad8 编写于 作者: G Guanghua Yu 提交者: GitHub

support compute per-category AP and PR curve (#2346)

上级 aa2500cf
...@@ -30,6 +30,7 @@ PaddleDetection在[tools](https://github.com/PaddlePaddle/PaddleDetection/tree/m ...@@ -30,6 +30,7 @@ PaddleDetection在[tools](https://github.com/PaddlePaddle/PaddleDetection/tree/m
| --draw_threshold | infer | 可视化时分数阈值 | 0.5 | 可选,`--draw_threshold=0.7` | | --draw_threshold | infer | 可视化时分数阈值 | 0.5 | 可选,`--draw_threshold=0.7` |
| --infer_dir | infer | 用于预测的图片文件夹路径 | None | 可选 | | --infer_dir | infer | 用于预测的图片文件夹路径 | None | 可选 |
| --infer_img | infer | 用于预测的图片路径 | None | 可选,`--infer_img``--infer_dir`必须至少设置一个 | | --infer_img | infer | 用于预测的图片路径 | None | 可选,`--infer_img``--infer_dir`必须至少设置一个 |
| --classwise | eval | 是否评估单类AP和绘制单类PR曲线 | False | 可选 |
### 训练 ### 训练
......
...@@ -115,19 +115,23 @@ class Trainer(object): ...@@ -115,19 +115,23 @@ class Trainer(object):
if self.mode == 'test': if self.mode == 'test':
self._metrics = [] self._metrics = []
return return
classwise = self.cfg['classwise'] if 'classwise' in self.cfg else False
if self.cfg.metric == 'COCO': if self.cfg.metric == 'COCO':
# TODO: bias should be unified # TODO: bias should be unified
bias = self.cfg['bias'] if 'bias' in self.cfg else 0 bias = self.cfg['bias'] if 'bias' in self.cfg else 0
self._metrics = [ self._metrics = [
COCOMetric( COCOMetric(
anno_file=self.dataset.get_anno(), bias=bias) anno_file=self.dataset.get_anno(),
classwise=classwise,
bias=bias)
] ]
elif self.cfg.metric == 'VOC': elif self.cfg.metric == 'VOC':
self._metrics = [ self._metrics = [
VOCMetric( VOCMetric(
anno_file=self.dataset.get_anno(), label_list=self.dataset.get_label_list(),
class_num=self.cfg.num_classes, class_num=self.cfg.num_classes,
map_type=self.cfg.map_type) map_type=self.cfg.map_type,
classwise=classwise)
] ]
elif self.cfg.metric == 'WiderFace': elif self.cfg.metric == 'WiderFace':
multi_scale = self.cfg.multi_scale_eval if 'multi_scale_eval' in self.cfg else True multi_scale = self.cfg.multi_scale_eval if 'multi_scale_eval' in self.cfg else True
......
...@@ -17,8 +17,12 @@ from __future__ import division ...@@ -17,8 +17,12 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
import sys
import numpy as np
import itertools
from ppdet.py_op.post_process import get_det_res, get_seg_res, get_solov2_segm_res from ppdet.py_op.post_process import get_det_res, get_seg_res, get_solov2_segm_res
from ppdet.metrics.map_utils import draw_pr_curve
from ppdet.utils.logger import setup_logger from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__) logger = setup_logger(__name__)
...@@ -59,7 +63,8 @@ def cocoapi_eval(jsonfile, ...@@ -59,7 +63,8 @@ def cocoapi_eval(jsonfile,
style, style,
coco_gt=None, coco_gt=None,
anno_file=None, anno_file=None,
max_dets=(100, 300, 1000)): max_dets=(100, 300, 1000),
classwise=False):
""" """
Args: Args:
jsonfile: Evaluation json file, eg: bbox.json, mask.json. jsonfile: Evaluation json file, eg: bbox.json, mask.json.
...@@ -68,6 +73,7 @@ def cocoapi_eval(jsonfile, ...@@ -68,6 +73,7 @@ def cocoapi_eval(jsonfile,
eg: coco_gt = COCO(anno_file) eg: coco_gt = COCO(anno_file)
anno_file: COCO annotations file. anno_file: COCO annotations file.
max_dets: COCO evaluation maxDets. max_dets: COCO evaluation maxDets.
classwise: whether per-category AP and draw P-R Curve or not.
""" """
assert coco_gt != None or anno_file != None assert coco_gt != None or anno_file != None
from pycocotools.coco import COCO from pycocotools.coco import COCO
...@@ -86,4 +92,51 @@ def cocoapi_eval(jsonfile, ...@@ -86,4 +92,51 @@ def cocoapi_eval(jsonfile,
coco_eval.evaluate() coco_eval.evaluate()
coco_eval.accumulate() coco_eval.accumulate()
coco_eval.summarize() coco_eval.summarize()
if classwise:
# Compute per-category AP and PR curve
try:
from terminaltables import AsciiTable
except Exception as e:
logger.error(
'terminaltables not found, plaese install terminaltables. '
'for example: `pip install terminaltables`.')
raise e
precisions = coco_eval.eval['precision']
cat_ids = coco_gt.getCatIds()
# precision: (iou, recall, cls, area range, max dets)
assert len(cat_ids) == precisions.shape[2]
results_per_category = []
for idx, catId in enumerate(cat_ids):
# area range index 0: all area ranges
# max dets index -1: typically 100 per image
nm = coco_gt.loadCats(catId)[0]
precision = precisions[:, :, idx, 0, -1]
precision = precision[precision > -1]
if precision.size:
ap = np.mean(precision)
else:
ap = float('nan')
results_per_category.append(
(str(nm["name"]), '{:0.3f}'.format(float(ap))))
pr_array = precisions[0, :, idx, 0, 2]
recall_array = np.arange(0.0, 1.01, 0.01)
draw_pr_curve(
pr_array,
recall_array,
out_dir=style + '_pr_curve',
file_name='{}_precision_recall_curve.jpg'.format(nm["name"]))
num_columns = min(6, len(results_per_category) * 2)
results_flatten = list(itertools.chain(*results_per_category))
headers = ['category', 'AP'] * (num_columns // 2)
results_2d = itertools.zip_longest(
* [results_flatten[i::num_columns] for i in range(num_columns)])
table_data = [headers]
table_data += [result for result in results_2d]
table = AsciiTable(table_data)
logger.info('Per-category of {} AP: \n{}'.format(style, table.table))
logger.info("per-category PR curve has output to {} folder.".format(
style + '_pr_curve'))
# flush coco evaluation result
sys.stdout.flush()
return coco_eval.stats return coco_eval.stats
...@@ -17,13 +17,42 @@ from __future__ import division ...@@ -17,13 +17,42 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from __future__ import unicode_literals from __future__ import unicode_literals
import os
import sys import sys
import numpy as np import numpy as np
import itertools
from ppdet.utils.logger import setup_logger from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__) logger = setup_logger(__name__)
__all__ = ['bbox_area', 'jaccard_overlap', 'prune_zero_padding', 'DetectionMAP'] __all__ = [
'draw_pr_curve', 'bbox_area', 'jaccard_overlap', 'prune_zero_padding',
'DetectionMAP'
]
def draw_pr_curve(precision,
recall,
iou=0.5,
out_dir='pr_curve',
file_name='precision_recall_curve.jpg'):
if not os.path.exists(out_dir):
os.makedirs(out_dir)
output_path = os.path.join(out_dir, file_name)
try:
import matplotlib.pyplot as plt
except Exception as e:
logger.error('Matplotlib not found, plaese install matplotlib.'
'for example: `pip install matplotlib`.')
raise e
plt.cla()
plt.figure('P-R Curve')
plt.title('Precision/Recall Curve(IoU={})'.format(iou))
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.grid(True)
plt.plot(recall, precision)
plt.savefig(output_path)
def bbox_area(bbox, is_bbox_normalized): def bbox_area(bbox, is_bbox_normalized):
...@@ -84,6 +113,8 @@ class DetectionMAP(object): ...@@ -84,6 +113,8 @@ class DetectionMAP(object):
is normalized to range[0, 1]. Default False. is normalized to range[0, 1]. Default False.
evaluate_difficult (bool): whether to evaluate evaluate_difficult (bool): whether to evaluate
difficult bounding boxes. Default False. difficult bounding boxes. Default False.
classwise (bool): whether per-category AP and draw
P-R Curve or not.
""" """
def __init__(self, def __init__(self,
...@@ -91,7 +122,9 @@ class DetectionMAP(object): ...@@ -91,7 +122,9 @@ class DetectionMAP(object):
overlap_thresh=0.5, overlap_thresh=0.5,
map_type='11point', map_type='11point',
is_bbox_normalized=False, is_bbox_normalized=False,
evaluate_difficult=False): evaluate_difficult=False,
catid2name=None,
classwise=False):
self.class_num = class_num self.class_num = class_num
self.overlap_thresh = overlap_thresh self.overlap_thresh = overlap_thresh
assert map_type in ['11point', 'integral'], \ assert map_type in ['11point', 'integral'], \
...@@ -100,6 +133,10 @@ class DetectionMAP(object): ...@@ -100,6 +133,10 @@ class DetectionMAP(object):
self.map_type = map_type self.map_type = map_type
self.is_bbox_normalized = is_bbox_normalized self.is_bbox_normalized = is_bbox_normalized
self.evaluate_difficult = evaluate_difficult self.evaluate_difficult = evaluate_difficult
self.classwise = classwise
self.classes = []
for cname in catid2name.values():
self.classes.append(cname)
self.reset() self.reset()
def update(self, bbox, score, label, gt_box, gt_label, difficult=None): def update(self, bbox, score, label, gt_box, gt_label, difficult=None):
...@@ -155,6 +192,7 @@ class DetectionMAP(object): ...@@ -155,6 +192,7 @@ class DetectionMAP(object):
""" """
mAP = 0. mAP = 0.
valid_cnt = 0 valid_cnt = 0
eval_results = []
for score_pos, count in zip(self.class_score_poss, for score_pos, count in zip(self.class_score_poss,
self.class_gt_counts): self.class_gt_counts):
if count == 0: continue if count == 0: continue
...@@ -170,6 +208,7 @@ class DetectionMAP(object): ...@@ -170,6 +208,7 @@ class DetectionMAP(object):
precision.append(float(ac_tp) / (ac_tp + ac_fp)) precision.append(float(ac_tp) / (ac_tp + ac_fp))
recall.append(float(ac_tp) / count) recall.append(float(ac_tp) / count)
one_class_ap = 0.0
if self.map_type == '11point': if self.map_type == '11point':
max_precisions = [0.] * 11 max_precisions = [0.] * 11
start_idx = len(precision) - 1 start_idx = len(precision) - 1
...@@ -183,23 +222,29 @@ class DetectionMAP(object): ...@@ -183,23 +222,29 @@ class DetectionMAP(object):
else: else:
if max_precisions[j] < precision[i]: if max_precisions[j] < precision[i]:
max_precisions[j] = precision[i] max_precisions[j] = precision[i]
mAP += sum(max_precisions) / 11. one_class_ap = sum(max_precisions) / 11.
mAP += one_class_ap
valid_cnt += 1 valid_cnt += 1
elif self.map_type == 'integral': elif self.map_type == 'integral':
import math import math
ap = 0.
prev_recall = 0. prev_recall = 0.
for i in range(len(precision)): for i in range(len(precision)):
recall_gap = math.fabs(recall[i] - prev_recall) recall_gap = math.fabs(recall[i] - prev_recall)
if recall_gap > 1e-6: if recall_gap > 1e-6:
ap += precision[i] * recall_gap one_class_ap += precision[i] * recall_gap
prev_recall = recall[i] prev_recall = recall[i]
mAP += ap mAP += one_class_ap
valid_cnt += 1 valid_cnt += 1
else: else:
logger.error("Unspported mAP type {}".format(self.map_type)) logger.error("Unspported mAP type {}".format(self.map_type))
sys.exit(1) sys.exit(1)
eval_results.append({
'class': self.classes[valid_cnt - 1],
'ap': one_class_ap,
'precision': precision,
'recall': recall,
})
self.eval_results = eval_results
self.mAP = mAP / float(valid_cnt) if valid_cnt > 0 else mAP self.mAP = mAP / float(valid_cnt) if valid_cnt > 0 else mAP
def get_map(self): def get_map(self):
...@@ -208,6 +253,39 @@ class DetectionMAP(object): ...@@ -208,6 +253,39 @@ class DetectionMAP(object):
""" """
if self.mAP is None: if self.mAP is None:
logger.error("mAP is not calculated.") logger.error("mAP is not calculated.")
if self.classwise:
# Compute per-category AP and PR curve
try:
from terminaltables import AsciiTable
except Exception as e:
logger.error(
'terminaltables not found, plaese install terminaltables. '
'for example: `pip install terminaltables`.')
raise e
results_per_category = []
for eval_result in self.eval_results:
results_per_category.append(
(str(eval_result['class']),
'{:0.3f}'.format(float(eval_result['ap']))))
draw_pr_curve(
eval_result['precision'],
eval_result['recall'],
out_dir='voc_pr_curve',
file_name='{}_precision_recall_curve.jpg'.format(
eval_result['class']))
num_columns = min(6, len(results_per_category) * 2)
results_flatten = list(itertools.chain(*results_per_category))
headers = ['category', 'AP'] * (num_columns // 2)
results_2d = itertools.zip_longest(* [
results_flatten[i::num_columns] for i in range(num_columns)
])
table_data = [headers]
table_data += [result for result in results_2d]
table = AsciiTable(table_data)
logger.info('Per-category of VOC AP: \n{}'.format(table.table))
logger.info(
"per-category PR curve has output to voc_pr_curve folder.")
return self.mAP return self.mAP
def _get_tp_fp_accum(self, score_pos_list): def _get_tp_fp_accum(self, score_pos_list):
......
...@@ -63,6 +63,7 @@ class COCOMetric(Metric): ...@@ -63,6 +63,7 @@ class COCOMetric(Metric):
"anno_file {} not a file".format(anno_file) "anno_file {} not a file".format(anno_file)
self.anno_file = anno_file self.anno_file = anno_file
self.clsid2catid, self.catid2name = get_categories('COCO', anno_file) self.clsid2catid, self.catid2name = get_categories('COCO', anno_file)
self.classwise = kwargs.get('classwise', False)
# TODO: bias should be unified # TODO: bias should be unified
self.bias = kwargs.get('bias', 0) self.bias = kwargs.get('bias', 0)
self.reset() self.reset()
...@@ -98,7 +99,10 @@ class COCOMetric(Metric): ...@@ -98,7 +99,10 @@ class COCOMetric(Metric):
logger.info('The bbox result is saved to bbox.json.') logger.info('The bbox result is saved to bbox.json.')
bbox_stats = cocoapi_eval( bbox_stats = cocoapi_eval(
'bbox.json', 'bbox', anno_file=self.anno_file) 'bbox.json',
'bbox',
anno_file=self.anno_file,
classwise=self.classwise)
self.eval_results['bbox'] = bbox_stats self.eval_results['bbox'] = bbox_stats
sys.stdout.flush() sys.stdout.flush()
...@@ -108,7 +112,10 @@ class COCOMetric(Metric): ...@@ -108,7 +112,10 @@ class COCOMetric(Metric):
logger.info('The mask result is saved to mask.json.') logger.info('The mask result is saved to mask.json.')
seg_stats = cocoapi_eval( seg_stats = cocoapi_eval(
'mask.json', 'segm', anno_file=self.anno_file) 'mask.json',
'segm',
anno_file=self.anno_file,
classwise=self.classwise)
self.eval_results['mask'] = seg_stats self.eval_results['mask'] = seg_stats
sys.stdout.flush() sys.stdout.flush()
...@@ -118,7 +125,10 @@ class COCOMetric(Metric): ...@@ -118,7 +125,10 @@ class COCOMetric(Metric):
logger.info('The segm result is saved to segm.json.') logger.info('The segm result is saved to segm.json.')
seg_stats = cocoapi_eval( seg_stats = cocoapi_eval(
'segm.json', 'segm', anno_file=self.anno_file) 'segm.json',
'segm',
anno_file=self.anno_file,
classwise=self.classwise)
self.eval_results['mask'] = seg_stats self.eval_results['mask'] = seg_stats
sys.stdout.flush() sys.stdout.flush()
...@@ -131,16 +141,16 @@ class COCOMetric(Metric): ...@@ -131,16 +141,16 @@ class COCOMetric(Metric):
class VOCMetric(Metric): class VOCMetric(Metric):
def __init__(self, def __init__(self,
anno_file, label_list,
class_num=20, class_num=20,
overlap_thresh=0.5, overlap_thresh=0.5,
map_type='11point', map_type='11point',
is_bbox_normalized=False, is_bbox_normalized=False,
evaluate_difficult=False): evaluate_difficult=False,
assert os.path.isfile(anno_file), \ classwise=False):
"anno_file {} not a file".format(anno_file) assert os.path.isfile(label_list), \
self.anno_file = anno_file "label_list {} not a file".format(label_list)
self.clsid2catid, self.catid2name = get_categories('VOC', anno_file) self.clsid2catid, self.catid2name = get_categories('VOC', label_list)
self.overlap_thresh = overlap_thresh self.overlap_thresh = overlap_thresh
self.map_type = map_type self.map_type = map_type
...@@ -150,7 +160,9 @@ class VOCMetric(Metric): ...@@ -150,7 +160,9 @@ class VOCMetric(Metric):
overlap_thresh=overlap_thresh, overlap_thresh=overlap_thresh,
map_type=map_type, map_type=map_type,
is_bbox_normalized=is_bbox_normalized, is_bbox_normalized=is_bbox_normalized,
evaluate_difficult=evaluate_difficult) evaluate_difficult=evaluate_difficult,
catid2name=self.catid2name,
classwise=classwise)
self.reset() self.reset()
......
...@@ -5,3 +5,4 @@ opencv-python ...@@ -5,3 +5,4 @@ opencv-python
PyYAML PyYAML
shapely shapely
scipy scipy
terminaltables
\ No newline at end of file
...@@ -64,6 +64,11 @@ def parse_args(): ...@@ -64,6 +64,11 @@ def parse_args():
action="store_true", action="store_true",
help="whether add bias or not while getting w and h") help="whether add bias or not while getting w and h")
parser.add_argument(
"--classwise",
action="store_true",
help="whether per-category AP and draw P-R Curve or not.")
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -88,6 +93,7 @@ def main(): ...@@ -88,6 +93,7 @@ def main():
cfg = load_config(FLAGS.config) cfg = load_config(FLAGS.config)
# TODO: bias should be unified # TODO: bias should be unified
cfg['bias'] = 1 if FLAGS.bias else 0 cfg['bias'] = 1 if FLAGS.bias else 0
cfg['classwise'] = True if FLAGS.classwise else False
merge_config(FLAGS.opt) merge_config(FLAGS.opt)
if FLAGS.slim_config: if FLAGS.slim_config:
slim_cfg = load_config(FLAGS.slim_config) slim_cfg = load_config(FLAGS.slim_config)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册