metrics.py 16.3 KB
Newer Older
W
weishengyu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import paddle
import paddle.nn as nn
C
cuicheng01 已提交
18 19 20 21 22 23
import paddle.nn.functional as F

from sklearn.metrics import hamming_loss
from sklearn.metrics import accuracy_score as accuracy_metric
from sklearn.metrics import multilabel_confusion_matrix
from sklearn.preprocessing import binarize
W
weishengyu 已提交
24

Z
zhiboniu 已提交
25 26
from easydict import EasyDict

C
cuicheng01 已提交
27
from ppcls.metric.avg_metrics import AvgMetrics
Z
zhiboniu 已提交
28
from ppcls.utils.misc import AverageMeter, AttrMeter
29
from ppcls.utils import logger
D
dongshuilong 已提交
30

C
cuicheng01 已提交
31 32

class TopkAcc(AvgMetrics):
W
weishengyu 已提交
33 34 35 36 37 38
    def __init__(self, topk=(1, 5)):
        super().__init__()
        assert isinstance(topk, (int, list, tuple))
        if isinstance(topk, int):
            topk = [topk]
        self.topk = topk
C
cuicheng01 已提交
39 40 41
        self.reset()

    def reset(self):
42
        self.avg_meters = {
43
            f"top{k}": AverageMeter(f"top{k}")
44 45
            for k in self.topk
        }
W
weishengyu 已提交
46 47 48 49 50

    def forward(self, x, label):
        if isinstance(x, dict):
            x = x["logits"]

51 52
        output_dims = x.shape[-1]

W
weishengyu 已提交
53
        metric_dict = dict()
54 55 56 57
        for idx, k in enumerate(self.topk):
            if output_dims < k:
                msg = f"The output dims({output_dims}) is less than k({k}), and the argument {k} of Topk has been removed."
                logger.warning(msg)
58
                self.avg_meters.pop(f"top{k}")
59
                continue
60 61 62 63
            metric_dict[f"top{k}"] = paddle.metric.accuracy(x, label, k=k)
            self.avg_meters[f"top{k}"].update(metric_dict[f"top{k}"],
                                              x.shape[0])

T
Tingquan Gao 已提交
64
        self.topk = list(filter(lambda k: k <= output_dims, self.topk))
65

W
weishengyu 已提交
66 67
        return metric_dict

D
dongshuilong 已提交
68

W
weishengyu 已提交
69
class mAP(nn.Layer):
H
HydrogenSulfate 已提交
70
    def __init__(self, descending=True):
W
weishengyu 已提交
71
        super().__init__()
H
HydrogenSulfate 已提交
72
        self.descending = descending
W
weishengyu 已提交
73

D
dongshuilong 已提交
74
    def forward(self, similarities_matrix, query_img_id, gallery_img_id,
D
dongshuilong 已提交
75
                keep_mask):
W
weishengyu 已提交
76
        metric_dict = dict()
D
dongshuilong 已提交
77 78

        choosen_indices = paddle.argsort(
H
HydrogenSulfate 已提交
79
            similarities_matrix, axis=1, descending=self.descending)
D
dongshuilong 已提交
80 81 82 83 84 85 86 87
        gallery_labels_transpose = paddle.transpose(gallery_img_id, [1, 0])
        gallery_labels_transpose = paddle.broadcast_to(
            gallery_labels_transpose,
            shape=[
                choosen_indices.shape[0], gallery_labels_transpose.shape[1]
            ])
        choosen_label = paddle.index_sample(gallery_labels_transpose,
                                            choosen_indices)
B
Bin Lu 已提交
88
        equal_flag = paddle.equal(choosen_label, query_img_id)
D
dongshuilong 已提交
89 90 91 92 93
        if keep_mask is not None:
            keep_mask = paddle.index_sample(
                keep_mask.astype('float32'), choosen_indices)
            equal_flag = paddle.logical_and(equal_flag,
                                            keep_mask.astype('bool'))
B
Bin Lu 已提交
94 95
        equal_flag = paddle.cast(equal_flag, 'float32')

D
dongshuilong 已提交
96 97 98 99 100 101
        num_rel = paddle.sum(equal_flag, axis=1)
        num_rel = paddle.greater_than(num_rel, paddle.to_tensor(0.))
        num_rel_index = paddle.nonzero(num_rel.astype("int"))
        num_rel_index = paddle.reshape(num_rel_index, [num_rel_index.shape[0]])
        equal_flag = paddle.index_select(equal_flag, num_rel_index, axis=0)

