metrics.py 13.2 KB
Newer Older
W
weishengyu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import paddle
import paddle.nn as nn
C
cuicheng01 已提交
18 19 20 21 22 23
import paddle.nn.functional as F

from sklearn.metrics import hamming_loss
from sklearn.metrics import accuracy_score as accuracy_metric
from sklearn.metrics import multilabel_confusion_matrix
from sklearn.preprocessing import binarize
W
weishengyu 已提交
24

Z
zhiboniu 已提交
25 26
from easydict import EasyDict

D
dongshuilong 已提交
27

W
weishengyu 已提交
28
class TopkAcc(nn.Layer):
W
weishengyu 已提交
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
    def __init__(self, topk=(1, 5)):
        super().__init__()
        assert isinstance(topk, (int, list, tuple))
        if isinstance(topk, int):
            topk = [topk]
        self.topk = topk

    def forward(self, x, label):
        if isinstance(x, dict):
            x = x["logits"]

        metric_dict = dict()
        for k in self.topk:
            metric_dict["top{}".format(k)] = paddle.metric.accuracy(
                x, label, k=k)
        return metric_dict

D
dongshuilong 已提交
46

W
weishengyu 已提交
47
class mAP(nn.Layer):
H
HydrogenSulfate 已提交
48
    def __init__(self, descending=True):
W
weishengyu 已提交
49
        super().__init__()
H
HydrogenSulfate 已提交
50
        self.descending = descending
W
weishengyu 已提交
51

D
dongshuilong 已提交
52
    def forward(self, similarities_matrix, query_img_id, gallery_img_id,
D
dongshuilong 已提交
53
                keep_mask):
W
weishengyu 已提交
54
        metric_dict = dict()
D
dongshuilong 已提交
55 56

        choosen_indices = paddle.argsort(
H
HydrogenSulfate 已提交
57
            similarities_matrix, axis=1, descending=self.descending)
D
dongshuilong 已提交
58 59 60 61 62 63 64 65
        gallery_labels_transpose = paddle.transpose(gallery_img_id, [1, 0])
        gallery_labels_transpose = paddle.broadcast_to(
            gallery_labels_transpose,
            shape=[
                choosen_indices.shape[0], gallery_labels_transpose.shape[1]
            ])
        choosen_label = paddle.index_sample(gallery_labels_transpose,
                                            choosen_indices)
B
Bin Lu 已提交
66
        equal_flag = paddle.equal(choosen_label, query_img_id)
D
dongshuilong 已提交
67 68 69 70 71
        if keep_mask is not None:
            keep_mask = paddle.index_sample(
                keep_mask.astype('float32'), choosen_indices)
            equal_flag = paddle.logical_and(equal_flag,
                                            keep_mask.astype('bool'))
B
Bin Lu 已提交
72 73
        equal_flag = paddle.cast(equal_flag, 'float32')

D
dongshuilong 已提交
74 75 76 77 78 79
        num_rel = paddle.sum(equal_flag, axis=1)
        num_rel = paddle.greater_than(num_rel, paddle.to_tensor(0.))
        num_rel_index = paddle.nonzero(num_rel.astype("int"))
        num_rel_index = paddle.reshape(num_rel_index, [num_rel_index.shape[0]])
        equal_flag = paddle.index_select(equal_flag, num_rel_index, axis=0)

B
Bin Lu 已提交
80 81
        acc_sum = paddle.cumsum(equal_flag, axis=1)
        div = paddle.arange(acc_sum.shape[1]).astype("float32") + 1
D
dongshuilong 已提交
82
        precision = paddle.divide(acc_sum, div)
B
Bin Lu 已提交
83 84 85

        #calc map
        precision_mask = paddle.multiply(equal_flag, precision)
D
dongshuilong 已提交
86 87
        ap = paddle.sum(precision_mask, axis=1) / paddle.sum(equal_flag,
                                                             axis=1)
B
Bin Lu 已提交
88
        metric_dict["mAP"] = paddle.mean(ap).numpy()[0]
W
weishengyu 已提交
89 90
        return metric_dict

D
dongshuilong 已提交
91

W
weishengyu 已提交
92
class mINP(nn.Layer):
H
HydrogenSulfate 已提交
93
    def __init__(self, descending=True):
