未验证 提交 96c5113c 编写于 作者: W Wei Shengyu 提交者: GitHub

Merge pull request #802 from FredHuang16/patch-8

fix retrieval metrics
...@@ -16,8 +16,7 @@ from paddle import nn ...@@ -16,8 +16,7 @@ from paddle import nn
import copy import copy
from collections import OrderedDict from collections import OrderedDict
from .metrics import TopkAcc, mAP, mINP, Recallk from .metrics import TopkAcc, mAP, mINP, Recallk, RetriMetric
class CombinedMetrics(nn.Layer): class CombinedMetrics(nn.Layer):
def __init__(self, config_list): def __init__(self, config_list):
...@@ -25,13 +24,21 @@ class CombinedMetrics(nn.Layer): ...@@ -25,13 +24,21 @@ class CombinedMetrics(nn.Layer):
self.metric_func_list = [] self.metric_func_list = []
assert isinstance(config_list, list), ( assert isinstance(config_list, list), (
'operator config should be a list') 'operator config should be a list')
self.retri_config = dict() # retrieval metrics config
for config in config_list: for config in config_list:
assert isinstance(config, assert isinstance(config,
dict) and len(config) == 1, "yaml format error" dict) and len(config) == 1, "yaml format error"
metric_name = list(config)[0] metric_name = list(config)[0]
if metric_name in ["Recallk", "mAP", "mINP"]:
self.retri_config[metric_name] = config[metric_name]
continue
metric_params = config[metric_name] metric_params = config[metric_name]
self.metric_func_list.append(eval(metric_name)(**metric_params)) self.metric_func_list.append(eval(metric_name)(**metric_params))
if self.retri_config:
self.metric_func_list.append(RetriMetric(self.retri_config))
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
metric_dict = OrderedDict() metric_dict = OrderedDict()
for idx, metric_func in enumerate(self.metric_func_list): for idx, metric_func in enumerate(self.metric_func_list):
......
...@@ -84,6 +84,29 @@ class Recallk(nn.Layer): ...@@ -84,6 +84,29 @@ class Recallk(nn.Layer):
metric_dict["recall{}".format(k)] = all_cmc[k - 1] metric_dict["recall{}".format(k)] = all_cmc[k - 1]
return metric_dict return metric_dict
# retrieval metrics
class RetriMetric(nn.Layer):
def __init__(self, config):
super().__init__()
self.config = config
self.max_rank = 50 #max(self.topk) if max(self.topk) > 50 else 50
def forward(self, similarities_matrix, query_img_id, gallery_img_id):
metric_dict = dict()
all_cmc, all_AP, all_INP = get_metrics(similarities_matrix, query_img_id,
gallery_img_id, self.max_rank)
if "Recallk" in self.config.keys():
topk = self.config['Recallk']['topk']
for k in topk:
metric_dict["recall{}".format(k)] = all_cmc[k - 1]
if "mAP" in self.config.keys():
mAP = np.mean(all_AP)
metric_dict["mAP"] = mAP
if "mINP" in self.config.keys():
mINP = np.mean(all_INP)
metric_dict["mINP"] = mINP
return metric_dict
@lru_cache() @lru_cache()
def get_metrics(similarities_matrix, query_img_id, gallery_img_id, def get_metrics(similarities_matrix, query_img_id, gallery_img_id,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册