B
Bin Lu 已提交
102 103
        acc_sum = paddle.cumsum(equal_flag, axis=1)
        div = paddle.arange(acc_sum.shape[1]).astype("float32") + 1
D
dongshuilong 已提交
104
        precision = paddle.divide(acc_sum, div)
B
Bin Lu 已提交
105 106 107

        #calc map
        precision_mask = paddle.multiply(equal_flag, precision)
D
dongshuilong 已提交
108 109
        ap = paddle.sum(precision_mask, axis=1) / paddle.sum(equal_flag,
                                                             axis=1)
B
Bin Lu 已提交
110
        metric_dict["mAP"] = paddle.mean(ap).numpy()[0]
W
weishengyu 已提交
111 112
        return metric_dict

D
dongshuilong 已提交
113

W
weishengyu 已提交
114
class mINP(nn.Layer):
H
HydrogenSulfate 已提交
115
    def __init__(self, descending=True):
W
weishengyu 已提交
116
        super().__init__()
H
HydrogenSulfate 已提交
117
        self.descending = descending
W
weishengyu 已提交
118

D
dongshuilong 已提交
119
    def forward(self, similarities_matrix, query_img_id, gallery_img_id,
D
dongshuilong 已提交
120
                keep_mask):
W
weishengyu 已提交
121
        metric_dict = dict()
D
dongshuilong 已提交
122 123

        choosen_indices = paddle.argsort(
H
HydrogenSulfate 已提交
124
            similarities_matrix, axis=1, descending=self.descending)
D
dongshuilong 已提交
125 126 127 128 129 130 131 132
        gallery_labels_transpose = paddle.transpose(gallery_img_id, [1, 0])
        gallery_labels_transpose = paddle.broadcast_to(
            gallery_labels_transpose,
            shape=[
                choosen_indices.shape[0], gallery_labels_transpose.shape[1]
            ])
        choosen_label = paddle.index_sample(gallery_labels_transpose,
                                            choosen_indices)
D
dongshuilong 已提交
133 134
        equal_flag = paddle.equal(choosen_label, query_img_id)
        if keep_mask is not None:
135
            keep_mask = paddle.indechmx_sample(
D
dongshuilong 已提交
136 137 138
                keep_mask.astype('float32'), choosen_indices)
            equal_flag = paddle.logical_and(equal_flag,
                                            keep_mask.astype('bool'))
D
dongshuilong 已提交
139
        equal_flag = paddle.cast(equal_flag, 'float32')
D
dongshuilong 已提交
140 141 142 143 144 145

        num_rel = paddle.sum(equal_flag, axis=1)
        num_rel = paddle.greater_than(num_rel, paddle.to_tensor(0.))
        num_rel_index = paddle.nonzero(num_rel.astype("int"))
        num_rel_index = paddle.reshape(num_rel_index, [num_rel_index.shape[0]])
        equal_flag = paddle.index_select(equal_flag, num_rel_index, axis=0)
B
Bin Lu 已提交
146 147

        #do accumulative sum
D
dongshuilong 已提交
148
        div = paddle.arange(equal_flag.shape[1]).astype("float32") + 2
D
dongshuilong 已提交
149 150
        minus = paddle.divide(equal_flag, div)
        auxilary = paddle.subtract(equal_flag, minus)
D
dongshuilong 已提交
151
        hard_index = paddle.argmax(auxilary, axis=1).astype("float32")
D
dongshuilong 已提交
152
        all_INP = paddle.divide(paddle.sum(equal_flag, axis=1), hard_index)
B
Bin Lu 已提交
153 154
        mINP = paddle.mean(all_INP)
        metric_dict["mINP"] = mINP.numpy()[0]
W
weishengyu 已提交
155 156
        return metric_dict

D
dongshuilong 已提交
157

C
cuicheng01 已提交
158
class TprAtFpr(nn.Layer):
159
    def __init__(self, max_fpr=1 / 1000.):