W
weishengyu 已提交
94
        super().__init__()
H
HydrogenSulfate 已提交
95
        self.descending = descending
W
weishengyu 已提交
96

D
dongshuilong 已提交
97
    def forward(self, similarities_matrix, query_img_id, gallery_img_id,
D
dongshuilong 已提交
98
                keep_mask):
W
weishengyu 已提交
99
        metric_dict = dict()
D
dongshuilong 已提交
100 101

        choosen_indices = paddle.argsort(
H
HydrogenSulfate 已提交
102
            similarities_matrix, axis=1, descending=self.descending)
D
dongshuilong 已提交
103 104 105 106 107 108 109 110
        gallery_labels_transpose = paddle.transpose(gallery_img_id, [1, 0])
        gallery_labels_transpose = paddle.broadcast_to(
            gallery_labels_transpose,
            shape=[
                choosen_indices.shape[0], gallery_labels_transpose.shape[1]
            ])
        choosen_label = paddle.index_sample(gallery_labels_transpose,
                                            choosen_indices)
D
dongshuilong 已提交
111 112 113 114 115 116
        equal_flag = paddle.equal(choosen_label, query_img_id)
        if keep_mask is not None:
            keep_mask = paddle.index_sample(
                keep_mask.astype('float32'), choosen_indices)
            equal_flag = paddle.logical_and(equal_flag,
                                            keep_mask.astype('bool'))
D
dongshuilong 已提交
117
        equal_flag = paddle.cast(equal_flag, 'float32')
D
dongshuilong 已提交
118 119 120 121 122 123

        num_rel = paddle.sum(equal_flag, axis=1)
        num_rel = paddle.greater_than(num_rel, paddle.to_tensor(0.))
        num_rel_index = paddle.nonzero(num_rel.astype("int"))
        num_rel_index = paddle.reshape(num_rel_index, [num_rel_index.shape[0]])
        equal_flag = paddle.index_select(equal_flag, num_rel_index, axis=0)
B
Bin Lu 已提交
124 125

        #do accumulative sum
D
dongshuilong 已提交
126
        div = paddle.arange(equal_flag.shape[1]).astype("float32") + 2
D
dongshuilong 已提交
127 128
        minus = paddle.divide(equal_flag, div)
        auxilary = paddle.subtract(equal_flag, minus)
D
dongshuilong 已提交
129
        hard_index = paddle.argmax(auxilary, axis=1).astype("float32")
D
dongshuilong 已提交
130
        all_INP = paddle.divide(paddle.sum(equal_flag, axis=1), hard_index)
B
Bin Lu 已提交
131 132
        mINP = paddle.mean(all_INP)
        metric_dict["mINP"] = mINP.numpy()[0]
W
weishengyu 已提交
133 134
        return metric_dict

D
dongshuilong 已提交
135

W
weishengyu 已提交
136
class Recallk(nn.Layer):
H
HydrogenSulfate 已提交
137
    def __init__(self, topk=(1, 5), descending=True):
W
weishengyu 已提交
138
        super().__init__()
B
Bin Lu 已提交
139
        assert isinstance(topk, (int, list, tuple))
W
weishengyu 已提交
140 141 142
        if isinstance(topk, int):
            topk = [topk]
        self.topk = topk
H
HydrogenSulfate 已提交
143
        self.descending = descending
W
weishengyu 已提交
144

D
dongshuilong 已提交
145 146
    def forward(self, similarities_matrix, query_img_id, gallery_img_id,
                keep_mask):
W
weishengyu 已提交
147
        metric_dict = dict()
B
Bin Lu 已提交
148 149

        #get cmc
D
dongshuilong 已提交
150
        choosen_indices = paddle.argsort(
H
HydrogenSulfate 已提交
151
            similarities_matrix, axis=1, descending=self.descending)
D
dongshuilong 已提交
152 153 154 155 156 157 158 159
        gallery_labels_transpose = paddle.transpose(gallery_img_id, [1, 0])
        gallery_labels_transpose = paddle.broadcast_to(
            gallery_labels_transpose,
            shape=[
                choosen_indices.shape[0], gallery_labels_transpose.shape[1]
            ])
        choosen_label = paddle.index_sample(gallery_labels_transpose,
                                            choosen_indices)
