coco_metric.py 3.5 KB
Newer Older
D
dengkaipeng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import json
D
dengkaipeng 已提交
17 18
from pycocotools.cocoeval import COCOeval
from pycocotools.coco import COCO
D
dengkaipeng 已提交
19 20 21 22 23 24 25 26 27 28 29 30

import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)

__all__ = ['COCOMetric']


OUTFILE = './bbox.json'


31
# considered to change to a callback later
D
dengkaipeng 已提交
32
class COCOMetric():
D
dengkaipeng 已提交
33
    """
D
dengkaipeng 已提交
34 35
    Metrci for MS-COCO dataset, only support update with batch
    size as 1.
D
dengkaipeng 已提交
36

D
dengkaipeng 已提交
37 38 39 40
    Args:
        anno_path(str): path to COCO annotation json file
        with_background(bool): whether load category id with
                               background as 0, default True
D
dengkaipeng 已提交
41 42 43 44 45
    """

    def __init__(self, anno_path, with_background=True, **kwargs):
        self.anno_path = anno_path
        self.with_background = with_background
D
dengkaipeng 已提交
46
        self.bbox_results = []
D
dengkaipeng 已提交
47 48 49

        self.coco_gt = COCO(anno_path)
        cat_ids = self.coco_gt.getCatIds()
D
dengkaipeng 已提交
50 51 52
        self.clsid2catid = dict(
            {i + int(with_background): catid
            for i, catid in enumerate(cat_ids)})
D
dengkaipeng 已提交
53

D
dengkaipeng 已提交
54 55
    def update(self, img_id, bboxes):
        assert img_id.shape[0] == 1, \
D
dengkaipeng 已提交
56 57
            "COCOMetric can only update with batch size = 1"
        if bboxes.shape[1] != 6:
D
dengkaipeng 已提交
58 59 60
            # no bbox detected in this batch
            return

D
dengkaipeng 已提交
61
        img_id = int(img_id)
D
dengkaipeng 已提交
62 63 64 65 66 67 68 69 70
        for i in range(bboxes.shape[0]):
            dt = bboxes[i, :]
            clsid, score, xmin, ymin, xmax, ymax = dt.tolist()
            catid = (self.clsid2catid[int(clsid)])

            w = xmax - xmin + 1
            h = ymax - ymin + 1
            bbox = [xmin, ymin, w, h]
            coco_res = {
D
dengkaipeng 已提交
71
                'image_id': img_id,
D
dengkaipeng 已提交
72 73 74 75 76 77 78 79
                'category_id': catid,
                'bbox': bbox,
                'score': score
            }
            self.bbox_results.append(coco_res)

    def reset(self):
        self.bbox_results = []
D
dengkaipeng 已提交
80 81

    def accumulate(self):
D
dengkaipeng 已提交
82 83 84 85 86 87 88 89 90 91 92
        if len(self.bbox_results) == 0:
            logger.warning("The number of valid bbox detected is zero.\n \
                    Please use reasonable model and check input data.\n \
                    stop COCOMetric accumulate!")
            return [0.0]
        with open(OUTFILE, 'w') as f:
            json.dump(self.bbox_results, f)

        map_stats = self.cocoapi_eval(OUTFILE, 'bbox', coco_gt=self.coco_gt)
        # flush coco evaluation result
        sys.stdout.flush()
D
dengkaipeng 已提交
93
        self.result = map_stats[0]
D
dengkaipeng 已提交
94
        return [self.result]
D
dengkaipeng 已提交
95 96

    def cocoapi_eval(self, jsonfile, style, coco_gt=None, anno_file=None):
D
dengkaipeng 已提交
97 98 99 100 101 102 103 104 105 106 107
        assert coco_gt != None or anno_file != None
        
        if coco_gt == None:
            coco_gt = COCO(anno_file)
        logger.info("Start evaluate...")
        coco_dt = coco_gt.loadRes(jsonfile) 
        coco_eval = COCOeval(coco_gt, coco_dt, style)
        coco_eval.evaluate()
        coco_eval.accumulate()
        coco_eval.summarize()
        return coco_eval.stats
D
dengkaipeng 已提交
108