未验证 提交 e59c0efe 编写于 作者: B Bin Lu 提交者: GitHub

Merge pull request #830 from Intsigstephon/develop_reg

retrieval metric speed up 
......@@ -29,10 +29,6 @@ class ICartoonDataset(CommonDataset):
with open(self._cls_path) as fd:
lines = fd.readlines()
if seed is not None:
np.random.RandomState(seed).shuffle(lines)
else:
np.random.shuffle(lines)
for l in lines:
l = l.strip().split("\t")
self.images.append(os.path.join(self._img_root, l[0]))
......
......@@ -401,9 +401,7 @@ class Trainer(object):
name='gallery')
query_feas, query_img_id, query_query_id = self._cal_feature(
name='query')
gallery_img_id = gallery_img_id
# if gallery_unique_id is not None:
# gallery_unique_id = gallery_unique_id
# step2. do evaluation
sim_block_size = self.config["Global"].get("sim_block_size", 64)
sections = [sim_block_size] * (len(query_feas) // sim_block_size)
......@@ -440,13 +438,9 @@ class Trainer(object):
for key in metric_tmp:
if key not in metric_dict:
metric_dict[key] = metric_tmp[key]
metric_dict[key] = metric_tmp[key] * block_fea.shape[0] / len(query_feas)
else:
metric_dict[key] += metric_tmp[key]
num_sections = len(fea_blocks)
for key in metric_dict:
metric_dict[key] = metric_dict[key] / num_sections
metric_dict[key] += metric_tmp[key] * block_fea.shape[0] / len(query_feas)
metric_info_list = []
for key in metric_dict:
......
......@@ -16,39 +16,31 @@ from paddle import nn
import copy
from collections import OrderedDict
from .metrics import TopkAcc, mAP, mINP, Recallk, RetriMetric
from .metrics import TopkAcc, mAP, mINP, Recallk
from .metrics import DistillationTopkAcc
class CombinedMetrics(nn.Layer):
def __init__(self, config_list):
super().__init__()
self.metric_func_list = []
assert isinstance(config_list, list), (
'operator config should be a list')
self.retri_config = dict() # retrieval metrics config
for config in config_list:
assert isinstance(config,
dict) and len(config) == 1, "yaml format error"
metric_name = list(config)[0]
if metric_name in ["Recallk", "mAP", "mINP"]:
self.retri_config[metric_name] = config[metric_name]
continue
metric_params = config[metric_name]
self.metric_func_list.append(eval(metric_name)(**metric_params))
if self.retri_config:
self.metric_func_list.append(RetriMetric(self.retri_config))
if metric_params is not None:
self.metric_func_list.append(eval(metric_name)(**metric_params))
else:
self.metric_func_list.append(eval(metric_name)())
def __call__(self, *args, **kwargs):
metric_dict = OrderedDict()
for idx, metric_func in enumerate(self.metric_func_list):
metric_dict.update(metric_func(*args, **kwargs))
return metric_dict
def build_metrics(config):
metrics_list = CombinedMetrics(copy.deepcopy(config))
return metrics_list
......@@ -15,8 +15,6 @@
import numpy as np
import paddle
import paddle.nn as nn
from functools import lru_cache
class TopkAcc(nn.Layer):
def __init__(self, topk=(1, 5)):
......@@ -36,35 +34,54 @@ class TopkAcc(nn.Layer):
x, label, k=k)
return metric_dict
class mAP(nn.Layer):
def __init__(self):
super().__init__()
def forward(self, similarities_matrix, query_img_id, gallery_img_id):
metric_dict = dict()
_, all_AP, _ = get_metrics(similarities_matrix, query_img_id,
gallery_img_id)
mAP = np.mean(all_AP)
metric_dict["mAP"] = mAP
choosen_indices = paddle.argsort(similarities_matrix, axis=1, descending=True)
gallery_labels_transpose = paddle.transpose(gallery_img_id, [1,0])
gallery_labels_transpose = paddle.broadcast_to(gallery_labels_transpose, shape=[choosen_indices.shape[0], gallery_labels_transpose.shape[1]])
choosen_label = paddle.index_sample(gallery_labels_transpose, choosen_indices)
equal_flag = paddle.equal(choosen_label, query_img_id)
equal_flag = paddle.cast(equal_flag, 'float32')
acc_sum = paddle.cumsum(equal_flag, axis=1)
div = paddle.arange(acc_sum.shape[1]).astype("float32") + 1
precision = paddle.divide(acc_sum, div)
#calc map
precision_mask = paddle.multiply(equal_flag, precision)
ap = paddle.sum(precision_mask, axis=1) / paddle.sum(equal_flag, axis=1)
metric_dict["mAP"] = paddle.mean(ap).numpy()[0]
return metric_dict
class mINP(nn.Layer):
def __init__(self):
super().__init__()
def forward(self, similarities_matrix, query_img_id, gallery_img_id):
metric_dict = dict()
_, _, all_INP = get_metrics(similarities_matrix, query_img_id,
gallery_img_id)
mINP = np.mean(all_INP)
metric_dict["mINP"] = mINP
choosen_indices = paddle.argsort(similarities_matrix, axis=1, descending=True)
gallery_labels_transpose = paddle.transpose(gallery_img_id, [1,0])
gallery_labels_transpose = paddle.broadcast_to(gallery_labels_transpose, shape=[choosen_indices.shape[0], gallery_labels_transpose.shape[1]])
choosen_label = paddle.index_sample(gallery_labels_transpose, choosen_indices)
tmp = paddle.equal(choosen_label, query_img_id)
tmp = paddle.cast(tmp, 'float64')
#do accumulative sum
div = paddle.arange(tmp.shape[1]).astype("float64") + 2
minus = paddle.divide(tmp, div)
auxilary = paddle.subtract(tmp, minus)
hard_index = paddle.argmax(auxilary, axis=1).astype("float64")
all_INP = paddle.divide(paddle.sum(tmp, axis=1), hard_index)
mINP = paddle.mean(all_INP)
metric_dict["mINP"] = mINP.numpy()[0]
return metric_dict
class Recallk(nn.Layer):
def __init__(self, topk=(1, 5)):
super().__init__()
......@@ -72,91 +89,26 @@ class Recallk(nn.Layer):
if isinstance(topk, int):
topk = [topk]
self.topk = topk
self.max_rank = max(self.topk) if max(self.topk) > 50 else 50
def forward(self, similarities_matrix, query_img_id, gallery_img_id):
metric_dict = dict()
all_cmc, _, _ = get_metrics(similarities_matrix, query_img_id,
gallery_img_id, self.max_rank)
#get cmc
choosen_indices = paddle.argsort(similarities_matrix, axis=1, descending=True)
gallery_labels_transpose = paddle.transpose(gallery_img_id, [1,0])
gallery_labels_transpose = paddle.broadcast_to(gallery_labels_transpose, shape=[choosen_indices.shape[0], gallery_labels_transpose.shape[1]])
choosen_label = paddle.index_sample(gallery_labels_transpose, choosen_indices)
equal_flag = paddle.equal(choosen_label, query_img_id)
equal_flag = paddle.cast(equal_flag, 'float32')
acc_sum = paddle.cumsum(equal_flag, axis=1)
mask = paddle.greater_than(acc_sum, paddle.to_tensor(0.)).astype("float32")
all_cmc = paddle.mean(mask, axis=0).numpy()
for k in self.topk:
metric_dict["recall{}".format(k)] = all_cmc[k - 1]
return metric_dict
# retrieval metrics
class RetriMetric(nn.Layer):
def __init__(self, config):
super().__init__()
self.config = config
self.max_rank = 50 #max(self.topk) if max(self.topk) > 50 else 50
def forward(self, similarities_matrix, query_img_id, gallery_img_id):
metric_dict = dict()
all_cmc, all_AP, all_INP = get_metrics(
similarities_matrix, query_img_id, gallery_img_id, self.max_rank)
if "Recallk" in self.config.keys():
topk = self.config['Recallk']['topk']
assert isinstance(topk, (int, list, tuple))
if isinstance(topk, int):
topk = [topk]
for k in topk:
metric_dict["recall{}".format(k)] = all_cmc[k - 1]
if "mAP" in self.config.keys():
mAP = np.mean(all_AP)
metric_dict["mAP"] = mAP
if "mINP" in self.config.keys():
mINP = np.mean(all_INP)
metric_dict["mINP"] = mINP
return metric_dict
@lru_cache()
def get_metrics(similarities_matrix, query_img_id, gallery_img_id,
max_rank=50):
num_q, num_g = similarities_matrix.shape
q_pids = query_img_id.numpy().reshape((query_img_id.shape[0]))
g_pids = gallery_img_id.numpy().reshape((gallery_img_id.shape[0]))
if num_g < max_rank:
max_rank = num_g
print('Note: number of gallery samples is quite small, got {}'.format(
num_g))
indices = paddle.argsort(
similarities_matrix, axis=1, descending=True).numpy()
all_cmc = []
all_AP = []
all_INP = []
num_valid_q = 0
matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
for q_idx in range(num_q):
raw_cmc = matches[q_idx]
if not np.any(raw_cmc):
continue
cmc = raw_cmc.cumsum()
pos_idx = np.where(raw_cmc == 1)
max_pos_idx = np.max(pos_idx)
inp = cmc[max_pos_idx] / (max_pos_idx + 1.0)
all_INP.append(inp)
cmc[cmc > 1] = 1
all_cmc.append(cmc[:max_rank])
num_valid_q += 1.
num_rel = raw_cmc.sum()
tmp_cmc = raw_cmc.cumsum()
tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)]
tmp_cmc = np.asarray(tmp_cmc) * raw_cmc
AP = tmp_cmc.sum() / num_rel
all_AP.append(AP)
assert num_valid_q > 0, 'Error: all query identities do not appear in gallery'
all_cmc = np.asarray(all_cmc).astype(np.float32)
all_cmc = all_cmc.sum(0) / num_valid_q
return all_cmc, all_AP, all_INP
class DistillationTopkAcc(TopkAcc):
def __init__(self, model_key, feature_key=None, topk=(1, 5)):
super().__init__(topk=topk)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册