提交 c6865e25 编写于 作者: H HydrogenSulfate 提交者: Walter

refactor(retrieval): polish retrieval.py

上级 97f99cd8
...@@ -16,108 +16,90 @@ from __future__ import division ...@@ -16,108 +16,90 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import platform import platform
from collections import defaultdict
import numpy as np import numpy as np
import paddle import paddle
from ppcls.engine.train.utils import type_name
from ppcls.utils import logger from ppcls.utils import all_gather, logger
from ppcls.utils import all_gather
def retrieval_eval(engine, epoch_id=0): def retrieval_eval(engine, epoch_id=0):
engine.model.eval() engine.model.eval()
# step1. prepare query and gallery features # step1. prepare query and gallery features
if engine.gallery_query_dataloader is not None: if engine.gallery_query_dataloader is not None:
gallery_feas, gallery_label_id, gallery_camera_id = compute_feature( gallery_feat, gallery_label, gallery_camera = compute_feature(
engine, "gallery_query") engine, "gallery_query")
query_feas, query_label_id, query_camera_id = gallery_feas, gallery_label_id, gallery_camera_id query_feat, query_label, query_camera = gallery_feat, gallery_label, gallery_camera
else: else:
gallery_feas, gallery_label_id, gallery_camera_id = compute_feature( gallery_feat, gallery_label, gallery_camera = compute_feature(
engine, "gallery") engine, "gallery")
query_feas, query_label_id, query_camera_id = compute_feature(engine, query_feat, query_label, query_camera = compute_feature(engine,
"query") "query")
# step2. split features into feature blocks for saving memory # step2. split features into feature blocks for saving memory
num_query = len(query_feat)
block_size = engine.config["Global"].get("sim_block_size", 64) block_size = engine.config["Global"].get("sim_block_size", 64)
sections = [block_size] * (len(query_feas) // block_size) sections = [block_size] * (num_query // block_size)
if len(query_feas) % block_size > 0: if num_query % block_size > 0:
sections.append(len(query_feas) % block_size) sections.append(num_query % block_size)
query_feas_blocks = paddle.split(query_feas, sections) query_feat_blocks = paddle.split(query_feat, sections)
query_camera_id_blocks = (paddle.split(query_camera_id, sections) query_label_blocks = paddle.split(query_label, sections)
if query_camera_id is not None else None) query_camera_blocks = paddle.split(
query_label_id_blocks = paddle.split(query_label_id, sections) query_camera, sections) if query_camera is not None else None
metric_key = None metric_key = None
# step3. compute metric # step3. compute metric
if engine.eval_loss_func is None: if engine.eval_loss_func is None:
metric_dict = {metric_key: 0.} metric_dict = {metric_key: 0.0}
else: else:
use_reranking = engine.config["Global"].get("re_ranking", False) use_reranking = engine.config["Global"].get("re_ranking", False)
logger.info(f"re_ranking={use_reranking}") logger.info(f"re_ranking={use_reranking}")
metric_dict = {}
if use_reranking: if use_reranking:
for _, metric_func in enumerate(
engine.eval_metric_func.metric_func_list):
if hasattr(metric_func,
"descending") and metric_func.descending is True:
metric_func.descending = False
logger.warning(
f"re_ranking=True, set {type_name(metric_func)}.descending set to False"
)
# compute distance matrix # compute distance matrix
distmat = compute_re_ranking_dist( distmat = compute_re_ranking_dist(
query_feas, gallery_feas, engine.config["Global"].get( query_feat, gallery_feat, engine.config["Global"].get(
"feature_normalize", True), 20, 6, 0.3) "feature_normalize", True), 20, 6, 0.3)
# exclude illegal distance # exclude illegal distance
camera_id_mask = query_camera_id != gallery_camera_id.t() if query_camera is not None:
image_id_mask = query_label_id != gallery_label_id.t() camera_mask = query_camera != gallery_camera.t()
keep_mask = paddle.logical_or(image_id_mask, camera_id_mask) label_mask = query_label != gallery_label.t()
distmat = distmat * keep_mask.astype(query_feas.dtype) keep_mask = label_mask | camera_mask
inf_mat = ( distmat = keep_mask.astype(query_feat.dtype) * distmat + (
paddle.logical_not(keep_mask).astype(query_feas.dtype)) * ( ~keep_mask).astype(query_feat.dtype) * (distmat.max() + 1)
distmat.max() + 1) else:
distmat = distmat + inf_mat keep_mask = None
# compute metric with all samples
metric_block = engine.eval_metric_func(distmat, query_label_id, metric_dict = engine.eval_metric_func(-distmat, query_label,
gallery_label_id, keep_mask) gallery_label, keep_mask)
for key in metric_block:
metric_dict[key] = metric_block[key]
else: else:
for block_idx, block_fea in enumerate(query_feas_blocks): metric_dict = defaultdict(float)
for block_idx, block_feat in enumerate(query_feat_blocks):
# compute distance matrix
distmat = paddle.matmul( distmat = paddle.matmul(
block_fea, gallery_feas, transpose_y=True) block_feat, gallery_feat, transpose_y=True)
if query_camera_id is not None: # exclude illegal distance
query_camera_id_block = query_camera_id_blocks[block_idx] if query_camera is not None:
camera_id_mask = query_camera_id_block != gallery_camera_id.t( camera_mask = query_camera_blocks[
) block_idx] != gallery_camera.t()
label_mask = query_label_blocks[
query_label_id_block = query_label_id_blocks[block_idx] block_idx] != gallery_label.t()
image_id_mask = query_label_id_block != gallery_label_id.t( keep_mask = label_mask | camera_mask
) distmat = keep_mask.astype(query_feat.dtype) * distmat
keep_mask = paddle.logical_or(image_id_mask,
camera_id_mask)
distmat = distmat * keep_mask.astype("float32")
else: else:
keep_mask = None keep_mask = None
# compute metric by block
metric_block = engine.eval_metric_func( metric_block = engine.eval_metric_func(
distmat, query_label_id_blocks[block_idx], distmat, query_label_blocks[block_idx], gallery_label,
gallery_label_id, keep_mask) keep_mask)
# accumulate metric
for key in metric_block: for key in metric_block:
if key not in metric_dict: metric_dict[key] += metric_block[key] * block_feat.shape[
metric_dict[key] = metric_block[key] * block_fea.shape[ 0] / num_query
0] / len(query_feas)
else:
metric_dict[key] += metric_block[
key] * block_fea.shape[0] / len(query_feas)
metric_info_list = [] metric_info_list = []
for key in metric_dict: for key, value in metric_dict.items():
metric_info_list.append(f"{key}: {metric_dict[key]:.5f}") metric_info_list.append(f"{key}: {value:.5f}")
if metric_key is None: if metric_key is None:
metric_key = key metric_key = key
metric_msg = ", ".join(metric_info_list) metric_msg = ", ".join(metric_info_list)
...@@ -127,9 +109,6 @@ def retrieval_eval(engine, epoch_id=0): ...@@ -127,9 +109,6 @@ def retrieval_eval(engine, epoch_id=0):
def compute_feature(engine, name="gallery"): def compute_feature(engine, name="gallery"):
has_camera_id = False
all_camera_id = None
if name == "gallery": if name == "gallery":
dataloader = engine.gallery_dataloader dataloader = engine.gallery_dataloader
elif name == "query": elif name == "query":
...@@ -137,13 +116,16 @@ def compute_feature(engine, name="gallery"): ...@@ -137,13 +116,16 @@ def compute_feature(engine, name="gallery"):
elif name == "gallery_query": elif name == "gallery_query":
dataloader = engine.gallery_query_dataloader dataloader = engine.gallery_query_dataloader
else: else:
raise RuntimeError("Only support gallery or query dataset") raise ValueError(
f"Only support gallery or query or gallery_query dataset, but got {name}"
)
batch_feas_list = [] all_feat = []
label_id_list = [] all_label = []
camera_id_list = [] all_camera = []
max_iter = len(dataloader) - 1 if platform.system() == "Windows" else len( max_iter = len(dataloader) - 1 if platform.system() == "Windows" else len(
dataloader) dataloader)
has_camera = False
for idx, batch in enumerate(dataloader): # load is very time-consuming for idx, batch in enumerate(dataloader): # load is very time-consuming
if idx >= max_iter: if idx >= max_iter:
break break
...@@ -154,8 +136,8 @@ def compute_feature(engine, name="gallery"): ...@@ -154,8 +136,8 @@ def compute_feature(engine, name="gallery"):
batch = [paddle.to_tensor(x) for x in batch] batch = [paddle.to_tensor(x) for x in batch]
batch[1] = batch[1].reshape([-1, 1]).astype("int64") batch[1] = batch[1].reshape([-1, 1]).astype("int64")
if len(batch) == 3: if len(batch) >= 3:
has_camera_id = True has_camera = True
batch[2] = batch[2].reshape([-1, 1]).astype("int64") batch[2] = batch[2].reshape([-1, 1]).astype("int64")
if engine.amp and engine.amp_eval: if engine.amp and engine.amp_eval:
with paddle.amp.auto_cast( with paddle.amp.auto_cast(
...@@ -163,62 +145,61 @@ def compute_feature(engine, name="gallery"): ...@@ -163,62 +145,61 @@ def compute_feature(engine, name="gallery"):
"flatten_contiguous_range", "greater_than" "flatten_contiguous_range", "greater_than"
}, },
level=engine.amp_level): level=engine.amp_level):
out = engine.model(batch[0], batch[1]) out = engine.model(batch[0])
else: else:
out = engine.model(batch[0], batch[1]) out = engine.model(batch[0])
if "Student" in out: if "Student" in out:
out = out["Student"] out = out["Student"]
# get features # get features
if engine.config["Global"].get("retrieval_feature_from", if engine.config["Global"].get("retrieval_feature_from",
"features") == "features": "features") == "features":
# use neck's output as features # use output from neck as feature
batch_feas = out["features"] batch_feat = out["features"]
else: else:
# use backbone's output as features # use output from backbone as feature
batch_feas = out["backbone"] batch_feat = out["backbone"]
# do norm(optinal) # do norm(optional)
if engine.config["Global"].get("feature_normalize", True): if engine.config["Global"].get("feature_normalize", True):
batch_feas_norm = paddle.sqrt( batch_feat = paddle.nn.functional.normalize(batch_feat, p=2)
paddle.sum(paddle.square(batch_feas), axis=1, keepdim=True))
batch_feas = paddle.divide(batch_feas, batch_feas_norm)
# do binarize(optinal) # do binarize(optional)
if engine.config["Global"].get("feature_binarize") == "round": if engine.config["Global"].get("feature_binarize") == "round":
batch_feas = paddle.round(batch_feas).astype("float32") * 2.0 - 1.0 batch_feat = paddle.round(batch_feat).astype("float32") * 2.0 - 1.0
elif engine.config["Global"].get("feature_binarize") == "sign": elif engine.config["Global"].get("feature_binarize") == "sign":
batch_feas = paddle.sign(batch_feas).astype("float32") batch_feat = paddle.sign(batch_feat).astype("float32")
if paddle.distributed.get_world_size() > 1: if paddle.distributed.get_world_size() > 1:
batch_feas_list.append(all_gather(batch_feas)) all_feat.append(all_gather(batch_feat))
label_id_list.append(all_gather(batch[1])) all_label.append(all_gather(batch[1]))
if has_camera_id: if has_camera:
camera_id_list.append(all_gather(batch[2])) all_camera.append(all_gather(batch[2]))
else: else:
batch_feas_list.append(batch_feas) all_feat.append(batch_feat)
label_id_list.append(batch[1]) all_label.append(batch[1])
if has_camera_id: if has_camera:
camera_id_list.append(batch[2]) all_camera.append(batch[2])
if engine.use_dali: if engine.use_dali:
dataloader.reset() dataloader.reset()
all_feas = paddle.concat(batch_feas_list) all_feat = paddle.concat(all_feat)
all_label_id = paddle.concat(label_id_list) all_label = paddle.concat(all_label)
if has_camera_id: if has_camera:
all_camera_id = paddle.concat(camera_id_list) all_camera = paddle.concat(all_camera)
else:
all_camera = None
# discard redundant padding sample(s) at the end # discard redundant padding sample(s) at the end
total_samples = len( total_samples = dataloader.size if engine.use_dali else len(
dataloader.dataset) if not engine.use_dali else dataloader.size dataloader.dataset)
all_feas = all_feas[:total_samples] all_feat = all_feat[:total_samples]
all_label_id = all_label_id[:total_samples] all_label = all_label[:total_samples]
if has_camera_id: if has_camera:
all_camera_id = all_camera_id[:total_samples] all_camera = all_camera[:total_samples]
logger.info(f"Build {name} done, all feat shape: {all_feas.shape}") logger.info(f"Build {name} done, all feat shape: {all_feat.shape}")
return all_feas, all_label_id, all_camera_id return all_feat, all_label, all_camera
def k_reciprocal_neighbor(rank: np.ndarray, p: int, k: int) -> np.ndarray: def k_reciprocal_neighbor(rank: np.ndarray, p: int, k: int) -> np.ndarray:
...@@ -239,8 +220,8 @@ def k_reciprocal_neighbor(rank: np.ndarray, p: int, k: int) -> np.ndarray: ...@@ -239,8 +220,8 @@ def k_reciprocal_neighbor(rank: np.ndarray, p: int, k: int) -> np.ndarray:
return forward_k_neigh_index[candidate] return forward_k_neigh_index[candidate]
def compute_re_ranking_dist(query_feas: paddle.Tensor, def compute_re_ranking_dist(query_feat: paddle.Tensor,
gallery_feas: paddle.Tensor, gallery_feat: paddle.Tensor,
feature_normed: bool=True, feature_normed: bool=True,
k1: int=20, k1: int=20,
k2: int=6, k2: int=6,
...@@ -251,8 +232,8 @@ def compute_re_ranking_dist(query_feas: paddle.Tensor, ...@@ -251,8 +232,8 @@ def compute_re_ranking_dist(query_feas: paddle.Tensor,
Code refernence: https://github.com/michuanhaohao/reid-strong-baseline/blob/master/utils/re_ranking.py Code refernence: https://github.com/michuanhaohao/reid-strong-baseline/blob/master/utils/re_ranking.py
Args: Args:
query_feas (paddle.Tensor): Query features with shape of [num_query, feature_dim]. query_feat (paddle.Tensor): Query features with shape of [num_query, feature_dim].
gallery_feas (paddle.Tensor): Gallery features with shape of [num_gallery, feature_dim]. gallery_feat (paddle.Tensor): Gallery features with shape of [num_gallery, feature_dim].
feature_normed (bool, optional): Whether input features are normalized. feature_normed (bool, optional): Whether input features are normalized.
k1 (int, optional): Parameter for K-reciprocal nearest neighbors. Defaults to 20. k1 (int, optional): Parameter for K-reciprocal nearest neighbors. Defaults to 20.
k2 (int, optional): Parameter for K-nearest neighbors. Defaults to 6. k2 (int, optional): Parameter for K-nearest neighbors. Defaults to 6.
...@@ -261,10 +242,10 @@ def compute_re_ranking_dist(query_feas: paddle.Tensor, ...@@ -261,10 +242,10 @@ def compute_re_ranking_dist(query_feas: paddle.Tensor,
Returns: Returns:
paddle.Tensor: (1 - lamb) x Dj + lamb x D, with shape of [num_query, num_gallery]. paddle.Tensor: (1 - lamb) x Dj + lamb x D, with shape of [num_query, num_gallery].
""" """
num_query = query_feas.shape[0] num_query = query_feat.shape[0]
num_gallery = gallery_feas.shape[0] num_gallery = gallery_feat.shape[0]
num_all = num_query + num_gallery num_all = num_query + num_gallery
feat = paddle.concat([query_feas, gallery_feas], 0) feat = paddle.concat([query_feat, gallery_feat], 0)
logger.info("Using GPU to compute original distance matrix") logger.info("Using GPU to compute original distance matrix")
# use L2 distance # use L2 distance
...@@ -273,8 +254,7 @@ def compute_re_ranking_dist(query_feas: paddle.Tensor, ...@@ -273,8 +254,7 @@ def compute_re_ranking_dist(query_feas: paddle.Tensor,
else: else:
original_dist = paddle.pow(feat, 2).sum(axis=1, keepdim=True).expand([num_all, num_all]) + \ original_dist = paddle.pow(feat, 2).sum(axis=1, keepdim=True).expand([num_all, num_all]) + \
paddle.pow(feat, 2).sum(axis=1, keepdim=True).expand([num_all, num_all]).t() paddle.pow(feat, 2).sum(axis=1, keepdim=True).expand([num_all, num_all]).t()
original_dist = original_dist.addmm( original_dist = original_dist.addmm(feat, feat.t(), -2.0, 1.0)
x=feat, y=feat.t(), alpha=-2.0, beta=1.0)
original_dist = original_dist.numpy() original_dist = original_dist.numpy()
del feat del feat
...@@ -298,7 +278,6 @@ def compute_re_ranking_dist(query_feas: paddle.Tensor, ...@@ -298,7 +278,6 @@ def compute_re_ranking_dist(query_feas: paddle.Tensor,
p_k_reciprocal_exp_ind = np.append(p_k_reciprocal_exp_ind, p_k_reciprocal_exp_ind = np.append(p_k_reciprocal_exp_ind,
q_k_reciprocal_ind) q_k_reciprocal_ind)
p_k_reciprocal_exp_ind = np.unique(p_k_reciprocal_exp_ind) p_k_reciprocal_exp_ind = np.unique(p_k_reciprocal_exp_ind)
# reweight distance using gaussian kernel # reweight distance using gaussian kernel
weight = np.exp(-original_dist[p, p_k_reciprocal_exp_ind]) weight = np.exp(-original_dist[p, p_k_reciprocal_exp_ind])
V[p, p_k_reciprocal_exp_ind] = weight / np.sum(weight) V[p, p_k_reciprocal_exp_ind] = weight / np.sum(weight)
...@@ -318,6 +297,7 @@ def compute_re_ranking_dist(query_feas: paddle.Tensor, ...@@ -318,6 +297,7 @@ def compute_re_ranking_dist(query_feas: paddle.Tensor,
for gj in range(num_all): for gj in range(num_all):
invIndex.append(np.nonzero(V[:, gj])[0]) invIndex.append(np.nonzero(V[:, gj])[0])
# compute jaccard distance
jaccard_dist = np.zeros_like(original_dist, dtype=np.float16) jaccard_dist = np.zeros_like(original_dist, dtype=np.float16)
for p in range(num_query): for p in range(num_query):
sum_min = np.zeros(shape=[1, num_all], dtype=np.float16) sum_min = np.zeros(shape=[1, num_all], dtype=np.float16)
...@@ -328,7 +308,8 @@ def compute_re_ranking_dist(query_feas: paddle.Tensor, ...@@ -328,7 +308,8 @@ def compute_re_ranking_dist(query_feas: paddle.Tensor,
sum_min[0, gi] += np.minimum(V[p, gj], V[gi, gj]) sum_min[0, gi] += np.minimum(V[p, gj], V[gi, gj])
jaccard_dist[p] = 1 - sum_min / (2 - sum_min) jaccard_dist[p] = 1 - sum_min / (2 - sum_min)
final_dist = jaccard_dist * (1 - lamb) + original_dist * lamb # fuse jaccard distance with original distance
final_dist = (1 - lamb) * jaccard_dist + lamb * original_dist
del original_dist del original_dist
del V del V
del jaccard_dist del jaccard_dist
......
...@@ -287,10 +287,10 @@ class Recallk(nn.Layer): ...@@ -287,10 +287,10 @@ class Recallk(nn.Layer):
keep_mask): keep_mask):
metric_dict = dict() metric_dict = dict()
#get cmc # get cmc
choosen_indices = paddle.argsort( choosen_indices = paddle.argsort(
similarities_matrix, axis=1, descending=self.descending) similarities_matrix, axis=1, descending=self.descending)
gallery_labels_transpose = paddle.transpose(gallery_img_id, [1, 0]) gallery_labels_transpose = gallery_img_id.t()
gallery_labels_transpose = paddle.broadcast_to( gallery_labels_transpose = paddle.broadcast_to(
gallery_labels_transpose, gallery_labels_transpose,
shape=[ shape=[
...@@ -301,18 +301,14 @@ class Recallk(nn.Layer): ...@@ -301,18 +301,14 @@ class Recallk(nn.Layer):
equal_flag = paddle.equal(choosen_label, query_img_id) equal_flag = paddle.equal(choosen_label, query_img_id)
if keep_mask is not None: if keep_mask is not None:
keep_mask = paddle.index_sample( keep_mask = paddle.index_sample(
keep_mask.astype('float32'), choosen_indices) keep_mask.astype("float32"), choosen_indices)
equal_flag = paddle.logical_and(equal_flag, equal_flag = equal_flag & keep_mask.astype("bool")
keep_mask.astype('bool')) equal_flag = paddle.cast(equal_flag, "float32")
equal_flag = paddle.cast(equal_flag, 'float32')
real_query_num = paddle.sum(equal_flag, axis=1) real_query_num = paddle.sum(equal_flag, axis=1)
real_query_num = paddle.sum( real_query_num = paddle.sum((real_query_num > 0.0).astype("float32"))
paddle.greater_than(real_query_num, paddle.to_tensor(0.)).astype(
"float32"))
acc_sum = paddle.cumsum(equal_flag, axis=1) acc_sum = paddle.cumsum(equal_flag, axis=1)
mask = paddle.greater_than(acc_sum, mask = (acc_sum > 0.0).astype("float32")
paddle.to_tensor(0.)).astype("float32")
all_cmc = (paddle.sum(mask, axis=0) / real_query_num).numpy() all_cmc = (paddle.sum(mask, axis=0) / real_query_num).numpy()
for k in self.topk: for k in self.topk:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册