B
Bin Lu 已提交
160
        equal_flag = paddle.equal(choosen_label, query_img_id)
D
dongshuilong 已提交
161 162 163 164 165
        if keep_mask is not None:
            keep_mask = paddle.index_sample(
                keep_mask.astype('float32'), choosen_indices)
            equal_flag = paddle.logical_and(equal_flag,
                                            keep_mask.astype('bool'))
B
Bin Lu 已提交
166
        equal_flag = paddle.cast(equal_flag, 'float32')
D
dongshuilong 已提交
167 168 169 170 171
        real_query_num = paddle.sum(equal_flag, axis=1)
        real_query_num = paddle.sum(
            paddle.greater_than(real_query_num, paddle.to_tensor(0.)).astype(
                "float32"))

B
Bin Lu 已提交
172
        acc_sum = paddle.cumsum(equal_flag, axis=1)
D
dongshuilong 已提交
173 174 175
        mask = paddle.greater_than(acc_sum,
                                   paddle.to_tensor(0.)).astype("float32")
        all_cmc = (paddle.sum(mask, axis=0) / real_query_num).numpy()
W
weishengyu 已提交
176 177 178 179 180

        for k in self.topk:
            metric_dict["recall{}".format(k)] = all_cmc[k - 1]
        return metric_dict

D
dongshuilong 已提交
181

B
Bin Lu 已提交
182
class Precisionk(nn.Layer):
H
HydrogenSulfate 已提交
183
    def __init__(self, topk=(1, 5), descending=True):
B
Bin Lu 已提交
184 185 186 187 188
        super().__init__()
        assert isinstance(topk, (int, list, tuple))
        if isinstance(topk, int):
            topk = [topk]
        self.topk = topk
H
HydrogenSulfate 已提交
189
        self.descending = descending
B
Bin Lu 已提交
190 191 192 193 194 195 196

    def forward(self, similarities_matrix, query_img_id, gallery_img_id,
                keep_mask):
        metric_dict = dict()

        #get cmc
        choosen_indices = paddle.argsort(
H
HydrogenSulfate 已提交
197
            similarities_matrix, axis=1, descending=self.descending)
B
Bin Lu 已提交
198 199 200 201 202 203 204 205 206 207 208 209 210 211 212
        gallery_labels_transpose = paddle.transpose(gallery_img_id, [1, 0])
        gallery_labels_transpose = paddle.broadcast_to(
            gallery_labels_transpose,
            shape=[
                choosen_indices.shape[0], gallery_labels_transpose.shape[1]
            ])
        choosen_label = paddle.index_sample(gallery_labels_transpose,
                                            choosen_indices)
        equal_flag = paddle.equal(choosen_label, query_img_id)
        if keep_mask is not None:
            keep_mask = paddle.index_sample(
                keep_mask.astype('float32'), choosen_indices)
            equal_flag = paddle.logical_and(equal_flag,
                                            keep_mask.astype('bool'))
        equal_flag = paddle.cast(equal_flag, 'float32')
C
cuicheng01 已提交
213

B
Bin Lu 已提交
214 215 216 217 218 219 220 221 222 223
        Ns = paddle.arange(gallery_img_id.shape[0]) + 1
        equal_flag_cumsum = paddle.cumsum(equal_flag, axis=1)
        Precision_at_k = (paddle.mean(equal_flag_cumsum, axis=0) / Ns).numpy()

        for k in self.topk:
            metric_dict["precision@{}".format(k)] = Precision_at_k[k - 1]

        return metric_dict


224 225 226 227 228 229 230
class DistillationTopkAcc(TopkAcc):
    def __init__(self, model_key, feature_key=None, topk=(1, 5)):
        super().__init__(topk=topk)
        self.model_key = model_key
        self.feature_key = feature_key

    def forward(self, x, label):
G
gaotingquan 已提交
231 232
        if isinstance(x, dict):
            x = x[self.model_key]
233 234 235
        if self.feature_key is not None:
            x = x[self.feature_key]
        return super().forward(x, label)
C
cuicheng01 已提交
236 237 238 239 240 241 242 243 244 245 246 247