C
cuicheng01 已提交
160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196
        super().__init__()
        self.gt_pos_score_list = []
        self.gt_neg_score_list = []
        self.softmax = nn.Softmax(axis=-1)
        self.max_fpr = max_fpr
        self.max_tpr = 0.

    def forward(self, x, label):
        if isinstance(x, dict):
            x = x["logits"]
        x = self.softmax(x)
        for i, label_i in enumerate(label):
            if label_i[0] == 0:
                self.gt_neg_score_list.append(x[i][1].numpy())
            else:
                self.gt_pos_score_list.append(x[i][1].numpy())
        return {}

    def reset(self):
        self.gt_pos_score_list = []
        self.gt_neg_score_list = []
        self.max_tpr = 0.

    @property
    def avg(self):
        return self.max_tpr

    @property
    def avg_info(self):
        max_tpr = 0.
        result = ""
        gt_pos_score_list = np.array(self.gt_pos_score_list)
        gt_neg_score_list = np.array(self.gt_neg_score_list)
        for i in range(0, 10000):
            threshold = i / 10000.
            if len(gt_pos_score_list) == 0:
                continue
197 198
            tpr = np.sum(
                gt_pos_score_list > threshold) / len(gt_pos_score_list)
C
cuicheng01 已提交
199 200
            if len(gt_neg_score_list) == 0 and tpr > max_tpr:
                max_tpr = tpr
201 202 203 204
                result = "threshold: {}, fpr: {}, tpr: {:.5f}".format(
                    threshold, fpr, tpr)
            fpr = np.sum(
                gt_neg_score_list > threshold) / len(gt_neg_score_list)
C
cuicheng01 已提交
205 206
            if fpr <= self.max_fpr and tpr > max_tpr:
                max_tpr = tpr
207 208
                result = "threshold: {}, fpr: {}, tpr: {:.5f}".format(
                    threshold, fpr, tpr)
C
cuicheng01 已提交
209 210 211 212
        self.max_tpr = max_tpr
        return result


W
weishengyu 已提交
213
class Recallk(nn.Layer):
H
HydrogenSulfate 已提交
214
    def __init__(self, topk=(1, 5), descending=True):
W
weishengyu 已提交
215
        super().__init__()
B
Bin Lu 已提交
216
        assert isinstance(topk, (int, list, tuple))
W
weishengyu 已提交
217 218 219
        if isinstance(topk, int):
            topk = [topk]
        self.topk = topk
H
HydrogenSulfate 已提交
220
        self.descending = descending
W
weishengyu 已提交
221

D
dongshuilong 已提交
222 223
    def forward(self, similarities_matrix, query_img_id, gallery_img_id,
                keep_mask):
W
weishengyu 已提交
224
        metric_dict = dict()
B
Bin Lu 已提交
225 226

        #get cmc
D
dongshuilong 已提交
227
        choosen_indices = paddle.argsort(
H
HydrogenSulfate 已提交
228
            similarities_matrix, axis=1, descending=self.descending)
D
dongshuilong 已提交
229 230 231 232 233 234 235 236
        gallery_labels_transpose = paddle.transpose(gallery_img_id, [1, 0])
        gallery_labels_transpose = paddle.broadcast_to(
            gallery_labels_transpose,
            shape=[
                choosen_indices.shape[0], gallery_labels_transpose.shape[1]
            ])
        choosen_label = paddle.index_sample(gallery_labels_transpose,
                                            choosen_indices)
B
Bin Lu 已提交
237
        equal_flag = paddle.equal(choosen_label, query_img_id)
D
dongshuilong 已提交
238 239 240 241 242
        if keep_mask is not None:
            keep_mask = paddle.index_sample(
                keep_mask.astype('float32'), choosen_indices)
            equal_flag = paddle.logical_and(equal_flag,
                                            keep_mask.astype('bool'))
B
Bin Lu 已提交
243
        equal_flag = paddle.cast(equal_flag, 'float32')
D
dongshuilong 已提交
244 245 246 247 248
        real_query_num = paddle.sum(equal_flag, axis=1)
        real_query_num = paddle.sum(
            paddle.greater_than(real_query_num, paddle.to_tensor(0.)).astype(
                "float32"))

B
Bin Lu 已提交
249
        acc_sum = paddle.cumsum(equal_flag, axis=1)
D
dongshuilong 已提交
250 251 252
        mask = paddle.greater_than(acc_sum,
                                   paddle.to_tensor(0.)).astype("float32")
        all_cmc = (paddle.sum(mask, axis=0) / real_query_num).numpy()
W
weishengyu 已提交
253 254 255 256 257

        for k in self.topk:
            metric_dict["recall{}".format(k)] = all_cmc[k - 1]
        return metric_dict

D
dongshuilong 已提交
258

B
Bin Lu 已提交
259
class Precisionk(nn.Layer):
H
HydrogenSulfate 已提交
260
    def __init__(self, topk=(1, 5), descending=True):
