metrics.py 9.1 KB
Newer Older
K
Kaipeng Deng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 
#   
# Licensed under the Apache License, Version 2.0 (the "License");   
# you may not use this file except in compliance with the License.  
# You may obtain a copy of the License at   
#   
#     http://www.apache.org/licenses/LICENSE-2.0    
#   
# Unless required by applicable law or agreed to in writing, software   
# distributed under the License is distributed on an "AS IS" BASIS, 
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
# See the License for the specific language governing permissions and   
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys
import json
import paddle
import numpy as np

from .map_utils import prune_zero_padding, DetectionMAP
from .coco_utils import get_infer_results, cocoapi_eval
27
from .widerface_utils import face_eval_run
K
Kaipeng Deng 已提交
28
from ppdet.data.source.category import get_categories
K
Kaipeng Deng 已提交
29 30 31 32

from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__)

33 34 35
__all__ = [
    'Metric', 'COCOMetric', 'VOCMetric', 'WiderFaceMetric', 'get_infer_results'
]
K
Kaipeng Deng 已提交
36 37 38 39 40 41


class Metric(paddle.metric.Metric):
    def name(self):
        return self.__class__.__name__

42 43 44 45 46 47
    def reset(self):
        pass

    def accumulate(self):
        pass

K
Kaipeng Deng 已提交
48 49 50 51 52 53 54 55 56 57 58 59 60
    # paddle.metric.Metric defined :metch:`update`, :meth:`accumulate`
    # :metch:`reset`, in ppdet, we also need following 2 methods:

    # abstract method for logging metric results
    def log(self):
        pass

    # abstract method for getting metric results
    def get_results(self):
        pass


class COCOMetric(Metric):
W
wangxinxin08 已提交
61
    def __init__(self, anno_file, **kwargs):
K
Kaipeng Deng 已提交
62 63 64
        assert os.path.isfile(anno_file), \
                "anno_file {} not a file".format(anno_file)
        self.anno_file = anno_file
K
Kaipeng Deng 已提交
65 66 67
        self.clsid2catid = kwargs.get('clsid2catid', None)
        if self.clsid2catid is None:
            self.clsid2catid, _ = get_categories('COCO', anno_file)
68
        self.classwise = kwargs.get('classwise', False)
S
shangliang Xu 已提交
69
        self.output_eval = kwargs.get('output_eval', None)
W
wangxinxin08 已提交
70 71
        # TODO: bias should be unified
        self.bias = kwargs.get('bias', 0)
72
        self.save_prediction_only = kwargs.get('save_prediction_only', False)
K
Kaipeng Deng 已提交
73 74 75 76
        self.reset()

    def reset(self):
        # only bbox and mask evaluation support currently
G
Guanghua Yu 已提交
77
        self.results = {'bbox': [], 'mask': [], 'segm': []}
K
Kaipeng Deng 已提交
78 79 80 81 82 83 84 85
        self.eval_results = {}

    def update(self, inputs, outputs):
        outs = {}
        # outputs Tensor -> numpy.ndarray
        for k, v in outputs.items():
            outs[k] = v.numpy() if isinstance(v, paddle.Tensor) else v

86 87 88
        im_id = inputs['im_id']
        outs['im_id'] = im_id.numpy() if isinstance(im_id,
                                                    paddle.Tensor) else im_id
K
Kaipeng Deng 已提交
89

W
wangxinxin08 已提交
90 91
        infer_results = get_infer_results(
            outs, self.clsid2catid, bias=self.bias)
K
Kaipeng Deng 已提交
92 93 94 95
        self.results['bbox'] += infer_results[
            'bbox'] if 'bbox' in infer_results else []
        self.results['mask'] += infer_results[
            'mask'] if 'mask' in infer_results else []
G
Guanghua Yu 已提交
96 97
        self.results['segm'] += infer_results[
            'segm'] if 'segm' in infer_results else []
K
Kaipeng Deng 已提交
98 99 100

    def accumulate(self):
        if len(self.results['bbox']) > 0:
S
shangliang Xu 已提交
101 102 103 104
            output = "bbox.json"
            if self.output_eval:
                output = os.path.join(self.output_eval, output)
            with open(output, 'w') as f:
K
Kaipeng Deng 已提交
105 106 107
                json.dump(self.results['bbox'], f)
                logger.info('The bbox result is saved to bbox.json.')

108 109 110 111 112 113 114 115 116 117 118
            if self.save_prediction_only:
                logger.info('The bbox result is saved to {} and do not '
                            'evaluate the mAP.'.format(output))
            else:
                bbox_stats = cocoapi_eval(
                    output,
                    'bbox',
                    anno_file=self.anno_file,
                    classwise=self.classwise)
                self.eval_results['bbox'] = bbox_stats
                sys.stdout.flush()
K
Kaipeng Deng 已提交
119 120

        if len(self.results['mask']) > 0:
S
shangliang Xu 已提交
121 122 123 124
            output = "mask.json"
            if self.output_eval:
                output = os.path.join(self.output_eval, output)
            with open(output, 'w') as f:
K
Kaipeng Deng 已提交
125 126 127
                json.dump(self.results['mask'], f)
                logger.info('The mask result is saved to mask.json.')

128 129 130 131 132 133 134 135 136 137 138
            if self.save_prediction_only:
                logger.info('The mask result is saved to {} and do not '
                            'evaluate the mAP.'.format(output))
            else:
                seg_stats = cocoapi_eval(
                    output,
                    'segm',
                    anno_file=self.anno_file,
                    classwise=self.classwise)
                self.eval_results['mask'] = seg_stats
                sys.stdout.flush()
