未验证 提交 3e1e25af 编写于 作者: F Felix 提交者: GitHub

Update metrics.py

上级 c6be2cd1
......@@ -84,6 +84,29 @@ class Recallk(nn.Layer):
metric_dict["recall{}".format(k)] = all_cmc[k - 1]
return metric_dict
# retrieval metrics
class RetriMetric(nn.Layer):
def __init__(self, config):
super().__init__()
self.config = config
self.max_rank = 50 #max(self.topk) if max(self.topk) > 50 else 50
def forward(self, similarities_matrix, query_img_id, gallery_img_id):
metric_dict = dict()
all_cmc, all_AP, all_INP = get_metrics(similarities_matrix, query_img_id,
gallery_img_id, self.max_rank)
if "Recallk" in self.config.keys():
topk = self.config['Recallk']['topk']
for k in topk:
metric_dict["recall{}".format(k)] = all_cmc[k - 1]
if "mAP" in self.config.keys():
mAP = np.mean(all_AP)
metric_dict["mAP"] = mAP
if "mINP" in self.config.keys():
mINP = np.mean(all_INP)
metric_dict["mINP"] = mINP
return metric_dict
@lru_cache()
def get_metrics(similarities_matrix, query_img_id, gallery_img_id,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册