From 3e1e25afca60e12319121f57f2585b137bc1aa09 Mon Sep 17 00:00:00 2001 From: Felix Date: Tue, 8 Jun 2021 11:48:07 +0800 Subject: [PATCH] Update metrics.py --- ppcls/metric/metrics.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/ppcls/metric/metrics.py b/ppcls/metric/metrics.py index 05723c12..d2e66bc5 100644 --- a/ppcls/metric/metrics.py +++ b/ppcls/metric/metrics.py @@ -84,6 +84,29 @@ class Recallk(nn.Layer): metric_dict["recall{}".format(k)] = all_cmc[k - 1] return metric_dict +# retrieval metrics +class RetriMetric(nn.Layer): + def __init__(self, config): + super().__init__() + self.config = config + self.max_rank = 50 #max(self.topk) if max(self.topk) > 50 else 50 + + def forward(self, similarities_matrix, query_img_id, gallery_img_id): + metric_dict = dict() + all_cmc, all_AP, all_INP = get_metrics(similarities_matrix, query_img_id, + gallery_img_id, self.max_rank) + if "Recallk" in self.config.keys(): + topk = self.config['Recallk']['topk'] + for k in topk: + metric_dict["recall{}".format(k)] = all_cmc[k - 1] + if "mAP" in self.config.keys(): + mAP = np.mean(all_AP) + metric_dict["mAP"] = mAP + if "mINP" in self.config.keys(): + mINP = np.mean(all_INP) + metric_dict["mINP"] = mINP + return metric_dict + @lru_cache() def get_metrics(similarities_matrix, query_img_id, gallery_img_id, -- GitLab