metrics.py 13.5 KB
Newer Older
W
weishengyu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import paddle
import paddle.nn as nn
C
cuicheng01 已提交
18 19 20 21 22 23
import paddle.nn.functional as F

from sklearn.metrics import hamming_loss
from sklearn.metrics import accuracy_score as accuracy_metric
from sklearn.metrics import multilabel_confusion_matrix
from sklearn.preprocessing import binarize
W
weishengyu 已提交
24

C
cuicheng01 已提交
25 26
from ppcls.metric.avg_metrics import AvgMetrics
from ppcls.utils.misc import AverageMeter
D
dongshuilong 已提交
27

C
cuicheng01 已提交
28 29

class TopkAcc(AvgMetrics):
W
weishengyu 已提交
30 31 32 33 34 35
    def __init__(self, topk=(1, 5)):
        super().__init__()
        assert isinstance(topk, (int, list, tuple))
        if isinstance(topk, int):
            topk = [topk]
        self.topk = topk
C
cuicheng01 已提交
36 37 38
        self.reset()

    def reset(self):
C
cuicheng01 已提交
39
        self.avg_meters = {"top{}".format(k): AverageMeter("top{}".format(k)) for k in self.topk}
W
weishengyu 已提交
40 41 42 43 44 45 46 47 48

    def forward(self, x, label):
        if isinstance(x, dict):
            x = x["logits"]

        metric_dict = dict()
        for k in self.topk:
            metric_dict["top{}".format(k)] = paddle.metric.accuracy(
                x, label, k=k)
C
cuicheng01 已提交
49
            self.avg_meters["top{}".format(k)].update(metric_dict["top{}".format(k)], x.shape[0])
W
weishengyu 已提交
50 51
        return metric_dict

D
dongshuilong 已提交
52

W
weishengyu 已提交
53
class mAP(nn.Layer):
D
dongshuilong 已提交
54
    def __init__(self):
W
weishengyu 已提交
55 56
        super().__init__()

D
dongshuilong 已提交
57
    def forward(self, similarities_matrix, query_img_id, gallery_img_id,
D
dongshuilong 已提交
58
                keep_mask):
W
weishengyu 已提交
59
        metric_dict = dict()
D
dongshuilong 已提交
60 61 62 63 64 65 66 67 68 69 70

        choosen_indices = paddle.argsort(
            similarities_matrix, axis=1, descending=True)
        gallery_labels_transpose = paddle.transpose(gallery_img_id, [1, 0])
        gallery_labels_transpose = paddle.broadcast_to(
            gallery_labels_transpose,
            shape=[
                choosen_indices.shape[0], gallery_labels_transpose.shape[1]
            ])
        choosen_label = paddle.index_sample(gallery_labels_transpose,
                                            choosen_indices)
B
Bin Lu 已提交
71
        equal_flag = paddle.equal(choosen_label, query_img_id)
D
dongshuilong 已提交
72 73 74 75 76
        if keep_mask is not None:
            keep_mask = paddle.index_sample(
                keep_mask.astype('float32'), choosen_indices)
            equal_flag = paddle.logical_and(equal_flag,
                                            keep_mask.astype('bool'))
B
Bin Lu 已提交
77 78
        equal_flag = paddle.cast(equal_flag, 'float32')

D
dongshuilong 已提交
79 80 81 82 83 84
        num_rel = paddle.sum(equal_flag, axis=1)
        num_rel = paddle.greater_than(num_rel, paddle.to_tensor(0.))
        num_rel_index = paddle.nonzero(num_rel.astype("int"))
        num_rel_index = paddle.reshape(num_rel_index, [num_rel_index.shape[0]])
        equal_flag = paddle.index_select(equal_flag, num_rel_index, axis=0)

B
Bin Lu 已提交
85 86
        acc_sum = paddle.cumsum(equal_flag, axis=1)
        div = paddle.arange(acc_sum.shape[1]).astype("float32") + 1
D
dongshuilong 已提交
87
        precision = paddle.divide(acc_sum, div)
B
Bin Lu 已提交
88 89 90

        #calc map
        precision_mask = paddle.multiply(equal_flag, precision)
D
dongshuilong 已提交
91 92
        ap = paddle.sum(precision_mask, axis=1) / paddle.sum(equal_flag,
                                                             axis=1)
B
Bin Lu 已提交
93
        metric_dict["mAP"] = paddle.mean(ap).numpy()[0]
W
weishengyu 已提交
94 95
        return metric_dict

