提交 fa2b698a 编写于 作者: D dongshuilong

fix metric bugs

上级 2da31a87
# global configs
Trainer:
name: TrainerReID
Global:
checkpoints: null
pretrained_model: null
......
......@@ -39,8 +39,9 @@ class TopkAcc(nn.Layer):
class mAP(nn.Layer):
def __init__(self):
def __init__(self, name="mAP"):
super().__init__()
self.name = name
def forward(self, similarities_matrix, query_img_id, gallery_img_id):
metric_dict = dict()
......@@ -48,13 +49,14 @@ class mAP(nn.Layer):
gallery_img_id)
mAP = np.mean(all_AP)
metric_dict["mAP"] = mAP
metric_dict[self.name] = mAP
return metric_dict
class mINP(nn.Layer):
def __init__(self):
def __init__(self, name="mINP"):
super().__init__()
self.name = name
def forward(self, similarities_matrix, query_img_id, gallery_img_id):
metric_dict = dict()
......@@ -62,7 +64,7 @@ class mINP(nn.Layer):
gallery_img_id)
mINP = np.mean(all_INP)
metric_dict["mINP"] = mINP
metric_dict[self.name] = mINP
return metric_dict
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册