metrics.py 4.1 KB
Newer Older
W
weishengyu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import paddle
import paddle.nn as nn
D
dongshuilong 已提交
18
from functools import lru_cache
W
weishengyu 已提交
19 20 21


# TODO: fix the format
W
weishengyu 已提交
22
class TopkAcc(nn.Layer):
W
weishengyu 已提交
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
    def __init__(self, topk=(1, 5)):
        super().__init__()
        assert isinstance(topk, (int, list, tuple))
        if isinstance(topk, int):
            topk = [topk]
        self.topk = topk

    def forward(self, x, label):
        if isinstance(x, dict):
            x = x["logits"]

        metric_dict = dict()
        for k in self.topk:
            metric_dict["top{}".format(k)] = paddle.metric.accuracy(
                x, label, k=k)
        return metric_dict


class mAP(nn.Layer):
D
dongshuilong 已提交
42
    def __init__(self):
W
weishengyu 已提交
43 44 45 46
        super().__init__()

    def forward(self, similarities_matrix, query_img_id, gallery_img_id):
        metric_dict = dict()
D
dongshuilong 已提交
47 48
        _, all_AP, _ = get_metrics(similarities_matrix, query_img_id,
                                   gallery_img_id)
W
weishengyu 已提交
49 50

        mAP = np.mean(all_AP)
D
dongshuilong 已提交
51
        metric_dict["mAP"] = mAP
W
weishengyu 已提交
52 53 54 55
        return metric_dict


class mINP(nn.Layer):
D
dongshuilong 已提交
56
    def __init__(self):
W
weishengyu 已提交
57 58 59 60
        super().__init__()

    def forward(self, similarities_matrix, query_img_id, gallery_img_id):
        metric_dict = dict()
D
dongshuilong 已提交
61 62
        _, _, all_INP = get_metrics(similarities_matrix, query_img_id,
                                    gallery_img_id)
W
weishengyu 已提交
63 64

        mINP = np.mean(all_INP)
D
dongshuilong 已提交
65
        metric_dict["mINP"] = mINP
W
weishengyu 已提交
66 67 68 69
        return metric_dict


class Recallk(nn.Layer):
D
dongshuilong 已提交
70
    def __init__(self, topk=(1, 5)):
W
weishengyu 已提交
71 72 73 74 75
        super().__init__()
        assert isinstance(topk, (int, list))
        if isinstance(topk, int):
            topk = [topk]
        self.topk = topk
D
dongshuilong 已提交
76
        self.max_rank = max(self.topk) if max(self.topk) > 50 else 50
W
weishengyu 已提交
77 78 79

    def forward(self, similarities_matrix, query_img_id, gallery_img_id):
        metric_dict = dict()
D
dongshuilong 已提交
80 81
        all_cmc, _, _ = get_metrics(similarities_matrix, query_img_id,
                                    gallery_img_id, self.max_rank)
W
weishengyu 已提交
82 83 84 85 86 87

        for k in self.topk:
            metric_dict["recall{}".format(k)] = all_cmc[k - 1]
        return metric_dict


D
dongshuilong 已提交
88 89 90 91 92 93 94 95 96 97 98 99 100
@lru_cache()
def get_metrics(similarities_matrix, query_img_id, gallery_img_id,
                max_rank=50):
    num_q, num_g = similarities_matrix.shape
    q_pids = query_img_id.numpy().reshape((query_img_id.shape[0]))
    g_pids = gallery_img_id.numpy().reshape((gallery_img_id.shape[0]))
    if num_g < max_rank:
        max_rank = num_g
        print('Note: number of gallery samples is quite small, got {}'.format(
            num_g))
    indices = paddle.argsort(
        similarities_matrix, axis=1, descending=True).numpy()

W
weishengyu 已提交
101 102 103 104 105 106
    all_cmc = []
    all_AP = []
    all_INP = []
    num_valid_q = 0
    matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
    for q_idx in range(num_q):
W
weishengyu 已提交
107
        raw_cmc = matches[q_idx]
W
weishengyu 已提交
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
        if not np.any(raw_cmc):
            continue
        cmc = raw_cmc.cumsum()
        pos_idx = np.where(raw_cmc == 1)
        max_pos_idx = np.max(pos_idx)
        inp = cmc[max_pos_idx] / (max_pos_idx + 1.0)
        all_INP.append(inp)
        cmc[cmc > 1] = 1

        all_cmc.append(cmc[:max_rank])
        num_valid_q += 1.

        num_rel = raw_cmc.sum()
        tmp_cmc = raw_cmc.cumsum()
        tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)]
        tmp_cmc = np.asarray(tmp_cmc) * raw_cmc
        AP = tmp_cmc.sum() / num_rel
        all_AP.append(AP)
    assert num_valid_q > 0, 'Error: all query identities do not appear in gallery'

    all_cmc = np.asarray(all_cmc).astype(np.float32)
    all_cmc = all_cmc.sum(0) / num_valid_q

    return all_cmc, all_AP, all_INP