metrics.py 5.5 KB
Newer Older
W
weishengyu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import paddle
import paddle.nn as nn
D
dongshuilong 已提交
18
from functools import lru_cache
W
weishengyu 已提交
19 20


W
weishengyu 已提交
21
class TopkAcc(nn.Layer):
W
weishengyu 已提交
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
    def __init__(self, topk=(1, 5)):
        super().__init__()
        assert isinstance(topk, (int, list, tuple))
        if isinstance(topk, int):
            topk = [topk]
        self.topk = topk

    def forward(self, x, label):
        if isinstance(x, dict):
            x = x["logits"]

        metric_dict = dict()
        for k in self.topk:
            metric_dict["top{}".format(k)] = paddle.metric.accuracy(
                x, label, k=k)
        return metric_dict


class mAP(nn.Layer):
D
dongshuilong 已提交
41
    def __init__(self):
W
weishengyu 已提交
42 43 44 45
        super().__init__()

    def forward(self, similarities_matrix, query_img_id, gallery_img_id):
        metric_dict = dict()
D
dongshuilong 已提交
46 47
        _, all_AP, _ = get_metrics(similarities_matrix, query_img_id,
                                   gallery_img_id)
W
weishengyu 已提交
48 49

        mAP = np.mean(all_AP)
D
dongshuilong 已提交
50
        metric_dict["mAP"] = mAP
W
weishengyu 已提交
51 52 53 54
        return metric_dict


class mINP(nn.Layer):
D
dongshuilong 已提交
55
    def __init__(self):
W
weishengyu 已提交
56 57 58 59
        super().__init__()

    def forward(self, similarities_matrix, query_img_id, gallery_img_id):
        metric_dict = dict()
D
dongshuilong 已提交
60 61
        _, _, all_INP = get_metrics(similarities_matrix, query_img_id,
                                    gallery_img_id)
W
weishengyu 已提交
62 63

        mINP = np.mean(all_INP)
D
dongshuilong 已提交
64
        metric_dict["mINP"] = mINP
W
weishengyu 已提交
65 66 67 68
        return metric_dict


class Recallk(nn.Layer):
D
dongshuilong 已提交
69
    def __init__(self, topk=(1, 5)):
W
weishengyu 已提交
70
        super().__init__()
B
Bin Lu 已提交
71
        assert isinstance(topk, (int, list, tuple))
W
weishengyu 已提交
72 73 74
        if isinstance(topk, int):
            topk = [topk]
        self.topk = topk
D
dongshuilong 已提交
75
        self.max_rank = max(self.topk) if max(self.topk) > 50 else 50
W
weishengyu 已提交
76 77 78

    def forward(self, similarities_matrix, query_img_id, gallery_img_id):
        metric_dict = dict()
D
dongshuilong 已提交
79 80
        all_cmc, _, _ = get_metrics(similarities_matrix, query_img_id,
                                    gallery_img_id, self.max_rank)
W
weishengyu 已提交
81 82 83 84 85

        for k in self.topk:
            metric_dict["recall{}".format(k)] = all_cmc[k - 1]
        return metric_dict

86

F
Felix 已提交
87 88 89 90 91 92 93 94 95
# retrieval metrics
class RetriMetric(nn.Layer):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.max_rank = 50  #max(self.topk) if max(self.topk) > 50 else 50

    def forward(self, similarities_matrix, query_img_id, gallery_img_id):
        metric_dict = dict()
96 97
        all_cmc, all_AP, all_INP = get_metrics(
            similarities_matrix, query_img_id, gallery_img_id, self.max_rank)
F
Felix 已提交
98 99
        if "Recallk" in self.config.keys():
            topk = self.config['Recallk']['topk']
B
Bin Lu 已提交
100 101 102
            assert isinstance(topk, (int, list, tuple))
            if isinstance(topk, int):
                topk = [topk]
F
Felix 已提交
103 104 105 106 107 108 109 110 111
            for k in topk:
                metric_dict["recall{}".format(k)] = all_cmc[k - 1]
        if "mAP" in self.config.keys():
            mAP = np.mean(all_AP)
            metric_dict["mAP"] = mAP
        if "mINP" in self.config.keys():
            mINP = np.mean(all_INP)
            metric_dict["mINP"] = mINP
        return metric_dict
112

W
weishengyu 已提交
113

D
dongshuilong 已提交
114 115 116 117 118 119 120 121 122 123 124 125 126
@lru_cache()
def get_metrics(similarities_matrix, query_img_id, gallery_img_id,
                max_rank=50):
    num_q, num_g = similarities_matrix.shape
    q_pids = query_img_id.numpy().reshape((query_img_id.shape[0]))
    g_pids = gallery_img_id.numpy().reshape((gallery_img_id.shape[0]))
    if num_g < max_rank:
        max_rank = num_g
        print('Note: number of gallery samples is quite small, got {}'.format(
            num_g))
    indices = paddle.argsort(
        similarities_matrix, axis=1, descending=True).numpy()

W
weishengyu 已提交
127 128 129 130 131 132
    all_cmc = []
    all_AP = []
    all_INP = []
    num_valid_q = 0
    matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
    for q_idx in range(num_q):
W
weishengyu 已提交
133
        raw_cmc = matches[q_idx]
W
weishengyu 已提交
134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
        if not np.any(raw_cmc):
            continue
        cmc = raw_cmc.cumsum()
        pos_idx = np.where(raw_cmc == 1)
        max_pos_idx = np.max(pos_idx)
        inp = cmc[max_pos_idx] / (max_pos_idx + 1.0)
        all_INP.append(inp)
        cmc[cmc > 1] = 1

        all_cmc.append(cmc[:max_rank])
        num_valid_q += 1.

        num_rel = raw_cmc.sum()
        tmp_cmc = raw_cmc.cumsum()
        tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)]
        tmp_cmc = np.asarray(tmp_cmc) * raw_cmc
        AP = tmp_cmc.sum() / num_rel
        all_AP.append(AP)
    assert num_valid_q > 0, 'Error: all query identities do not appear in gallery'

    all_cmc = np.asarray(all_cmc).astype(np.float32)
    all_cmc = all_cmc.sum(0) / num_valid_q

    return all_cmc, all_AP, all_INP
158 159 160 161 162 163 164 165 166 167 168 169 170


class DistillationTopkAcc(TopkAcc):
    def __init__(self, model_key, feature_key=None, topk=(1, 5)):
        super().__init__(topk=topk)
        self.model_key = model_key
        self.feature_key = feature_key

    def forward(self, x, label):
        x = x[self.model_key]
        if self.feature_key is not None:
            x = x[self.feature_key]
        return super().forward(x, label)