metrics.py 8.9 KB
Newer Older
W
weishengyu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import paddle
import paddle.nn as nn

D
dongshuilong 已提交
19

W
weishengyu 已提交
20
class TopkAcc(nn.Layer):
W
weishengyu 已提交
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
    def __init__(self, topk=(1, 5)):
        super().__init__()
        assert isinstance(topk, (int, list, tuple))
        if isinstance(topk, int):
            topk = [topk]
        self.topk = topk

    def forward(self, x, label):
        if isinstance(x, dict):
            x = x["logits"]

        metric_dict = dict()
        for k in self.topk:
            metric_dict["top{}".format(k)] = paddle.metric.accuracy(
                x, label, k=k)
        return metric_dict

D
dongshuilong 已提交
38

W
weishengyu 已提交
39
class mAP(nn.Layer):
D
dongshuilong 已提交
40
    def __init__(self):
W
weishengyu 已提交
41 42
        super().__init__()

D
dongshuilong 已提交
43
    def forward(self, similarities_matrix, query_img_id, gallery_img_id,
D
dongshuilong 已提交
44
                keep_mask):
W
weishengyu 已提交
45
        metric_dict = dict()
D
dongshuilong 已提交
46 47 48 49 50 51 52 53 54 55 56

        choosen_indices = paddle.argsort(
            similarities_matrix, axis=1, descending=True)
        gallery_labels_transpose = paddle.transpose(gallery_img_id, [1, 0])
        gallery_labels_transpose = paddle.broadcast_to(
            gallery_labels_transpose,
            shape=[
                choosen_indices.shape[0], gallery_labels_transpose.shape[1]
            ])
        choosen_label = paddle.index_sample(gallery_labels_transpose,
                                            choosen_indices)
B
Bin Lu 已提交
57
        equal_flag = paddle.equal(choosen_label, query_img_id)
D
dongshuilong 已提交
58 59 60 61 62
        if keep_mask is not None:
            keep_mask = paddle.index_sample(
                keep_mask.astype('float32'), choosen_indices)
            equal_flag = paddle.logical_and(equal_flag,
                                            keep_mask.astype('bool'))
B
Bin Lu 已提交
63 64
        equal_flag = paddle.cast(equal_flag, 'float32')

D
dongshuilong 已提交
65 66 67 68 69 70
        num_rel = paddle.sum(equal_flag, axis=1)
        num_rel = paddle.greater_than(num_rel, paddle.to_tensor(0.))
        num_rel_index = paddle.nonzero(num_rel.astype("int"))
        num_rel_index = paddle.reshape(num_rel_index, [num_rel_index.shape[0]])
        equal_flag = paddle.index_select(equal_flag, num_rel_index, axis=0)

B
Bin Lu 已提交
71 72
        acc_sum = paddle.cumsum(equal_flag, axis=1)
        div = paddle.arange(acc_sum.shape[1]).astype("float32") + 1
D
dongshuilong 已提交
73
        precision = paddle.divide(acc_sum, div)
B
Bin Lu 已提交
74 75 76

        #calc map
        precision_mask = paddle.multiply(equal_flag, precision)
D
dongshuilong 已提交
77 78
        ap = paddle.sum(precision_mask, axis=1) / paddle.sum(equal_flag,
                                                             axis=1)
B
Bin Lu 已提交
79
        metric_dict["mAP"] = paddle.mean(ap).numpy()[0]
W
weishengyu 已提交
80 81
        return metric_dict

D
dongshuilong 已提交
82

W
weishengyu 已提交
83
class mINP(nn.Layer):
D
dongshuilong 已提交
84
    def __init__(self):
W
weishengyu 已提交
85 86
        super().__init__()

D
dongshuilong 已提交
87
    def forward(self, similarities_matrix, query_img_id, gallery_img_id,
D
dongshuilong 已提交
88
                keep_mask):
W
weishengyu 已提交
89
        metric_dict = dict()
D
dongshuilong 已提交
90 91 92 93 94 95 96 97 98 99 100

        choosen_indices = paddle.argsort(
            similarities_matrix, axis=1, descending=True)
        gallery_labels_transpose = paddle.transpose(gallery_img_id, [1, 0])
        gallery_labels_transpose = paddle.broadcast_to(
            gallery_labels_transpose,
            shape=[
                choosen_indices.shape[0], gallery_labels_transpose.shape[1]
            ])
        choosen_label = paddle.index_sample(gallery_labels_transpose,
                                            choosen_indices)
D
dongshuilong 已提交
101 102 103 104 105 106
        equal_flag = paddle.equal(choosen_label, query_img_id)
        if keep_mask is not None:
            keep_mask = paddle.index_sample(
                keep_mask.astype('float32'), choosen_indices)
            equal_flag = paddle.logical_and(equal_flag,
                                            keep_mask.astype('bool'))
D
dongshuilong 已提交
107
        equal_flag = paddle.cast(equal_flag, 'float32')
D
dongshuilong 已提交
108 109 110 111 112 113

        num_rel = paddle.sum(equal_flag, axis=1)
        num_rel = paddle.greater_than(num_rel, paddle.to_tensor(0.))
        num_rel_index = paddle.nonzero(num_rel.astype("int"))
        num_rel_index = paddle.reshape(num_rel_index, [num_rel_index.shape[0]])
        equal_flag = paddle.index_select(equal_flag, num_rel_index, axis=0)
B
Bin Lu 已提交
114 115

        #do accumulative sum
D
dongshuilong 已提交
116
        div = paddle.arange(equal_flag.shape[1]).astype("float32") + 2
