未验证 提交 9e83fb4c 编写于 作者: B Bin Lu 提交者: GitHub

Update metrics.py

上级 c24dc93c
...@@ -69,7 +69,7 @@ class mINP(nn.Layer): ...@@ -69,7 +69,7 @@ class mINP(nn.Layer):
class Recallk(nn.Layer): class Recallk(nn.Layer):
def __init__(self, topk=(1, 5)): def __init__(self, topk=(1, 5)):
super().__init__() super().__init__()
assert isinstance(topk, (int, list)) assert isinstance(topk, (int, list, tuple))
if isinstance(topk, int): if isinstance(topk, int):
topk = [topk] topk = [topk]
self.topk = topk self.topk = topk
...@@ -97,6 +97,10 @@ class RetriMetric(nn.Layer): ...@@ -97,6 +97,10 @@ class RetriMetric(nn.Layer):
gallery_img_id, self.max_rank) gallery_img_id, self.max_rank)
if "Recallk" in self.config.keys(): if "Recallk" in self.config.keys():
topk = self.config['Recallk']['topk'] topk = self.config['Recallk']['topk']
assert isinstance(topk, (int, list, tuple))
if isinstance(topk, int):
topk = [topk]
self.topk = topk
for k in topk: for k in topk:
metric_dict["recall{}".format(k)] = all_cmc[k - 1] metric_dict["recall{}".format(k)] = all_cmc[k - 1]
if "mAP" in self.config.keys(): if "mAP" in self.config.keys():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册