D
dongshuilong 已提交
96

W
weishengyu 已提交
97
class mINP(nn.Layer):
D
dongshuilong 已提交
98
    def __init__(self):
W
weishengyu 已提交
99 100
        super().__init__()

D
dongshuilong 已提交
101
    def forward(self, similarities_matrix, query_img_id, gallery_img_id,
D
dongshuilong 已提交
102
                keep_mask):
W
weishengyu 已提交
103
        metric_dict = dict()
D
dongshuilong 已提交
104 105 106 107 108 109 110 111 112 113 114

        choosen_indices = paddle.argsort(
            similarities_matrix, axis=1, descending=True)
        gallery_labels_transpose = paddle.transpose(gallery_img_id, [1, 0])
        gallery_labels_transpose = paddle.broadcast_to(
            gallery_labels_transpose,
            shape=[
                choosen_indices.shape[0], gallery_labels_transpose.shape[1]
            ])
        choosen_label = paddle.index_sample(gallery_labels_transpose,
                                            choosen_indices)
D
dongshuilong 已提交
115 116 117 118 119 120
        equal_flag = paddle.equal(choosen_label, query_img_id)
        if keep_mask is not None:
            keep_mask = paddle.index_sample(
                keep_mask.astype('float32'), choosen_indices)
            equal_flag = paddle.logical_and(equal_flag,
                                            keep_mask.astype('bool'))
D
dongshuilong 已提交
121
        equal_flag = paddle.cast(equal_flag, 'float32')
D
dongshuilong 已提交
122 123 124 125 126 127

        num_rel = paddle.sum(equal_flag, axis=1)
        num_rel = paddle.greater_than(num_rel, paddle.to_tensor(0.))
        num_rel_index = paddle.nonzero(num_rel.astype("int"))
        num_rel_index = paddle.reshape(num_rel_index, [num_rel_index.shape[0]])
        equal_flag = paddle.index_select(equal_flag, num_rel_index, axis=0)
B
Bin Lu 已提交
128 129

        #do accumulative sum
D
dongshuilong 已提交
130
        div = paddle.arange(equal_flag.shape[1]).astype("float32") + 2
D
dongshuilong 已提交
131 132
        minus = paddle.divide(equal_flag, div)
        auxilary = paddle.subtract(equal_flag, minus)
D
dongshuilong 已提交
133
        hard_index = paddle.argmax(auxilary, axis=1).astype("float32")
D
dongshuilong 已提交
134
        all_INP = paddle.divide(paddle.sum(equal_flag, axis=1), hard_index)
B
Bin Lu 已提交
135 136
        mINP = paddle.mean(all_INP)
        metric_dict["mINP"] = mINP.numpy()[0]
W
weishengyu 已提交
137 138
        return metric_dict

D
dongshuilong 已提交
139

C
cuicheng01 已提交
140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
class TprAtFpr(nn.Layer):
    def __init__(self, max_fpr=1/1000.):
        super().__init__()
        self.gt_pos_score_list = []
        self.gt_neg_score_list = []
        self.softmax = nn.Softmax(axis=-1)
        self.max_fpr = max_fpr
        self.max_tpr = 0.

    def forward(self, x, label):
        if isinstance(x, dict):
            x = x["logits"]
        x = self.softmax(x)
        for i, label_i in enumerate(label):
            if label_i[0] == 0:
                self.gt_neg_score_list.append(x[i][1].numpy())
            else:
                self.gt_pos_score_list.append(x[i][1].numpy())
        return {}

    def reset(self):
        self.gt_pos_score_list = []
        self.gt_neg_score_list = []
        self.max_tpr = 0.

    @property
    def avg(self):
        return self.max_tpr

    @property
    def avg_info(self):
        max_tpr = 0.
        result = ""
        gt_pos_score_list = np.array(self.gt_pos_score_list)
        gt_neg_score_list = np.array(self.gt_neg_score_list)
        for i in range(0, 10000):
            threshold = i / 10000.
            if len(gt_pos_score_list) == 0:
                continue
            tpr = np.sum(gt_pos_score_list > threshold) / len(gt_pos_score_list)
            if len(gt_neg_score_list) == 0 and tpr > max_tpr:
                max_tpr = tpr
                result = "threshold: {}, fpr: {}, tpr: {:.5f}".format(threshold, fpr, tpr)
            fpr = np.sum(gt_neg_score_list > threshold) / len(gt_neg_score_list)
            if fpr <= self.max_fpr and tpr > max_tpr:
                max_tpr = tpr
                result = "threshold: {}, fpr: {}, tpr: {:.5f}".format(threshold, fpr, tpr)
        self.max_tpr = max_tpr
        return result