B
Bin Lu 已提交
261 262 263 264 265
        super().__init__()
        assert isinstance(topk, (int, list, tuple))
        if isinstance(topk, int):
            topk = [topk]
        self.topk = topk
H
HydrogenSulfate 已提交
266
        self.descending = descending
B
Bin Lu 已提交
267 268 269 270 271 272 273

    def forward(self, similarities_matrix, query_img_id, gallery_img_id,
                keep_mask):
        metric_dict = dict()

        #get cmc
        choosen_indices = paddle.argsort(
H
HydrogenSulfate 已提交
274
            similarities_matrix, axis=1, descending=self.descending)
B
Bin Lu 已提交
275 276 277 278 279 280 281 282 283 284 285 286 287 288 289
        gallery_labels_transpose = paddle.transpose(gallery_img_id, [1, 0])
        gallery_labels_transpose = paddle.broadcast_to(
            gallery_labels_transpose,
            shape=[
                choosen_indices.shape[0], gallery_labels_transpose.shape[1]
            ])
        choosen_label = paddle.index_sample(gallery_labels_transpose,
                                            choosen_indices)
        equal_flag = paddle.equal(choosen_label, query_img_id)
        if keep_mask is not None:
            keep_mask = paddle.index_sample(
                keep_mask.astype('float32'), choosen_indices)
            equal_flag = paddle.logical_and(equal_flag,
                                            keep_mask.astype('bool'))
        equal_flag = paddle.cast(equal_flag, 'float32')
C
cuicheng01 已提交
290

B
Bin Lu 已提交
291 292 293 294 295 296 297 298 299 300
        Ns = paddle.arange(gallery_img_id.shape[0]) + 1
        equal_flag_cumsum = paddle.cumsum(equal_flag, axis=1)
        Precision_at_k = (paddle.mean(equal_flag_cumsum, axis=0) / Ns).numpy()

        for k in self.topk:
            metric_dict["precision@{}".format(k)] = Precision_at_k[k - 1]

        return metric_dict


301 302 303 304 305 306 307
class DistillationTopkAcc(TopkAcc):
    def __init__(self, model_key, feature_key=None, topk=(1, 5)):
        super().__init__(topk=topk)
        self.model_key = model_key
        self.feature_key = feature_key

    def forward(self, x, label):
G
gaotingquan 已提交
308 309
        if isinstance(x, dict):
            x = x[self.model_key]
310 311 312
        if self.feature_key is not None:
            x = x[self.feature_key]
        return super().forward(x, label)
C
cuicheng01 已提交
313 314 315 316 317 318 319 320 321 322 323 324


class GoogLeNetTopkAcc(TopkAcc):
    def __init__(self, topk=(1, 5)):
        super().__init__()
        assert isinstance(topk, (int, list, tuple))
        if isinstance(topk, int):
            topk = [topk]
        self.topk = topk

    def forward(self, x, label):
        return super().forward(x[0], label)
C
cuicheng01 已提交
325 326


C
cuicheng01 已提交
327 328 329 330
class MultiLabelMetric(AvgMetrics):
    def __init__(self, bi_threshold=0.5):
        super().__init__()
        self.bi_threshold = bi_threshold
C
cuicheng01 已提交
331

C
cuicheng01 已提交
332 333 334
    def _multi_hot_encode(self, output):
        logits = F.sigmoid(output).numpy()
        return binarize(logits, threshold=self.bi_threshold)
C
cuicheng01 已提交
335 336


C
cuicheng01 已提交
337
class HammingDistance(MultiLabelMetric):
C
cuicheng01 已提交
338 339 340 341 342 343 344 345
    """
    Soft metric based label for multilabel classification
    Returns:
        The smaller the return value is, the better model is.
    """

    def __init__(self):
        super().__init__()
C
cuicheng01 已提交
346 347 348
        self.reset()

    def reset(self):
C
cuicheng01 已提交
349
        self.avg_meters = {"HammingDistance": AverageMeter("HammingDistance")}
C
cuicheng01 已提交
350

C
cuicheng01 已提交
351 352
    def forward(self, output, target):
        preds = super()._multi_hot_encode(output)
C
cuicheng01 已提交
353 354 355
        metric_dict = dict()
        metric_dict["HammingDistance"] = paddle.to_tensor(
            hamming_loss(target, preds))
