提交 fa2b698a 编写于 作者: D dongshuilong

fix metric bugs

上级 2da31a87
# global configs # global configs
Trainer:
name: TrainerReID
Global: Global:
checkpoints: null checkpoints: null
pretrained_model: null pretrained_model: null
......
...@@ -39,8 +39,9 @@ class TopkAcc(nn.Layer): ...@@ -39,8 +39,9 @@ class TopkAcc(nn.Layer):
class mAP(nn.Layer): class mAP(nn.Layer):
def __init__(self): def __init__(self, name="mAP"):
super().__init__() super().__init__()
self.name = name
def forward(self, similarities_matrix, query_img_id, gallery_img_id): def forward(self, similarities_matrix, query_img_id, gallery_img_id):
metric_dict = dict() metric_dict = dict()
...@@ -48,13 +49,14 @@ class mAP(nn.Layer): ...@@ -48,13 +49,14 @@ class mAP(nn.Layer):
gallery_img_id) gallery_img_id)
mAP = np.mean(all_AP) mAP = np.mean(all_AP)
metric_dict["mAP"] = mAP metric_dict[self.name] = mAP
return metric_dict return metric_dict
class mINP(nn.Layer): class mINP(nn.Layer):
def __init__(self): def __init__(self, name="mINP"):
super().__init__() super().__init__()
self.name = name
def forward(self, similarities_matrix, query_img_id, gallery_img_id): def forward(self, similarities_matrix, query_img_id, gallery_img_id):
metric_dict = dict() metric_dict = dict()
...@@ -62,7 +64,7 @@ class mINP(nn.Layer): ...@@ -62,7 +64,7 @@ class mINP(nn.Layer):
gallery_img_id) gallery_img_id)
mINP = np.mean(all_INP) mINP = np.mean(all_INP)
metric_dict["mINP"] = mINP metric_dict[self.name] = mINP
return metric_dict return metric_dict
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册