提交 2bde2485 编写于 作者: V Vighnesh Birodkar 提交者: TF Object Detection Team

Cleanup COCO eval code and enable reporting all metrics with per category metrics.

PiperOrigin-RevId: 341188685
上级 30cc8e74
......@@ -1177,6 +1177,12 @@ def evaluator_options_from_eval_config(eval_config):
'include_metrics_per_category': (
eval_config.include_metrics_per_category)
}
if (hasattr(eval_config, 'all_metrics_per_category') and
eval_config.all_metrics_per_category):
evaluator_options[eval_metric_fn_key].update({
'all_metrics_per_category': eval_config.all_metrics_per_category
})
# For coco detection eval, if the eval_config proto contains the
# "skip_predictions_for_unlabeled_class" field, include this field in
# evaluator_options.
......
......@@ -961,6 +961,7 @@ class CocoMaskEvaluator(object_detection_evaluation.DetectionEvaluator):
def __init__(self, categories,
include_metrics_per_category=False,
all_metrics_per_category=False,
super_categories=None):
"""Constructor.
......@@ -969,6 +970,10 @@ class CocoMaskEvaluator(object_detection_evaluation.DetectionEvaluator):
'id': (required) an integer id uniquely identifying this category.
'name': (required) string representing category name e.g., 'cat', 'dog'.
include_metrics_per_category: If True, include metrics for each category.
all_metrics_per_category: Whether to include all the summary metrics for
each category in per_category_ap. Be careful with setting it to true if
you have more than handful of categories, because it will pollute
your mldash.
super_categories: None or a python dict mapping super-category names
(strings) to lists of categories (corresponding to category names
in the label_map). Metrics are aggregated along these super-categories
......@@ -984,6 +989,7 @@ class CocoMaskEvaluator(object_detection_evaluation.DetectionEvaluator):
self._annotation_id = 1
self._include_metrics_per_category = include_metrics_per_category
self._super_categories = super_categories
self._all_metrics_per_category = all_metrics_per_category
def clear(self):
"""Clears the state to prepare for a fresh evaluation."""
......@@ -1177,7 +1183,8 @@ class CocoMaskEvaluator(object_detection_evaluation.DetectionEvaluator):
agnostic_mode=False, iou_type='segm')
mask_metrics, mask_per_category_ap = mask_evaluator.ComputeMetrics(
include_metrics_per_category=self._include_metrics_per_category,
super_categories=self._super_categories)
super_categories=self._super_categories,
all_metrics_per_category=self._all_metrics_per_category)
mask_metrics.update(mask_per_category_ap)
mask_metrics = {'DetectionMasks_'+ key: value
for key, value in mask_metrics.items()}
......
......@@ -142,6 +142,35 @@ class COCOWrapper(coco.COCO):
return results
COCO_METRIC_NAMES_AND_INDEX = (
('Precision/mAP', 0),
('Precision/mAP@.50IOU', 1),
('Precision/mAP@.75IOU', 2),
('Precision/mAP (small)', 3),
('Precision/mAP (medium)', 4),
('Precision/mAP (large)', 5),
('Recall/AR@1', 6),
('Recall/AR@10', 7),
('Recall/AR@100', 8),
('Recall/AR@100 (small)', 9),
('Recall/AR@100 (medium)', 10),
('Recall/AR@100 (large)', 11)
)
COCO_KEYPOINT_METRIC_NAMES_AND_INDEX = (
('Precision/mAP', 0),
('Precision/mAP@.50IOU', 1),
('Precision/mAP@.75IOU', 2),
('Precision/mAP (medium)', 3),
('Precision/mAP (large)', 4),
('Recall/AR@1', 5),
('Recall/AR@10', 6),
('Recall/AR@100', 7),
('Recall/AR@100 (medium)', 8),
('Recall/AR@100 (large)', 9)
)
class COCOEvalWrapper(cocoeval.COCOeval):
"""Wrapper for the pycocotools COCOeval class.
......@@ -259,42 +288,17 @@ class COCOEvalWrapper(cocoeval.COCOeval):
summary_metrics = {}
if self._iou_type in ['bbox', 'segm']:
summary_metrics = OrderedDict([('Precision/mAP', self.stats[0]),
('Precision/mAP@.50IOU', self.stats[1]),
('Precision/mAP@.75IOU', self.stats[2]),
('Precision/mAP (small)', self.stats[3]),
('Precision/mAP (medium)', self.stats[4]),
('Precision/mAP (large)', self.stats[5]),
('Recall/AR@1', self.stats[6]),
('Recall/AR@10', self.stats[7]),
('Recall/AR@100', self.stats[8]),
('Recall/AR@100 (small)', self.stats[9]),
('Recall/AR@100 (medium)', self.stats[10]),
('Recall/AR@100 (large)', self.stats[11])])
summary_metrics = OrderedDict(
[(name, self.stats[index]) for name, index in
COCO_METRIC_NAMES_AND_INDEX])
elif self._iou_type == 'keypoints':
category_id = self.GetCategoryIdList()[0]
category_name = self.GetCategory(category_id)['name']
summary_metrics = OrderedDict([])
summary_metrics['Precision/mAP ByCategory/{}'.format(
category_name)] = self.stats[0]
summary_metrics['Precision/mAP@.50IOU ByCategory/{}'.format(
category_name)] = self.stats[1]
summary_metrics['Precision/mAP@.75IOU ByCategory/{}'.format(
category_name)] = self.stats[2]
summary_metrics['Precision/mAP (medium) ByCategory/{}'.format(
category_name)] = self.stats[3]
summary_metrics['Precision/mAP (large) ByCategory/{}'.format(
category_name)] = self.stats[4]
summary_metrics['Recall/AR@1 ByCategory/{}'.format(
category_name)] = self.stats[5]
summary_metrics['Recall/AR@10 ByCategory/{}'.format(
category_name)] = self.stats[6]
summary_metrics['Recall/AR@100 ByCategory/{}'.format(
category_name)] = self.stats[7]
summary_metrics['Recall/AR@100 (medium) ByCategory/{}'.format(
category_name)] = self.stats[8]
summary_metrics['Recall/AR@100 (large) ByCategory/{}'.format(
category_name)] = self.stats[9]
for metric_name, index in COCO_KEYPOINT_METRIC_NAMES_AND_INDEX:
value = self.stats[index]
summary_metrics['{} ByCategory/{}'.format(
metric_name, category_name)] = value
if not include_metrics_per_category:
return summary_metrics, {}
if not hasattr(self, 'category_stats'):
......@@ -303,48 +307,51 @@ class COCOEvalWrapper(cocoeval.COCOeval):
super_category_ap = OrderedDict([])
if self.GetAgnosticMode():
return summary_metrics, per_category_ap
if super_categories:
for key in super_categories:
super_category_ap['PerformanceBySuperCategory/{}'.format(key)] = 0
if all_metrics_per_category:
for metric_name, _ in COCO_METRIC_NAMES_AND_INDEX:
metric_key = '{} BySuperCategory/{}'.format(metric_name, key)
super_category_ap[metric_key] = 0
for category_index, category_id in enumerate(self.GetCategoryIdList()):
category = self.GetCategory(category_id)['name']
# Kept for backward compatilbility
per_category_ap['PerformanceByCategory/mAP/{}'.format(
category)] = self.category_stats[0][category_index]
if all_metrics_per_category:
for metric_name, index in COCO_METRIC_NAMES_AND_INDEX:
metric_key = '{} ByCategory/{}'.format(metric_name, category)
per_category_ap[metric_key] = self.category_stats[index][
category_index]
if super_categories:
for key in super_categories:
if category in super_categories[key]:
metric_name = 'PerformanceBySuperCategory/{}'.format(key)
if metric_name not in super_category_ap:
super_category_ap[metric_name] = 0
super_category_ap[metric_name] += self.category_stats[0][
metric_key = 'PerformanceBySuperCategory/{}'.format(key)
super_category_ap[metric_key] += self.category_stats[0][
category_index]
if all_metrics_per_category:
per_category_ap['Precision mAP ByCategory/{}'.format(
category)] = self.category_stats[0][category_index]
per_category_ap['Precision mAP@.50IOU ByCategory/{}'.format(
category)] = self.category_stats[1][category_index]
per_category_ap['Precision mAP@.75IOU ByCategory/{}'.format(
category)] = self.category_stats[2][category_index]
per_category_ap['Precision mAP (small) ByCategory/{}'.format(
category)] = self.category_stats[3][category_index]
per_category_ap['Precision mAP (medium) ByCategory/{}'.format(
category)] = self.category_stats[4][category_index]
per_category_ap['Precision mAP (large) ByCategory/{}'.format(
category)] = self.category_stats[5][category_index]
per_category_ap['Recall AR@1 ByCategory/{}'.format(
category)] = self.category_stats[6][category_index]
per_category_ap['Recall AR@10 ByCategory/{}'.format(
category)] = self.category_stats[7][category_index]
per_category_ap['Recall AR@100 ByCategory/{}'.format(
category)] = self.category_stats[8][category_index]
per_category_ap['Recall AR@100 (small) ByCategory/{}'.format(
category)] = self.category_stats[9][category_index]
per_category_ap['Recall AR@100 (medium) ByCategory/{}'.format(
category)] = self.category_stats[10][category_index]
per_category_ap['Recall AR@100 (large) ByCategory/{}'.format(
category)] = self.category_stats[11][category_index]
if all_metrics_per_category:
for metric_name, index in COCO_METRIC_NAMES_AND_INDEX:
metric_key = '{} BySuperCategory/{}'.format(metric_name, key)
super_category_ap[metric_key] += (
self.category_stats[index][category_index])
if super_categories:
for key in super_categories:
metric_name = 'PerformanceBySuperCategory/{}'.format(key)
super_category_ap[metric_name] /= len(super_categories[key])
length = len(super_categories[key])
super_category_ap['PerformanceBySuperCategory/{}'.format(
key)] /= length
if all_metrics_per_category:
for metric_name, _ in COCO_METRIC_NAMES_AND_INDEX:
super_category_ap['{} BySuperCategory/{}'.format(
metric_name, key)] /= length
per_category_ap.update(super_category_ap)
return summary_metrics, per_category_ap
......
......@@ -3,7 +3,7 @@ syntax = "proto2";
package object_detection.protos;
// Message for configuring DetectionModel evaluation jobs (eval.py).
// Next id - 35
// Next id - 36
message EvalConfig {
optional uint32 batch_size = 25 [default = 1];
// Number of visualization images to generate.
......@@ -82,6 +82,9 @@ message EvalConfig {
// If True, additionally include per-category metrics.
optional bool include_metrics_per_category = 24 [default = false];
// If true, includes all metrics per category.
optional bool all_metrics_per_category = 35 [default=false];
// Optional super-category definitions: keys are super-category names;
// values are comma-separated categories (assumed to correspond to category
// names (`display_name`) in the label map.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册