class GoogLeNetTopkAcc(TopkAcc):
    def __init__(self, topk=(1, 5)):
        super().__init__()
        assert isinstance(topk, (int, list, tuple))
        if isinstance(topk, int):
            topk = [topk]
        self.topk = topk

    def forward(self, x, label):
        return super().forward(x[0], label)
C
cuicheng01 已提交
248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312


class MutiLabelMetric(object):
    def __init__(self):
        pass

    def _multi_hot_encode(self, logits, threshold=0.5):
        return binarize(logits, threshold=threshold)

    def __call__(self, output):
        output = F.sigmoid(output)
        preds = self._multi_hot_encode(logits=output.numpy(), threshold=0.5)
        return preds


class HammingDistance(MutiLabelMetric):
    """
    Soft metric based label for multilabel classification
    Returns:
        The smaller the return value is, the better model is.
    """

    def __init__(self):
        super().__init__()

    def __call__(self, output, target):
        preds = super().__call__(output)
        metric_dict = dict()
        metric_dict["HammingDistance"] = paddle.to_tensor(
            hamming_loss(target, preds))
        return metric_dict


class AccuracyScore(MutiLabelMetric):
    """
    Hard metric for multilabel classification
    Args:
        base: ["sample", "label"], default="sample"
            if "sample", return metric score based sample,
            if "label", return metric score based label.
    Returns:
        accuracy:
    """

    def __init__(self, base="label"):
        super().__init__()
        assert base in ["sample", "label"
                        ], 'must be one of ["sample", "label"]'
        self.base = base

    def __call__(self, output, target):
        preds = super().__call__(output)
        metric_dict = dict()
        if self.base == "sample":
            accuracy = accuracy_metric(target, preds)
        elif self.base == "label":
            mcm = multilabel_confusion_matrix(target, preds)
            tns = mcm[:, 0, 0]
            fns = mcm[:, 1, 0]
            tps = mcm[:, 1, 1]
            fps = mcm[:, 0, 1]
            accuracy = (sum(tps) + sum(tns)) / (
                sum(tps) + sum(tns) + sum(fns) + sum(fps))
        metric_dict["AccuracyScore"] = paddle.to_tensor(accuracy)
        return metric_dict
Z
zhiboniu 已提交
313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364


def get_attr_metrics(gt_label, preds_probs, threshold):
    """
    index: evaluated label index
    """
    pred_label = (preds_probs > threshold).astype(int)

    eps = 1e-20
    result = EasyDict()

    has_fuyi = gt_label == -1
    pred_label[has_fuyi] = -1

    ###############################
    # label metrics
    # TP + FN
    result.gt_pos = np.sum((gt_label == 1), axis=0).astype(float)
    # TN + FP
    result.gt_neg = np.sum((gt_label == 0), axis=0).astype(float)
    # TP
    result.true_pos = np.sum((gt_label == 1) * (pred_label == 1),
                             axis=0).astype(float)
    # TN
    result.true_neg = np.sum((gt_label == 0) * (pred_label == 0),
                             axis=0).astype(float)
    # FP
    result.false_pos = np.sum(((gt_label == 0) * (pred_label == 1)),
                              axis=0).astype(float)
    # FN
    result.false_neg = np.sum(((gt_label == 1) * (pred_label == 0)),
                              axis=0).astype(float)

    ################
    # instance metrics
    result.gt_pos_ins = np.sum((gt_label == 1), axis=1).astype(float)
    result.true_pos_ins = np.sum((pred_label == 1), axis=1).astype(float)
    # true positive
    result.intersect_pos = np.sum((gt_label == 1) * (pred_label == 1),
                                  axis=1).astype(float)
    # IOU
    result.union_pos = np.sum(((gt_label == 1) + (pred_label == 1)),
                              axis=1).astype(float)

    return result


class ATTRMetric(nn.Layer):
    def __init__(self, threshold=0.5):
        super().__init__()
        self.threshold = threshold

Z
zhiboniu 已提交
365
    def forward(self, output, target):
Z
zhiboniu 已提交
366
        metric_dict = get_attr_metrics(target[:, 0, :].numpy(),
Z
zhiboniu 已提交
367 368
                                       output.numpy(), self.threshold)
        return metric_dict