coco_metric.py 3.6 KB
Newer Older
D
dengkaipeng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import json
D
dengkaipeng 已提交
17 18
from pycocotools.cocoeval import COCOeval
from pycocotools.coco import COCO
D
dengkaipeng 已提交
19 20 21 22 23 24 25 26 27 28 29

import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)

__all__ = ['COCOMetric']

OUTFILE = './bbox.json'


D
dengkaipeng 已提交
30 31 32
# COCOMetric behavior is different from Metric defined in high
# level API, COCOMetric will and con only accumulate on the epoch
# end, so we impliment COCOMetric as not a high level API Metric
D
dengkaipeng 已提交
33
class COCOMetric():
D
dengkaipeng 已提交
34
    """
D
dengkaipeng 已提交
35 36
    Metrci for MS-COCO dataset, only support update with batch
    size as 1.
D
dengkaipeng 已提交
37

D
dengkaipeng 已提交
38 39 40 41
    Args:
        anno_path(str): path to COCO annotation json file
        with_background(bool): whether load category id with
                               background as 0, default True
D
dengkaipeng 已提交
42 43 44 45 46
    """

    def __init__(self, anno_path, with_background=True, **kwargs):
        self.anno_path = anno_path
        self.with_background = with_background
D
dengkaipeng 已提交
47
        self.bbox_results = []
D
dengkaipeng 已提交
48 49 50

        self.coco_gt = COCO(anno_path)
        cat_ids = self.coco_gt.getCatIds()
D
dengkaipeng 已提交
51 52 53
        self.clsid2catid = dict(
            {i + int(with_background): catid
            for i, catid in enumerate(cat_ids)})
D
dengkaipeng 已提交
54

D
dengkaipeng 已提交
55 56
    def update(self, img_id, bboxes):
        assert img_id.shape[0] == 1, \
D
dengkaipeng 已提交
57 58
            "COCOMetric can only update with batch size = 1"
        if bboxes.shape[1] != 6:
D
dengkaipeng 已提交
59 60 61
            # no bbox detected in this batch
            return

D
dengkaipeng 已提交
62
        img_id = int(img_id)
D
dengkaipeng 已提交
63 64 65 66 67 68 69 70 71
        for i in range(bboxes.shape[0]):
            dt = bboxes[i, :]
            clsid, score, xmin, ymin, xmax, ymax = dt.tolist()
            catid = (self.clsid2catid[int(clsid)])

            w = xmax - xmin + 1
            h = ymax - ymin + 1
            bbox = [xmin, ymin, w, h]
            coco_res = {
D
dengkaipeng 已提交
72
                'image_id': img_id,
D
dengkaipeng 已提交
73 74 75 76 77 78 79 80
                'category_id': catid,
                'bbox': bbox,
                'score': score
            }
            self.bbox_results.append(coco_res)

    def reset(self):
        self.bbox_results = []
D
dengkaipeng 已提交
81 82

    def accumulate(self):
D
dengkaipeng 已提交
83 84 85 86 87 88 89 90 91 92 93
        if len(self.bbox_results) == 0:
            logger.warning("The number of valid bbox detected is zero.\n \
                    Please use reasonable model and check input data.\n \
                    stop COCOMetric accumulate!")
            return [0.0]
        with open(OUTFILE, 'w') as f:
            json.dump(self.bbox_results, f)

        map_stats = self.cocoapi_eval(OUTFILE, 'bbox', coco_gt=self.coco_gt)
        # flush coco evaluation result
        sys.stdout.flush()
D
dengkaipeng 已提交
94
        self.result = map_stats[0]
D
dengkaipeng 已提交
95
        return [self.result]
D
dengkaipeng 已提交
96 97

    def cocoapi_eval(self, jsonfile, style, coco_gt=None, anno_file=None):
D
dengkaipeng 已提交
98 99 100 101 102 103 104 105 106 107 108
        assert coco_gt != None or anno_file != None
        
        if coco_gt == None:
            coco_gt = COCO(anno_file)
        logger.info("Start evaluate...")
        coco_dt = coco_gt.loadRes(jsonfile) 
        coco_eval = COCOeval(coco_gt, coco_dt, style)
        coco_eval.evaluate()
        coco_eval.accumulate()
        coco_eval.summarize()
        return coco_eval.stats
D
dengkaipeng 已提交
109