未验证 提交 4f61f908 编写于 作者: B Bin Lu 提交者: GitHub

Update metrics.py

上级 d87f4f15
......@@ -38,14 +38,14 @@ class mAP(nn.Layer):
def __init__(self):
super().__init__()
def forward(self, similarities_matrix, query_labels, gallery_labels):
def forward(self, similarities_matrix, query_img_id, gallery_img_id):
metric_dict = dict()
choosen_indices = paddle.argsort(similarities_matrix, axis=1, descending=True)
gallery_labels_transpose = paddle.transpose(gallery_labels, [1,0])
gallery_labels_transpose = paddle.transpose(gallery_img_id, [1,0])
gallery_labels_transpose = paddle.broadcast_to(gallery_labels_transpose, shape=[choosen_indices.shape[0], gallery_labels_transpose.shape[1]])
choosen_label = paddle.index_sample(gallery_labels_transpose, choosen_indices)
equal_flag = paddle.equal(choosen_label, query_labels)
equal_flag = paddle.equal(choosen_label, query_img_id)
equal_flag = paddle.cast(equal_flag, 'float32')
acc_sum = paddle.cumsum(equal_flag, axis=1)
......@@ -62,14 +62,14 @@ class mINP(nn.Layer):
def __init__(self):
super().__init__()
def forward(self, similarities_matrix, query_labels, gallery_labels):
def forward(self, similarities_matrix, query_img_id, gallery_img_id):
metric_dict = dict()
choosen_indices = paddle.argsort(similarities_matrix, axis=1, descending=True)
gallery_labels_transpose = paddle.transpose(gallery_labels, [1,0])
gallery_labels_transpose = paddle.transpose(gallery_img_id, [1,0])
gallery_labels_transpose = paddle.broadcast_to(gallery_labels_transpose, shape=[choosen_indices.shape[0], gallery_labels_transpose.shape[1]])
choosen_label = paddle.index_sample(gallery_labels_transpose, choosen_indices)
tmp = paddle.equal(choosen_label, query_labels)
tmp = paddle.equal(choosen_label, query_img_id)
tmp = paddle.cast(tmp, 'float64')
#do accumulative sum
......@@ -90,15 +90,15 @@ class Recallk(nn.Layer):
topk = [topk]
self.topk = topk
def forward(self, similarities_matrix, query_labels, gallery_labels):
def forward(self, similarities_matrix, query_img_id, gallery_img_id):
metric_dict = dict()
#get cmc
choosen_indices = paddle.argsort(similarities_matrix, axis=1, descending=True)
gallery_labels_transpose = paddle.transpose(gallery_labels, [1,0])
gallery_labels_transpose = paddle.transpose(gallery_img_id, [1,0])
gallery_labels_transpose = paddle.broadcast_to(gallery_labels_transpose, shape=[choosen_indices.shape[0], gallery_labels_transpose.shape[1]])
choosen_label = paddle.index_sample(gallery_labels_transpose, choosen_indices)
equal_flag = paddle.equal(choosen_label, query_labels)
equal_flag = paddle.equal(choosen_label, query_img_id)
equal_flag = paddle.cast(equal_flag, 'float32')
acc_sum = paddle.cumsum(equal_flag, axis=1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册