W
weishengyu 已提交
191
class Recallk(nn.Layer):
D
dongshuilong 已提交
192
    def __init__(self, topk=(1, 5)):
W
weishengyu 已提交
193
        super().__init__()
B
Bin Lu 已提交
194
        assert isinstance(topk, (int, list, tuple))
W
weishengyu 已提交
195 196 197 198
        if isinstance(topk, int):
            topk = [topk]
        self.topk = topk

D
dongshuilong 已提交
199 200
    def forward(self, similarities_matrix, query_img_id, gallery_img_id,
                keep_mask):
W
weishengyu 已提交
201
        metric_dict = dict()
B
Bin Lu 已提交
202 203

        #get cmc
D
dongshuilong 已提交
204 205 206 207 208 209 210 211 212 213
        choosen_indices = paddle.argsort(
            similarities_matrix, axis=1, descending=True)
        gallery_labels_transpose = paddle.transpose(gallery_img_id, [1, 0])
        gallery_labels_transpose = paddle.broadcast_to(
            gallery_labels_transpose,
            shape=[
                choosen_indices.shape[0], gallery_labels_transpose.shape[1]
            ])
        choosen_label = paddle.index_sample(gallery_labels_transpose,
                                            choosen_indices)
B
Bin Lu 已提交
214
        equal_flag = paddle.equal(choosen_label, query_img_id)
D
dongshuilong 已提交
215 216 217 218 219
        if keep_mask is not None:
            keep_mask = paddle.index_sample(
                keep_mask.astype('float32'), choosen_indices)
            equal_flag = paddle.logical_and(equal_flag,
                                            keep_mask.astype('bool'))
B
Bin Lu 已提交
220
        equal_flag = paddle.cast(equal_flag, 'float32')
D
dongshuilong 已提交
221 222 223 224 225
        real_query_num = paddle.sum(equal_flag, axis=1)
        real_query_num = paddle.sum(
            paddle.greater_than(real_query_num, paddle.to_tensor(0.)).astype(
                "float32"))

B
Bin Lu 已提交
226
        acc_sum = paddle.cumsum(equal_flag, axis=1)
D
dongshuilong 已提交
227 228 229
        mask = paddle.greater_than(acc_sum,
                                   paddle.to_tensor(0.)).astype("float32")
        all_cmc = (paddle.sum(mask, axis=0) / real_query_num).numpy()
W
weishengyu 已提交
230 231 232 233 234

        for k in self.topk:
            metric_dict["recall{}".format(k)] = all_cmc[k - 1]
        return metric_dict

D
dongshuilong 已提交
235

B
Bin Lu 已提交
236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265
class Precisionk(nn.Layer):
    def __init__(self, topk=(1, 5)):
        super().__init__()
        assert isinstance(topk, (int, list, tuple))
        if isinstance(topk, int):
            topk = [topk]
        self.topk = topk

    def forward(self, similarities_matrix, query_img_id, gallery_img_id,
                keep_mask):
        metric_dict = dict()

        #get cmc
        choosen_indices = paddle.argsort(
            similarities_matrix, axis=1, descending=True)
        gallery_labels_transpose = paddle.transpose(gallery_img_id, [1, 0])
        gallery_labels_transpose = paddle.broadcast_to(
            gallery_labels_transpose,
            shape=[
                choosen_indices.shape[0], gallery_labels_transpose.shape[1]
            ])
        choosen_label = paddle.index_sample(gallery_labels_transpose,
                                            choosen_indices)
        equal_flag = paddle.equal(choosen_label, query_img_id)
        if keep_mask is not None:
            keep_mask = paddle.index_sample(
                keep_mask.astype('float32'), choosen_indices)
            equal_flag = paddle.logical_and(equal_flag,
                                            keep_mask.astype('bool'))
        equal_flag = paddle.cast(equal_flag, 'float32')
C
cuicheng01 已提交
266

B
Bin Lu 已提交
267 268 269 270 271 272 273 274 275 276
        Ns = paddle.arange(gallery_img_id.shape[0]) + 1
        equal_flag_cumsum = paddle.cumsum(equal_flag, axis=1)
        Precision_at_k = (paddle.mean(equal_flag_cumsum, axis=0) / Ns).numpy()

        for k in self.topk:
            metric_dict["precision@{}".format(k)] = Precision_at_k[k - 1]

        return metric_dict