356 357
        self.avg_meters["HammingDistance"].update(
            metric_dict["HammingDistance"].numpy()[0], output.shape[0])
C
cuicheng01 已提交
358 359 360
        return metric_dict


C
cuicheng01 已提交
361
class AccuracyScore(MultiLabelMetric):
C
cuicheng01 已提交
362 363 364 365 366 367 368 369 370 371 372 373 374 375 376
    """
    Hard metric for multilabel classification
    Args:
        base: ["sample", "label"], default="sample"
            if "sample", return metric score based sample,
            if "label", return metric score based label.
    Returns:
        accuracy:
    """

    def __init__(self, base="label"):
        super().__init__()
        assert base in ["sample", "label"
                        ], 'must be one of ["sample", "label"]'
        self.base = base
C
cuicheng01 已提交
377 378 379 380
        self.reset()

    def reset(self):
        self.avg_meters = {"AccuracyScore": AverageMeter("AccuracyScore")}
C
cuicheng01 已提交
381

C
cuicheng01 已提交
382 383
    def forward(self, output, target):
        preds = super()._multi_hot_encode(output)
C
cuicheng01 已提交
384 385 386 387 388 389 390 391 392 393 394 395
        metric_dict = dict()
        if self.base == "sample":
            accuracy = accuracy_metric(target, preds)
        elif self.base == "label":
            mcm = multilabel_confusion_matrix(target, preds)
            tns = mcm[:, 0, 0]
            fns = mcm[:, 1, 0]
            tps = mcm[:, 1, 1]
            fps = mcm[:, 0, 1]
            accuracy = (sum(tps) + sum(tns)) / (
                sum(tps) + sum(tns) + sum(fns) + sum(fps))
        metric_dict["AccuracyScore"] = paddle.to_tensor(accuracy)
396 397
        self.avg_meters["AccuracyScore"].update(
            metric_dict["AccuracyScore"].numpy()[0], output.shape[0])
C
cuicheng01 已提交
398
        return metric_dict
Z
zhiboniu 已提交
399 400 401 402 403


def get_attr_metrics(gt_label, preds_probs, threshold):
    """
    index: evaluated label index
Z
zhiboniu 已提交
404
    adapted from "https://github.com/valencebond/Rethinking_of_PAR/blob/master/metrics/pedestrian_metrics.py"
Z
zhiboniu 已提交
405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451
    """
    pred_label = (preds_probs > threshold).astype(int)

    eps = 1e-20
    result = EasyDict()

    has_fuyi = gt_label == -1
    pred_label[has_fuyi] = -1

    ###############################
    # label metrics
    # TP + FN
    result.gt_pos = np.sum((gt_label == 1), axis=0).astype(float)
    # TN + FP
    result.gt_neg = np.sum((gt_label == 0), axis=0).astype(float)
    # TP
    result.true_pos = np.sum((gt_label == 1) * (pred_label == 1),
                             axis=0).astype(float)
    # TN
    result.true_neg = np.sum((gt_label == 0) * (pred_label == 0),
                             axis=0).astype(float)
    # FP
    result.false_pos = np.sum(((gt_label == 0) * (pred_label == 1)),
                              axis=0).astype(float)
    # FN
    result.false_neg = np.sum(((gt_label == 1) * (pred_label == 0)),
                              axis=0).astype(float)

    ################
    # instance metrics
    result.gt_pos_ins = np.sum((gt_label == 1), axis=1).astype(float)
    result.true_pos_ins = np.sum((pred_label == 1), axis=1).astype(float)
    # true positive
    result.intersect_pos = np.sum((gt_label == 1) * (pred_label == 1),
                                  axis=1).astype(float)
    # IOU
    result.union_pos = np.sum(((gt_label == 1) + (pred_label == 1)),
                              axis=1).astype(float)

    return result


class ATTRMetric(nn.Layer):
    def __init__(self, threshold=0.5):
        super().__init__()
        self.threshold = threshold

Z
zhiboniu 已提交
452 453 454
    def reset(self):
        self.attrmeter = AttrMeter(threshold=0.5)

Z
zhiboniu 已提交
455
    def forward(self, output, target):
Z
zhiboniu 已提交
456
        metric_dict = get_attr_metrics(target[:, 0, :].numpy(),
Z
zhiboniu 已提交
457
                                       output.numpy(), self.threshold)
Z
zhiboniu 已提交
458
        self.attrmeter.update(metric_dict)
C
cuicheng01 已提交
459
        return metric_dict