metrics.py 8.4 KB
Newer Older
K
Kaipeng Deng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 
#   
# Licensed under the Apache License, Version 2.0 (the "License");   
# you may not use this file except in compliance with the License.  
# You may obtain a copy of the License at   
#   
#     http://www.apache.org/licenses/LICENSE-2.0    
#   
# Unless required by applicable law or agreed to in writing, software   
# distributed under the License is distributed on an "AS IS" BASIS, 
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
# See the License for the specific language governing permissions and   
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys
import json
import paddle
import numpy as np

from .map_utils import prune_zero_padding, DetectionMAP
from .coco_utils import get_infer_results, cocoapi_eval
27
from .widerface_utils import face_eval_run
K
Kaipeng Deng 已提交
28
from ppdet.data.source.category import get_categories
K
Kaipeng Deng 已提交
29 30 31 32

from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__)

33 34 35
__all__ = [
    'Metric', 'COCOMetric', 'VOCMetric', 'WiderFaceMetric', 'get_infer_results'
]
K
Kaipeng Deng 已提交
36 37 38 39 40 41


class Metric(paddle.metric.Metric):
    def name(self):
        return self.__class__.__name__

42 43 44 45 46 47
    def reset(self):
        pass

    def accumulate(self):
        pass

K
Kaipeng Deng 已提交
48 49 50 51 52 53 54 55 56 57 58 59 60
    # paddle.metric.Metric defined :metch:`update`, :meth:`accumulate`
    # :metch:`reset`, in ppdet, we also need following 2 methods:

    # abstract method for logging metric results
    def log(self):
        pass

    # abstract method for getting metric results
    def get_results(self):
        pass


class COCOMetric(Metric):
W
wangxinxin08 已提交
61
    def __init__(self, anno_file, **kwargs):
K
Kaipeng Deng 已提交
62 63 64
        assert os.path.isfile(anno_file), \
                "anno_file {} not a file".format(anno_file)
        self.anno_file = anno_file
K
Kaipeng Deng 已提交
65 66 67
        self.clsid2catid = kwargs.get('clsid2catid', None)
        if self.clsid2catid is None:
            self.clsid2catid, _ = get_categories('COCO', anno_file)
68
        self.classwise = kwargs.get('classwise', False)
S
shangliang Xu 已提交
69
        self.output_eval = kwargs.get('output_eval', None)
W
wangxinxin08 已提交
70 71
        # TODO: bias should be unified
        self.bias = kwargs.get('bias', 0)
K
Kaipeng Deng 已提交
72 73 74 75
        self.reset()

    def reset(self):
        # only bbox and mask evaluation support currently
G
Guanghua Yu 已提交
76
        self.results = {'bbox': [], 'mask': [], 'segm': []}
K
Kaipeng Deng 已提交
77 78 79 80 81 82 83 84
        self.eval_results = {}

    def update(self, inputs, outputs):
        outs = {}
        # outputs Tensor -> numpy.ndarray
        for k, v in outputs.items():
            outs[k] = v.numpy() if isinstance(v, paddle.Tensor) else v

85 86 87
        im_id = inputs['im_id']
        outs['im_id'] = im_id.numpy() if isinstance(im_id,
                                                    paddle.Tensor) else im_id
K
Kaipeng Deng 已提交
88

W
wangxinxin08 已提交
89 90
        infer_results = get_infer_results(
            outs, self.clsid2catid, bias=self.bias)
K
Kaipeng Deng 已提交
91 92 93 94
        self.results['bbox'] += infer_results[
            'bbox'] if 'bbox' in infer_results else []
        self.results['mask'] += infer_results[
            'mask'] if 'mask' in infer_results else []
G
Guanghua Yu 已提交
95 96
        self.results['segm'] += infer_results[
            'segm'] if 'segm' in infer_results else []
K
Kaipeng Deng 已提交
97 98 99

    def accumulate(self):
        if len(self.results['bbox']) > 0:
S
shangliang Xu 已提交
100 101 102 103
            output = "bbox.json"
            if self.output_eval:
                output = os.path.join(self.output_eval, output)
            with open(output, 'w') as f:
K
Kaipeng Deng 已提交
104 105 106 107
                json.dump(self.results['bbox'], f)
                logger.info('The bbox result is saved to bbox.json.')

            bbox_stats = cocoapi_eval(
S
shangliang Xu 已提交
108
                output,
109 110 111
                'bbox',
                anno_file=self.anno_file,
                classwise=self.classwise)
K
Kaipeng Deng 已提交
112 113 114 115
            self.eval_results['bbox'] = bbox_stats
            sys.stdout.flush()

        if len(self.results['mask']) > 0:
S
shangliang Xu 已提交
116 117 118 119
            output = "mask.json"
            if self.output_eval:
                output = os.path.join(self.output_eval, output)
            with open(output, 'w') as f:
K
Kaipeng Deng 已提交
120 121 122 123
                json.dump(self.results['mask'], f)
                logger.info('The mask result is saved to mask.json.')

            seg_stats = cocoapi_eval(
S
shangliang Xu 已提交
124
                output,
125 126 127
                'segm',
                anno_file=self.anno_file,
                classwise=self.classwise)
