未验证 提交 4f61f908 编写于 作者: B Bin Lu 提交者: GitHub

Update metrics.py

上级 d87f4f15
...@@ -38,14 +38,14 @@ class mAP(nn.Layer): ...@@ -38,14 +38,14 @@ class mAP(nn.Layer):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def forward(self, similarities_matrix, query_labels, gallery_labels): def forward(self, similarities_matrix, query_img_id, gallery_img_id):
metric_dict = dict() metric_dict = dict()
choosen_indices = paddle.argsort(similarities_matrix, axis=1, descending=True) choosen_indices = paddle.argsort(similarities_matrix, axis=1, descending=True)
gallery_labels_transpose = paddle.transpose(gallery_labels, [1,0]) gallery_labels_transpose = paddle.transpose(gallery_img_id, [1,0])
gallery_labels_transpose = paddle.broadcast_to(gallery_labels_transpose, shape=[choosen_indices.shape[0], gallery_labels_transpose.shape[1]]) gallery_labels_transpose = paddle.broadcast_to(gallery_labels_transpose, shape=[choosen_indices.shape[0], gallery_labels_transpose.shape[1]])
choosen_label = paddle.index_sample(gallery_labels_transpose, choosen_indices) choosen_label = paddle.index_sample(gallery_labels_transpose, choosen_indices)
equal_flag = paddle.equal(choosen_label, query_labels) equal_flag = paddle.equal(choosen_label, query_img_id)
equal_flag = paddle.cast(equal_flag, 'float32') equal_flag = paddle.cast(equal_flag, 'float32')
acc_sum = paddle.cumsum(equal_flag, axis=1) acc_sum = paddle.cumsum(equal_flag, axis=1)
...@@ -62,14 +62,14 @@ class mINP(nn.Layer): ...@@ -62,14 +62,14 @@ class mINP(nn.Layer):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def forward(self, similarities_matrix, query_labels, gallery_labels): def forward(self, similarities_matrix, query_img_id, gallery_img_id):
metric_dict = dict() metric_dict = dict()
choosen_indices = paddle.argsort(similarities_matrix, axis=1, descending=True) choosen_indices = paddle.argsort(similarities_matrix, axis=1, descending=True)
gallery_labels_transpose = paddle.transpose(gallery_labels, [1,0]) gallery_labels_transpose = paddle.transpose(gallery_img_id, [1,0])
gallery_labels_transpose = paddle.broadcast_to(gallery_labels_transpose, shape=[choosen_indices.shape[0], gallery_labels_transpose.shape[1]]) gallery_labels_transpose = paddle.broadcast_to(gallery_labels_transpose, shape=[choosen_indices.shape[0], gallery_labels_transpose.shape[1]])
choosen_label = paddle.index_sample(gallery_labels_transpose, choosen_indices) choosen_label = paddle.index_sample(gallery_labels_transpose, choosen_indices)
tmp = paddle.equal(choosen_label, query_labels) tmp = paddle.equal(choosen_label, query_img_id)
tmp = paddle.cast(tmp, 'float64') tmp = paddle.cast(tmp, 'float64')
#do accumulative sum #do accumulative sum
...@@ -90,15 +90,15 @@ class Recallk(nn.Layer): ...@@ -90,15 +90,15 @@ class Recallk(nn.Layer):
topk = [topk] topk = [topk]
self.topk = topk self.topk = topk
def forward(self, similarities_matrix, query_labels, gallery_labels): def forward(self, similarities_matrix, query_img_id, gallery_img_id):
metric_dict = dict() metric_dict = dict()
#get cmc #get cmc
choosen_indices = paddle.argsort(similarities_matrix, axis=1, descending=True) choosen_indices = paddle.argsort(similarities_matrix, axis=1, descending=True)
gallery_labels_transpose = paddle.transpose(gallery_labels, [1,0]) gallery_labels_transpose = paddle.transpose(gallery_img_id, [1,0])
gallery_labels_transpose = paddle.broadcast_to(gallery_labels_transpose, shape=[choosen_indices.shape[0], gallery_labels_transpose.shape[1]]) gallery_labels_transpose = paddle.broadcast_to(gallery_labels_transpose, shape=[choosen_indices.shape[0], gallery_labels_transpose.shape[1]])
choosen_label = paddle.index_sample(gallery_labels_transpose, choosen_indices) choosen_label = paddle.index_sample(gallery_labels_transpose, choosen_indices)
equal_flag = paddle.equal(choosen_label, query_labels) equal_flag = paddle.equal(choosen_label, query_img_id)
equal_flag = paddle.cast(equal_flag, 'float32') equal_flag = paddle.cast(equal_flag, 'float32')
acc_sum = paddle.cumsum(equal_flag, axis=1) acc_sum = paddle.cumsum(equal_flag, axis=1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册