K
Kaipeng Deng 已提交
139

G
Guanghua Yu 已提交
140
        if len(self.results['segm']) > 0:
S
shangliang Xu 已提交
141 142 143 144
            output = "segm.json"
            if self.output_eval:
                output = os.path.join(self.output_eval, output)
            with open(output, 'w') as f:
G
Guanghua Yu 已提交
145 146 147
                json.dump(self.results['segm'], f)
                logger.info('The segm result is saved to segm.json.')

148 149 150 151 152 153 154 155 156 157 158
            if self.save_prediction_only:
                logger.info('The segm result is saved to {} and do not '
                            'evaluate the mAP.'.format(output))
            else:
                seg_stats = cocoapi_eval(
                    output,
                    'segm',
                    anno_file=self.anno_file,
                    classwise=self.classwise)
                self.eval_results['mask'] = seg_stats
                sys.stdout.flush()
G
Guanghua Yu 已提交
159

K
Kaipeng Deng 已提交
160 161 162 163 164 165 166 167 168
    def log(self):
        pass

    def get_results(self):
        return self.eval_results


class VOCMetric(Metric):
    def __init__(self,
169
                 label_list,
K
Kaipeng Deng 已提交
170 171 172 173
                 class_num=20,
                 overlap_thresh=0.5,
                 map_type='11point',
                 is_bbox_normalized=False,
174 175 176 177 178
                 evaluate_difficult=False,
                 classwise=False):
        assert os.path.isfile(label_list), \
                "label_list {} not a file".format(label_list)
        self.clsid2catid, self.catid2name = get_categories('VOC', label_list)
K
Kaipeng Deng 已提交
179 180 181 182 183 184 185 186 187

        self.overlap_thresh = overlap_thresh
        self.map_type = map_type
        self.evaluate_difficult = evaluate_difficult
        self.detection_map = DetectionMAP(
            class_num=class_num,
            overlap_thresh=overlap_thresh,
            map_type=map_type,
            is_bbox_normalized=is_bbox_normalized,
188 189 190
            evaluate_difficult=evaluate_difficult,
            catid2name=self.catid2name,
            classwise=classwise)
K
Kaipeng Deng 已提交
191 192 193 194 195 196 197

        self.reset()

    def reset(self):
        self.detection_map.reset()

    def update(self, inputs, outputs):
198 199 200
        bboxes = outputs['bbox'][:, 2:].numpy()
        scores = outputs['bbox'][:, 1].numpy()
        labels = outputs['bbox'][:, 0].numpy()
K
Kaipeng Deng 已提交
201 202 203 204
        bbox_lengths = outputs['bbox_num'].numpy()

        if bboxes.shape == (1, 1) or bboxes is None:
            return
W
wangguanzhong 已提交
205 206 207
        gt_boxes = inputs['gt_bbox']
        gt_labels = inputs['gt_class']
        difficults = inputs['difficult'] if not self.evaluate_difficult \
K
Kaipeng Deng 已提交
208 209 210 211 212 213 214
                            else None

        scale_factor = inputs['scale_factor'].numpy(
        ) if 'scale_factor' in inputs else np.ones(
            (gt_boxes.shape[0], 2)).astype('float32')

        bbox_idx = 0
W
wangguanzhong 已提交
215 216
        for i in range(len(gt_boxes)):
            gt_box = gt_boxes[i].numpy()
K
Kaipeng Deng 已提交
217 218
            h, w = scale_factor[i]
            gt_box = gt_box / np.array([w, h, w, h])
W
wangguanzhong 已提交
219
            gt_label = gt_labels[i].numpy()
K
Kaipeng Deng 已提交
220
            difficult = None if difficults is None \
W
wangguanzhong 已提交
221
                            else difficults[i].numpy()
K
Kaipeng Deng 已提交
222 223
            bbox_num = bbox_lengths[i]
            bbox = bboxes[bbox_idx:bbox_idx + bbox_num]
224 225
            score = scores[bbox_idx:bbox_idx + bbox_num]
            label = labels[bbox_idx:bbox_idx + bbox_num]
K
Kaipeng Deng 已提交
226 227
            gt_box, gt_label, difficult = prune_zero_padding(gt_box, gt_label,
                                                             difficult)
228 229
            self.detection_map.update(bbox, score, label, gt_box, gt_label,
                                      difficult)
K
Kaipeng Deng 已提交
230 231 232 233 234 235 236 237 238 239 240 241
            bbox_idx += bbox_num

    def accumulate(self):
        logger.info("Accumulating evaluatation results...")
        self.detection_map.accumulate()

    def log(self):
        map_stat = 100. * self.detection_map.get_map()
        logger.info("mAP({:.2f}, {}) = {:.2f}%".format(self.overlap_thresh,
                                                       self.map_type, map_stat))

    def get_results(self):
242
        return {'bbox': [self.detection_map.get_map()]}
243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260


class WiderFaceMetric(Metric):
    def __init__(self, image_dir, anno_file, multi_scale=True):
        self.image_dir = image_dir
        self.anno_file = anno_file
        self.multi_scale = multi_scale
        self.clsid2catid, self.catid2name = get_categories('widerface')

    def update(self, model):

        face_eval_run(
            model,
            self.image_dir,
            self.anno_file,
            pred_dir='output/pred',
            eval_mode='widerface',
            multi_scale=self.multi_scale)