277 278 279 280 281 282 283
class DistillationTopkAcc(TopkAcc):
    def __init__(self, model_key, feature_key=None, topk=(1, 5)):
        super().__init__(topk=topk)
        self.model_key = model_key
        self.feature_key = feature_key

    def forward(self, x, label):
G
gaotingquan 已提交
284 285
        if isinstance(x, dict):
            x = x[self.model_key]
286 287 288
        if self.feature_key is not None:
            x = x[self.feature_key]
        return super().forward(x, label)
C
cuicheng01 已提交
289 290 291 292 293 294 295 296 297 298 299 300


class GoogLeNetTopkAcc(TopkAcc):
    def __init__(self, topk=(1, 5)):
        super().__init__()
        assert isinstance(topk, (int, list, tuple))
        if isinstance(topk, int):
            topk = [topk]
        self.topk = topk

    def forward(self, x, label):
        return super().forward(x[0], label)
C
cuicheng01 已提交
301 302


C
cuicheng01 已提交
303 304 305 306
class MultiLabelMetric(AvgMetrics):
    def __init__(self, bi_threshold=0.5):
        super().__init__()
        self.bi_threshold = bi_threshold
C
cuicheng01 已提交
307

C
cuicheng01 已提交
308 309 310
    def _multi_hot_encode(self, output):
        logits = F.sigmoid(output).numpy()
        return binarize(logits, threshold=self.bi_threshold)
C
cuicheng01 已提交
311 312


C
cuicheng01 已提交
313
class HammingDistance(MultiLabelMetric):
C
cuicheng01 已提交
314 315 316 317 318 319 320 321
    """
    Soft metric based label for multilabel classification
    Returns:
        The smaller the return value is, the better model is.
    """

    def __init__(self):
        super().__init__()
C
cuicheng01 已提交
322 323 324
        self.reset()

    def reset(self):
C
cuicheng01 已提交
325
        self.avg_meters = {"HammingDistance": AverageMeter("HammingDistance")}
C
cuicheng01 已提交
326

C
cuicheng01 已提交
327 328
    def forward(self, output, target):
        preds = super()._multi_hot_encode(output)
C
cuicheng01 已提交
329 330 331
        metric_dict = dict()
        metric_dict["HammingDistance"] = paddle.to_tensor(
            hamming_loss(target, preds))
C
cuicheng01 已提交
332
        self.avg_meters["HammingDistance"].update(metric_dict["HammingDistance"].numpy()[0], output.shape[0])
C
cuicheng01 已提交
333 334 335
        return metric_dict


C
cuicheng01 已提交
336
class AccuracyScore(MultiLabelMetric):
C
cuicheng01 已提交
337 338 339 340 341 342 343 344 345 346 347 348 349 350 351
    """
    Hard metric for multilabel classification
    Args:
        base: ["sample", "label"], default="sample"
            if "sample", return metric score based sample,
            if "label", return metric score based label.
    Returns:
        accuracy:
    """

    def __init__(self, base="label"):
        super().__init__()
        assert base in ["sample", "label"
                        ], 'must be one of ["sample", "label"]'
        self.base = base
C
cuicheng01 已提交
352 353 354 355
        self.reset()

    def reset(self):
        self.avg_meters = {"AccuracyScore": AverageMeter("AccuracyScore")}
C
cuicheng01 已提交
356

C
cuicheng01 已提交
357 358
    def forward(self, output, target):
        preds = super()._multi_hot_encode(output)
C
cuicheng01 已提交
359 360 361 362 363 364 365 366 367 368 369 370
        metric_dict = dict()
        if self.base == "sample":
            accuracy = accuracy_metric(target, preds)
        elif self.base == "label":
            mcm = multilabel_confusion_matrix(target, preds)
            tns = mcm[:, 0, 0]
            fns = mcm[:, 1, 0]
            tps = mcm[:, 1, 1]
            fps = mcm[:, 0, 1]
            accuracy = (sum(tps) + sum(tns)) / (
                sum(tps) + sum(tns) + sum(fns) + sum(fps))
        metric_dict["AccuracyScore"] = paddle.to_tensor(accuracy)
C
cuicheng01 已提交
371
        self.avg_meters["AccuracyScore"].update(metric_dict["AccuracyScore"].numpy()[0], output.shape[0])
C
cuicheng01 已提交
372
        return metric_dict