未验证 提交 9e83fb4c 编写于 作者: B Bin Lu 提交者: GitHub

Update metrics.py

上级 c24dc93c
......@@ -69,7 +69,7 @@ class mINP(nn.Layer):
class Recallk(nn.Layer):
def __init__(self, topk=(1, 5)):
super().__init__()
assert isinstance(topk, (int, list))
assert isinstance(topk, (int, list, tuple))
if isinstance(topk, int):
topk = [topk]
self.topk = topk
......@@ -97,6 +97,10 @@ class RetriMetric(nn.Layer):
gallery_img_id, self.max_rank)
if "Recallk" in self.config.keys():
topk = self.config['Recallk']['topk']
assert isinstance(topk, (int, list, tuple))
if isinstance(topk, int):
topk = [topk]
self.topk = topk
for k in topk:
metric_dict["recall{}".format(k)] = all_cmc[k - 1]
if "mAP" in self.config.keys():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册