K
Kaipeng Deng 已提交
128 129 130
            self.eval_results['mask'] = seg_stats
            sys.stdout.flush()

G
Guanghua Yu 已提交
131
        if len(self.results['segm']) > 0:
S
shangliang Xu 已提交
132 133 134 135
            output = "segm.json"
            if self.output_eval:
                output = os.path.join(self.output_eval, output)
            with open(output, 'w') as f:
G
Guanghua Yu 已提交
136 137 138 139
                json.dump(self.results['segm'], f)
                logger.info('The segm result is saved to segm.json.')

            seg_stats = cocoapi_eval(
S
shangliang Xu 已提交
140
                output,
141 142 143
                'segm',
                anno_file=self.anno_file,
                classwise=self.classwise)
G
Guanghua Yu 已提交
144 145 146
            self.eval_results['mask'] = seg_stats
            sys.stdout.flush()

K
Kaipeng Deng 已提交
147 148 149 150 151 152 153 154 155
    def log(self):
        pass

    def get_results(self):
        return self.eval_results


class VOCMetric(Metric):
    def __init__(self,
156
                 label_list,
K
Kaipeng Deng 已提交
157 158 159 160
                 class_num=20,
                 overlap_thresh=0.5,
                 map_type='11point',
                 is_bbox_normalized=False,
161 162 163 164 165
                 evaluate_difficult=False,
                 classwise=False):
        assert os.path.isfile(label_list), \
                "label_list {} not a file".format(label_list)
        self.clsid2catid, self.catid2name = get_categories('VOC', label_list)
K
Kaipeng Deng 已提交
166 167 168 169 170 171 172 173 174

        self.overlap_thresh = overlap_thresh
        self.map_type = map_type
        self.evaluate_difficult = evaluate_difficult
        self.detection_map = DetectionMAP(
            class_num=class_num,
            overlap_thresh=overlap_thresh,
            map_type=map_type,
            is_bbox_normalized=is_bbox_normalized,
175 176 177
            evaluate_difficult=evaluate_difficult,
            catid2name=self.catid2name,
            classwise=classwise)
K
Kaipeng Deng 已提交
178 179 180 181 182 183 184

        self.reset()

    def reset(self):
        self.detection_map.reset()

    def update(self, inputs, outputs):
185 186 187
        bboxes = outputs['bbox'][:, 2:].numpy()
        scores = outputs['bbox'][:, 1].numpy()
        labels = outputs['bbox'][:, 0].numpy()
K
Kaipeng Deng 已提交
188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210
        bbox_lengths = outputs['bbox_num'].numpy()

        if bboxes.shape == (1, 1) or bboxes is None:
            return
        gt_boxes = inputs['gt_bbox'].numpy()
        gt_labels = inputs['gt_class'].numpy()
        difficults = inputs['difficult'].numpy() if not self.evaluate_difficult \
                            else None

        scale_factor = inputs['scale_factor'].numpy(
        ) if 'scale_factor' in inputs else np.ones(
            (gt_boxes.shape[0], 2)).astype('float32')

        bbox_idx = 0
        for i in range(gt_boxes.shape[0]):
            gt_box = gt_boxes[i]
            h, w = scale_factor[i]
            gt_box = gt_box / np.array([w, h, w, h])
            gt_label = gt_labels[i]
            difficult = None if difficults is None \
                            else difficults[i]
            bbox_num = bbox_lengths[i]
            bbox = bboxes[bbox_idx:bbox_idx + bbox_num]
211 212
            score = scores[bbox_idx:bbox_idx + bbox_num]
            label = labels[bbox_idx:bbox_idx + bbox_num]
K
Kaipeng Deng 已提交
213 214
            gt_box, gt_label, difficult = prune_zero_padding(gt_box, gt_label,
                                                             difficult)
215 216
            self.detection_map.update(bbox, score, label, gt_box, gt_label,
                                      difficult)
K
Kaipeng Deng 已提交
217 218 219 220 221 222 223 224 225 226 227 228
            bbox_idx += bbox_num

    def accumulate(self):
        logger.info("Accumulating evaluatation results...")
        self.detection_map.accumulate()

    def log(self):
        map_stat = 100. * self.detection_map.get_map()
        logger.info("mAP({:.2f}, {}) = {:.2f}%".format(self.overlap_thresh,
                                                       self.map_type, map_stat))

    def get_results(self):
229
        return {'bbox': [self.detection_map.get_map()]}
230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247


class WiderFaceMetric(Metric):
    def __init__(self, image_dir, anno_file, multi_scale=True):
        self.image_dir = image_dir
        self.anno_file = anno_file
        self.multi_scale = multi_scale
        self.clsid2catid, self.catid2name = get_categories('widerface')

    def update(self, model):

        face_eval_run(
            model,
            self.image_dir,
            self.anno_file,
            pred_dir='output/pred',
            eval_mode='widerface',
            multi_scale=self.multi_scale)