未验证 提交 96c5113c 编写于 作者: W Wei Shengyu 提交者: GitHub

Merge pull request #802 from FredHuang16/patch-8

fix retrieval metrics
......@@ -16,8 +16,7 @@ from paddle import nn
import copy
from collections import OrderedDict
from .metrics import TopkAcc, mAP, mINP, Recallk
from .metrics import TopkAcc, mAP, mINP, Recallk, RetriMetric
class CombinedMetrics(nn.Layer):
def __init__(self, config_list):
......@@ -25,13 +24,21 @@ class CombinedMetrics(nn.Layer):
self.metric_func_list = []
assert isinstance(config_list, list), (
'operator config should be a list')
self.retri_config = dict() # retrieval metrics config
for config in config_list:
assert isinstance(config,
dict) and len(config) == 1, "yaml format error"
metric_name = list(config)[0]
if metric_name in ["Recallk", "mAP", "mINP"]:
self.retri_config[metric_name] = config[metric_name]
continue
metric_params = config[metric_name]
self.metric_func_list.append(eval(metric_name)(**metric_params))
if self.retri_config:
self.metric_func_list.append(RetriMetric(self.retri_config))
def __call__(self, *args, **kwargs):
metric_dict = OrderedDict()
for idx, metric_func in enumerate(self.metric_func_list):
......
......@@ -84,6 +84,29 @@ class Recallk(nn.Layer):
metric_dict["recall{}".format(k)] = all_cmc[k - 1]
return metric_dict
# retrieval metrics
class RetriMetric(nn.Layer):
def __init__(self, config):
super().__init__()
self.config = config
self.max_rank = 50 #max(self.topk) if max(self.topk) > 50 else 50
def forward(self, similarities_matrix, query_img_id, gallery_img_id):
metric_dict = dict()
all_cmc, all_AP, all_INP = get_metrics(similarities_matrix, query_img_id,
gallery_img_id, self.max_rank)
if "Recallk" in self.config.keys():
topk = self.config['Recallk']['topk']
for k in topk:
metric_dict["recall{}".format(k)] = all_cmc[k - 1]
if "mAP" in self.config.keys():
mAP = np.mean(all_AP)
metric_dict["mAP"] = mAP
if "mINP" in self.config.keys():
mINP = np.mean(all_INP)
metric_dict["mINP"] = mINP
return metric_dict
@lru_cache()
def get_metrics(similarities_matrix, query_img_id, gallery_img_id,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册