diff --git a/ppcls/metric/metrics.py b/ppcls/metric/metrics.py index 05723c12303fdf616110a0cccffb883058f62674..d2e66bc54dc298ec329b45305684eb33f39da11b 100644 --- a/ppcls/metric/metrics.py +++ b/ppcls/metric/metrics.py @@ -84,6 +84,29 @@ class Recallk(nn.Layer): metric_dict["recall{}".format(k)] = all_cmc[k - 1] return metric_dict +# retrieval metrics +class RetriMetric(nn.Layer): + def __init__(self, config): + super().__init__() + self.config = config + self.max_rank = 50 #max(self.topk) if max(self.topk) > 50 else 50 + + def forward(self, similarities_matrix, query_img_id, gallery_img_id): + metric_dict = dict() + all_cmc, all_AP, all_INP = get_metrics(similarities_matrix, query_img_id, + gallery_img_id, self.max_rank) + if "Recallk" in self.config.keys(): + topk = self.config['Recallk']['topk'] + for k in topk: + metric_dict["recall{}".format(k)] = all_cmc[k - 1] + if "mAP" in self.config.keys(): + mAP = np.mean(all_AP) + metric_dict["mAP"] = mAP + if "mINP" in self.config.keys(): + mINP = np.mean(all_INP) + metric_dict["mINP"] = mINP + return metric_dict + @lru_cache() def get_metrics(similarities_matrix, query_img_id, gallery_img_id,