D
dongshuilong 已提交
117 118
        minus = paddle.divide(equal_flag, div)
        auxilary = paddle.subtract(equal_flag, minus)
D
dongshuilong 已提交
119
        hard_index = paddle.argmax(auxilary, axis=1).astype("float32")
D
dongshuilong 已提交
120
        all_INP = paddle.divide(paddle.sum(equal_flag, axis=1), hard_index)
B
Bin Lu 已提交
121 122
        mINP = paddle.mean(all_INP)
        metric_dict["mINP"] = mINP.numpy()[0]
W
weishengyu 已提交
123 124
        return metric_dict

D
dongshuilong 已提交
125

W
weishengyu 已提交
126
class Recallk(nn.Layer):
D
dongshuilong 已提交
127
    def __init__(self, topk=(1, 5)):
W
weishengyu 已提交
128
        super().__init__()
B
Bin Lu 已提交
129
        assert isinstance(topk, (int, list, tuple))
W
weishengyu 已提交
130 131 132 133
        if isinstance(topk, int):
            topk = [topk]
        self.topk = topk

D
dongshuilong 已提交
134 135
    def forward(self, similarities_matrix, query_img_id, gallery_img_id,
                keep_mask):
W
weishengyu 已提交
136
        metric_dict = dict()
B
Bin Lu 已提交
137 138

        #get cmc
D
dongshuilong 已提交
139 140 141 142 143 144 145 146 147 148
        choosen_indices = paddle.argsort(
            similarities_matrix, axis=1, descending=True)
        gallery_labels_transpose = paddle.transpose(gallery_img_id, [1, 0])
        gallery_labels_transpose = paddle.broadcast_to(
            gallery_labels_transpose,
            shape=[
                choosen_indices.shape[0], gallery_labels_transpose.shape[1]
            ])
        choosen_label = paddle.index_sample(gallery_labels_transpose,
                                            choosen_indices)
B
Bin Lu 已提交
149
        equal_flag = paddle.equal(choosen_label, query_img_id)
D
dongshuilong 已提交
150 151 152 153 154
        if keep_mask is not None:
            keep_mask = paddle.index_sample(
                keep_mask.astype('float32'), choosen_indices)
            equal_flag = paddle.logical_and(equal_flag,
                                            keep_mask.astype('bool'))
B
Bin Lu 已提交
155
        equal_flag = paddle.cast(equal_flag, 'float32')
D
dongshuilong 已提交
156 157 158 159 160
        real_query_num = paddle.sum(equal_flag, axis=1)
        real_query_num = paddle.sum(
            paddle.greater_than(real_query_num, paddle.to_tensor(0.)).astype(
                "float32"))

B
Bin Lu 已提交
161
        acc_sum = paddle.cumsum(equal_flag, axis=1)
D
dongshuilong 已提交
162 163 164
        mask = paddle.greater_than(acc_sum,
                                   paddle.to_tensor(0.)).astype("float32")
        all_cmc = (paddle.sum(mask, axis=0) / real_query_num).numpy()
W
weishengyu 已提交
165 166 167 168 169

        for k in self.topk:
            metric_dict["recall{}".format(k)] = all_cmc[k - 1]
        return metric_dict

D
dongshuilong 已提交
170

B
Bin Lu 已提交
171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
class Precisionk(nn.Layer):
    def __init__(self, topk=(1, 5)):
        super().__init__()
        assert isinstance(topk, (int, list, tuple))
        if isinstance(topk, int):
            topk = [topk]
        self.topk = topk

    def forward(self, similarities_matrix, query_img_id, gallery_img_id,
                keep_mask):
        metric_dict = dict()

        #get cmc
        choosen_indices = paddle.argsort(
            similarities_matrix, axis=1, descending=True)
        gallery_labels_transpose = paddle.transpose(gallery_img_id, [1, 0])
        gallery_labels_transpose = paddle.broadcast_to(
            gallery_labels_transpose,
            shape=[
                choosen_indices.shape[0], gallery_labels_transpose.shape[1]
            ])
        choosen_label = paddle.index_sample(gallery_labels_transpose,
                                            choosen_indices)
        equal_flag = paddle.equal(choosen_label, query_img_id)
        if keep_mask is not None:
            keep_mask = paddle.index_sample(
                keep_mask.astype('float32'), choosen_indices)
            equal_flag = paddle.logical_and(equal_flag,
                                            keep_mask.astype('bool'))
        equal_flag = paddle.cast(equal_flag, 'float32')
        
        Ns = paddle.arange(gallery_img_id.shape[0]) + 1
        equal_flag_cumsum = paddle.cumsum(equal_flag, axis=1)
        Precision_at_k = (paddle.mean(equal_flag_cumsum, axis=0) / Ns).numpy()

        for k in self.topk:
            metric_dict["precision@{}".format(k)] = Precision_at_k[k - 1]

        return metric_dict


212 213 214 215 216 217 218 219 220 221 222
class DistillationTopkAcc(TopkAcc):
    def __init__(self, model_key, feature_key=None, topk=(1, 5)):
        super().__init__(topk=topk)
        self.model_key = model_key
        self.feature_key = feature_key

    def forward(self, x, label):
        x = x[self.model_key]
        if self.feature_key is not None:
            x = x[self.feature_key]
        return super().forward(x, label)
C
cuicheng01 已提交
223 224 225 226 227 228 229 230 231 232 233 234


class GoogLeNetTopkAcc(TopkAcc):
    def __init__(self, topk=(1, 5)):
        super().__init__()
        assert isinstance(topk, (int, list, tuple))
        if isinstance(topk, int):
            topk = [topk]
        self.topk = topk

    def forward(self, x, label):
        return super().forward(x[0], label)