From c9f2c4742f3d39f534fdcc8f0ce6e36352fd6866 Mon Sep 17 00:00:00 2001 From: weishengyu Date: Fri, 4 Jun 2021 22:19:04 +0800 Subject: [PATCH] update metric --- ppcls/metric/__init__.py | 0 ppcls/metric/metrics.py | 151 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 151 insertions(+) create mode 100644 ppcls/metric/__init__.py create mode 100644 ppcls/metric/metrics.py diff --git a/ppcls/metric/__init__.py b/ppcls/metric/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ppcls/metric/metrics.py b/ppcls/metric/metrics.py new file mode 100644 index 00000000..fca2f9fc --- /dev/null +++ b/ppcls/metric/metrics.py @@ -0,0 +1,151 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle +import paddle.nn as nn + + +# TODO: fix the format +class Topk(nn.Layer): + def __init__(self, topk=(1, 5)): + super().__init__() + assert isinstance(topk, (int, list, tuple)) + if isinstance(topk, int): + topk = [topk] + self.topk = topk + + def forward(self, x, label): + if isinstance(x, dict): + x = x["logits"] + + metric_dict = dict() + for k in self.topk: + metric_dict["top{}".format(k)] = paddle.metric.accuracy( + x, label, k=k) + return metric_dict + + +class mAP(nn.Layer): + def __init__(self, max_rank=50): + super().__init__() + self.max_rank = max_rank + + def forward(self, similarities_matrix, query_img_id, gallery_img_id): + metric_dict = dict() + num_q, num_g = similarities_matrix.shape + q_pids = query_img_id.numpy().reshape((query_img_id.shape[0])) + g_pids = gallery_img_id.numpy().reshape((gallery_img_id.shape[0])) + if num_g < self.max_rank: + self.max_rank = num_g + print('Note: number of gallery samples is quite small, got {}'. + format(num_g)) + indices = paddle.argsort( + similarities_matrix, axis=1, descending=True).numpy() + _, all_AP, _ = get_metrics(indices, num_q, num_g, q_pids, g_pids, + self.max_rank) + + mAP = np.mean(all_AP) + metric_dict["mAP"] = mAP + return metric_dict + + +class mINP(nn.Layer): + def __init__(self, max_rank=50): + super().__init__() + self.max_rank = max_rank + + def forward(self, similarities_matrix, query_img_id, gallery_img_id): + metric_dict = dict() + num_q, num_g = similarities_matrix.shape + q_pids = query_img_id.numpy().reshape((query_img_id.shape[0])) + g_pids = gallery_img_id.numpy().reshape((gallery_img_id.shape[0])) + if num_g < self.max_rank: + max_rank = num_g + print('Note: number of gallery samples is quite small, got {}'. + format(num_g)) + indices = paddle.argsort( + similarities_matrix, axis=1, descending=True).numpy() + _, _, all_INP = get_metrics(indices, num_q, num_g, q_pids, g_pids, + self.max_rank) + + mINP = np.mean(all_INP) + metric_dict["mINP"] = mINP + return metric_dict + + +class Recallk(nn.Layer): + def __init__(self, max_rank=50, topk=(1, 5)): + super().__init__() + self.max_rank = max_rank + assert isinstance(topk, (int, list)) + if isinstance(topk, int): + topk = [topk] + self.topk = topk + + def forward(self, similarities_matrix, query_img_id, gallery_img_id): + metric_dict = dict() + num_q, num_g = similarities_matrix.shape + q_pids = query_img_id.numpy().reshape((query_img_id.shape[0])) + g_pids = gallery_img_id.numpy().reshape((gallery_img_id.shape[0])) + if num_g < self.max_rank: + max_rank = num_g + print('Note: number of gallery samples is quite small, got {}'. + format(num_g)) + indices = paddle.argsort( + similarities_matrix, axis=1, descending=True).numpy() + all_cmc, _, _ = get_metrics(indices, num_q, num_g, q_pids, g_pids, + self.max_rank) + + for k in self.topk: + metric_dict["recall{}".format(k)] = all_cmc[k - 1] + return metric_dict + + +def get_metrics(indices, num_q, num_g, q_pids, g_pids, max_rank=50): + all_cmc = [] + all_AP = [] + all_INP = [] + num_valid_q = 0 + matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) + for q_idx in range(num_q): + q_pid = q_pids[q_idx] + order = indices[q_idx] + remove = g_pids[order] == q_pid + keep = np.invert(remove) + raw_cmc = matches[q_idx][keep] + if not np.any(raw_cmc): + continue + cmc = raw_cmc.cumsum() + pos_idx = np.where(raw_cmc == 1) + max_pos_idx = np.max(pos_idx) + inp = cmc[max_pos_idx] / (max_pos_idx + 1.0) + all_INP.append(inp) + cmc[cmc > 1] = 1 + + all_cmc.append(cmc[:max_rank]) + num_valid_q += 1. + + num_rel = raw_cmc.sum() + tmp_cmc = raw_cmc.cumsum() + tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)] + tmp_cmc = np.asarray(tmp_cmc) * raw_cmc + AP = tmp_cmc.sum() / num_rel + all_AP.append(AP) + assert num_valid_q > 0, 'Error: all query identities do not appear in gallery' + + all_cmc = np.asarray(all_cmc).astype(np.float32) + all_cmc = all_cmc.sum(0) / num_valid_q + + return all_cmc, all_AP, all_INP -- GitLab