提交 97f99cd8 编写于 作者: H HydrogenSulfate 提交者: Walter

refactor(retrieval): polish retrieval.py

上级 8542967b
......@@ -16,137 +16,132 @@ from __future__ import division
from __future__ import print_function
import platform
from typing import Optional
import numpy as np
import paddle
from ppcls.engine.train.utils import type_name
from ppcls.utils import logger
from ppcls.utils import all_gather
def retrieval_eval(engine, epoch_id=0):
engine.model.eval()
# step1. build query & gallery
# step1. prepare query and gallery features
if engine.gallery_query_dataloader is not None:
gallery_feas, gallery_img_id, gallery_unique_id = cal_feature(
engine, name='gallery_query')
query_feas, query_img_id, query_unique_id = gallery_feas, gallery_img_id, gallery_unique_id
gallery_feas, gallery_label_id, gallery_camera_id = compute_feature(
engine, "gallery_query")
query_feas, query_label_id, query_camera_id = gallery_feas, gallery_label_id, gallery_camera_id
else:
gallery_feas, gallery_img_id, gallery_unique_id = cal_feature(
engine, name='gallery')
query_feas, query_img_id, query_unique_id = cal_feature(
engine, name='query')
# step2. split data into blocks so as to save memory
sim_block_size = engine.config["Global"].get("sim_block_size", 64)
sections = [sim_block_size] * (len(query_feas) // sim_block_size)
if len(query_feas) % sim_block_size:
sections.append(len(query_feas) % sim_block_size)
fea_blocks = paddle.split(query_feas, num_or_sections=sections)
if query_unique_id is not None:
query_unique_id_blocks = paddle.split(
query_unique_id, num_or_sections=sections)
query_img_id_blocks = paddle.split(query_img_id, num_or_sections=sections)
gallery_feas, gallery_label_id, gallery_camera_id = compute_feature(
engine, "gallery")
query_feas, query_label_id, query_camera_id = compute_feature(engine,
"query")
# step2. split features into feature blocks for saving memory
block_size = engine.config["Global"].get("sim_block_size", 64)
sections = [block_size] * (len(query_feas) // block_size)
if len(query_feas) % block_size > 0:
sections.append(len(query_feas) % block_size)
query_feas_blocks = paddle.split(query_feas, sections)
query_camera_id_blocks = (paddle.split(query_camera_id, sections)
if query_camera_id is not None else None)
query_label_id_blocks = paddle.split(query_label_id, sections)
metric_key = None
# step3. do evaluation
# step3. compute metric
if engine.eval_loss_func is None:
metric_dict = {metric_key: 0.}
else:
# do evaluation with re-ranking(k-reciprocal)
reranking_flag = engine.config['Global'].get('re_ranking', False)
logger.info(f"re_ranking={reranking_flag}")
metric_dict = dict()
if reranking_flag:
# set the order from small to large
for i in range(len(engine.eval_metric_func.metric_func_list)):
if hasattr(engine.eval_metric_func.metric_func_list[i], 'descending') \
and engine.eval_metric_func.metric_func_list[i].descending is True:
engine.eval_metric_func.metric_func_list[
i].descending = False
use_reranking = engine.config["Global"].get("re_ranking", False)
logger.info(f"re_ranking={use_reranking}")
metric_dict = {}
if use_reranking:
for _, metric_func in enumerate(
engine.eval_metric_func.metric_func_list):
if hasattr(metric_func,
"descending") and metric_func.descending is True:
metric_func.descending = False
logger.warning(
f"re_ranking=True,{type_name(engine.eval_metric_func.metric_func_list[i])}.descending has been set to False"
f"re_ranking=True, set {type_name(metric_func)}.descending set to False"
)
# compute distance matrix(The smaller the value, the more similar)
distmat = re_ranking(
query_feas, gallery_feas, k1=20, k2=6, lambda_value=0.3)
# compute keep mask
unique_id_mask = (query_unique_id != gallery_unique_id.t())
image_id_mask = (query_img_id != gallery_img_id.t())
keep_mask = paddle.logical_or(image_id_mask, unique_id_mask)
# set inf(1e9) distance to those exist in gallery
distmat = distmat * keep_mask.astype("float32")
inf_mat = (paddle.logical_not(keep_mask).astype("float32")) * 1e20
# compute distance matrix
distmat = compute_re_ranking_dist(
query_feas, gallery_feas, engine.config["Global"].get(
"feature_normalize", True), 20, 6, 0.3)
# exclude illegal distance
camera_id_mask = query_camera_id != gallery_camera_id.t()
image_id_mask = query_label_id != gallery_label_id.t()
keep_mask = paddle.logical_or(image_id_mask, camera_id_mask)
distmat = distmat * keep_mask.astype(query_feas.dtype)
inf_mat = (
paddle.logical_not(keep_mask).astype(query_feas.dtype)) * (
distmat.max() + 1)
distmat = distmat + inf_mat
# compute metric
metric_tmp = engine.eval_metric_func(distmat, query_img_id,
gallery_img_id, keep_mask)
for key in metric_tmp:
metric_dict[key] = metric_tmp[key]
metric_block = engine.eval_metric_func(distmat, query_label_id,
gallery_label_id, keep_mask)
for key in metric_block:
metric_dict[key] = metric_block[key]
else:
# do evaluation without re-ranking
for block_idx, block_fea in enumerate(fea_blocks):
similarity_matrix = paddle.matmul(
block_fea, gallery_feas, transpose_y=True) # [n,m]
if query_unique_id is not None:
query_unique_id_block = query_unique_id_blocks[block_idx]
unique_id_mask = (
query_unique_id_block != gallery_unique_id.t())
query_img_id_block = query_img_id_blocks[block_idx]
image_id_mask = (query_img_id_block != gallery_img_id.t())
for block_idx, block_fea in enumerate(query_feas_blocks):
distmat = paddle.matmul(
block_fea, gallery_feas, transpose_y=True)
if query_camera_id is not None:
query_camera_id_block = query_camera_id_blocks[block_idx]
camera_id_mask = query_camera_id_block != gallery_camera_id.t(
)
query_label_id_block = query_label_id_blocks[block_idx]
image_id_mask = query_label_id_block != gallery_label_id.t(
)
keep_mask = paddle.logical_or(image_id_mask,
unique_id_mask)
similarity_matrix = similarity_matrix * keep_mask.astype(
"float32")
camera_id_mask)
distmat = distmat * keep_mask.astype("float32")
else:
keep_mask = None
metric_tmp = engine.eval_metric_func(
similarity_matrix, query_img_id_blocks[block_idx],
gallery_img_id, keep_mask)
metric_block = engine.eval_metric_func(
distmat, query_label_id_blocks[block_idx],
gallery_label_id, keep_mask)
for key in metric_tmp:
for key in metric_block:
if key not in metric_dict:
metric_dict[key] = metric_tmp[key] * block_fea.shape[
metric_dict[key] = metric_block[key] * block_fea.shape[
0] / len(query_feas)
else:
metric_dict[key] += metric_tmp[key] * block_fea.shape[
0] / len(query_feas)
metric_dict[key] += metric_block[
key] * block_fea.shape[0] / len(query_feas)
metric_info_list = []
for key in metric_dict:
metric_info_list.append(f"{key}: {metric_dict[key]:.5f}")
if metric_key is None:
metric_key = key
metric_info_list.append("{}: {:.5f}".format(key, metric_dict[key]))
metric_msg = ", ".join(metric_info_list)
logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg))
logger.info(f"[Eval][Epoch {epoch_id}][Avg]{metric_msg}")
return metric_dict[metric_key]
def cal_feature(engine, name='gallery'):
has_unique_id = False
all_unique_id = None
def compute_feature(engine, name="gallery"):
has_camera_id = False
all_camera_id = None
if name == 'gallery':
if name == "gallery":
dataloader = engine.gallery_dataloader
elif name == 'query':
elif name == "query":
dataloader = engine.query_dataloader
elif name == 'gallery_query':
elif name == "gallery_query":
dataloader = engine.gallery_query_dataloader
else:
raise RuntimeError("Only support gallery or query dataset")
batch_feas_list = []
img_id_list = []
unique_id_list = []
label_id_list = []
camera_id_list = []
max_iter = len(dataloader) - 1 if platform.system() == "Windows" else len(
dataloader)
for idx, batch in enumerate(dataloader): # load is very time-consuming
......@@ -160,7 +155,7 @@ def cal_feature(engine, name='gallery'):
batch = [paddle.to_tensor(x) for x in batch]
batch[1] = batch[1].reshape([-1, 1]).astype("int64")
if len(batch) == 3:
has_unique_id = True
has_camera_id = True
batch[2] = batch[2].reshape([-1, 1]).astype("int64")
if engine.amp and engine.amp_eval:
with paddle.amp.auto_cast(
......@@ -183,158 +178,160 @@ def cal_feature(engine, name='gallery'):
# use backbone's output as features
batch_feas = out["backbone"]
# do norm
# do norm(optinal)
if engine.config["Global"].get("feature_normalize", True):
feas_norm = paddle.sqrt(
batch_feas_norm = paddle.sqrt(
paddle.sum(paddle.square(batch_feas), axis=1, keepdim=True))
batch_feas = paddle.divide(batch_feas, feas_norm)
batch_feas = paddle.divide(batch_feas, batch_feas_norm)
# do binarize
# do binarize(optinal)
if engine.config["Global"].get("feature_binarize") == "round":
batch_feas = paddle.round(batch_feas).astype("float32") * 2.0 - 1.0
if engine.config["Global"].get("feature_binarize") == "sign":
elif engine.config["Global"].get("feature_binarize") == "sign":
batch_feas = paddle.sign(batch_feas).astype("float32")
if paddle.distributed.get_world_size() > 1:
batch_feas_gather = []
img_id_gather = []
unique_id_gather = []
paddle.distributed.all_gather(batch_feas_gather, batch_feas)
paddle.distributed.all_gather(img_id_gather, batch[1])
batch_feas_list.append(paddle.concat(batch_feas_gather))
img_id_list.append(paddle.concat(img_id_gather))
if has_unique_id:
paddle.distributed.all_gather(unique_id_gather, batch[2])
unique_id_list.append(paddle.concat(unique_id_gather))
batch_feas_list.append(all_gather(batch_feas))
label_id_list.append(all_gather(batch[1]))
if has_camera_id:
camera_id_list.append(all_gather(batch[2]))
else:
batch_feas_list.append(batch_feas)
img_id_list.append(batch[1])
if has_unique_id:
unique_id_list.append(batch[2])
label_id_list.append(batch[1])
if has_camera_id:
camera_id_list.append(batch[2])
if engine.use_dali:
dataloader.reset()
all_feas = paddle.concat(batch_feas_list)
all_img_id = paddle.concat(img_id_list)
if has_unique_id:
all_unique_id = paddle.concat(unique_id_list)
all_label_id = paddle.concat(label_id_list)
if has_camera_id:
all_camera_id = paddle.concat(camera_id_list)
# just for DistributedBatchSampler issue: repeat sampling
# discard redundant padding sample(s) at the end
total_samples = len(
dataloader.dataset) if not engine.use_dali else dataloader.size
all_feas = all_feas[:total_samples]
all_img_id = all_img_id[:total_samples]
if has_unique_id:
all_unique_id = all_unique_id[:total_samples]
all_label_id = all_label_id[:total_samples]
if has_camera_id:
all_camera_id = all_camera_id[:total_samples]
logger.info("Build {} done, all feat shape: {}, begin to eval..".format(
name, all_feas.shape))
return all_feas, all_img_id, all_unique_id
logger.info(f"Build {name} done, all feat shape: {all_feas.shape}")
return all_feas, all_label_id, all_camera_id
def re_ranking(query_feas: paddle.Tensor,
gallery_feas: paddle.Tensor,
k1: int=20,
k2: int=6,
lambda_value: int=0.5,
local_distmat: Optional[np.ndarray]=None,
only_local: bool=False) -> paddle.Tensor:
"""re-ranking, most computed with numpy
code heavily based on
https://github.com/michuanhaohao/reid-strong-baseline/blob/3da7e6f03164a92e696cb6da059b1cd771b0346d/utils/reid_metric.py
def k_reciprocal_neighbor(rank: np.ndarray, p: int, k: int) -> np.ndarray:
"""Implementation of k-reciprocal nearest neighbors, i.e. R(p, k)
Args:
query_feas (paddle.Tensor): query features, [num_query, num_features]
gallery_feas (paddle.Tensor): gallery features, [num_gallery, num_features]
k1 (int, optional): k1. Defaults to 20.
k2 (int, optional): k2. Defaults to 6.
lambda_value (int, optional): lambda. Defaults to 0.5.
local_distmat (Optional[np.ndarray], optional): local_distmat. Defaults to None.
only_local (bool, optional): only_local. Defaults to False.
rank (np.ndarray): Rank mat with shape of [N, N].
p (int): Probe index.
k (int): Parameter k for k-reciprocal nearest neighbors algorithm.
Returns:
paddle.Tensor: final_dist matrix after re-ranking, [num_query, num_gallery]
np.ndarray: K-reciprocal nearest neighbors of probe p with shape of [M, ].
"""
query_num = query_feas.shape[0]
all_num = query_num + gallery_feas.shape[0]
if only_local:
original_dist = local_distmat
else:
feat = paddle.concat([query_feas, gallery_feas])
logger.info('using GPU to compute original distance')
# use k+1 for excluding probe index itself
forward_k_neigh_index = rank[p, :k + 1]
backward_k_neigh_index = rank[forward_k_neigh_index, :k + 1]
candidate = np.where(backward_k_neigh_index == p)[0]
return forward_k_neigh_index[candidate]
def compute_re_ranking_dist(query_feas: paddle.Tensor,
gallery_feas: paddle.Tensor,
feature_normed: bool=True,
k1: int=20,
k2: int=6,
lamb: float=0.5) -> paddle.Tensor:
"""
Re-ranking Person Re-identification with k-reciprocal Encoding
Reference: https://arxiv.org/abs/1701.08398
Code refernence: https://github.com/michuanhaohao/reid-strong-baseline/blob/master/utils/re_ranking.py
# L2 distance
distmat = paddle.pow(feat, 2).sum(axis=1, keepdim=True).expand([all_num, all_num]) + \
paddle.pow(feat, 2).sum(axis=1, keepdim=True).expand([all_num, all_num]).t()
distmat = distmat.addmm(x=feat, y=feat.t(), alpha=-2.0, beta=1.0)
Args:
query_feas (paddle.Tensor): Query features with shape of [num_query, feature_dim].
gallery_feas (paddle.Tensor): Gallery features with shape of [num_gallery, feature_dim].
feature_normed (bool, optional): Whether input features are normalized.
k1 (int, optional): Parameter for K-reciprocal nearest neighbors. Defaults to 20.
k2 (int, optional): Parameter for K-nearest neighbors. Defaults to 6.
lamb (float, optional): Penalty factor. Defaults to 0.5.
original_dist = distmat.cpu().numpy()
del feat
if local_distmat is not None:
original_dist = original_dist + local_distmat
Returns:
paddle.Tensor: (1 - lamb) x Dj + lamb x D, with shape of [num_query, num_gallery].
"""
num_query = query_feas.shape[0]
num_gallery = gallery_feas.shape[0]
num_all = num_query + num_gallery
feat = paddle.concat([query_feas, gallery_feas], 0)
logger.info("Using GPU to compute original distance matrix")
# use L2 distance
if feature_normed:
original_dist = 2 - 2 * paddle.matmul(feat, feat, transpose_y=True)
else:
original_dist = paddle.pow(feat, 2).sum(axis=1, keepdim=True).expand([num_all, num_all]) + \
paddle.pow(feat, 2).sum(axis=1, keepdim=True).expand([num_all, num_all]).t()
original_dist = original_dist.addmm(
x=feat, y=feat.t(), alpha=-2.0, beta=1.0)
original_dist = original_dist.numpy()
del feat
gallery_num = original_dist.shape[0]
original_dist = np.transpose(original_dist / np.max(original_dist, axis=0))
V = np.zeros_like(original_dist).astype(np.float16)
initial_rank = np.argsort(original_dist).astype(np.int32)
logger.info('starting re_ranking')
for i in range(all_num):
# k-reciprocal neighbors
forward_k_neigh_index = initial_rank[i, :k1 + 1]
backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1]
fi = np.where(backward_k_neigh_index == i)[0]
k_reciprocal_index = forward_k_neigh_index[fi]
k_reciprocal_expansion_index = k_reciprocal_index
for j in range(len(k_reciprocal_index)):
candidate = k_reciprocal_index[j]
candidate_forward_k_neigh_index = initial_rank[candidate, :int(
np.around(k1 / 2)) + 1]
candidate_backward_k_neigh_index = initial_rank[
candidate_forward_k_neigh_index, :int(np.around(k1 / 2)) + 1]
fi_candidate = np.where(
candidate_backward_k_neigh_index == candidate)[0]
candidate_k_reciprocal_index = candidate_forward_k_neigh_index[
fi_candidate]
if len(
np.intersect1d(candidate_k_reciprocal_index,
k_reciprocal_index)) > 2 / 3 * len(
candidate_k_reciprocal_index):
k_reciprocal_expansion_index = np.append(
k_reciprocal_expansion_index, candidate_k_reciprocal_index)
k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index)
weight = np.exp(-original_dist[i, k_reciprocal_expansion_index])
V[i, k_reciprocal_expansion_index] = weight / np.sum(weight)
original_dist = original_dist[:query_num, ]
if k2 != 1:
initial_rank = np.argpartition(original_dist, range(1, k1 + 1))
logger.info("Start re-ranking...")
for p in range(num_all):
# compute R(p,k1)
p_k_reciprocal_ind = k_reciprocal_neighbor(initial_rank, p, k1)
# compute R*(p,k1)=R(p,k1)∪R(q,k1/2)
# s.t. |R(p,k1)∩R(q,k1/2)|>=2/3|R(q,k1/2)|, ∀q∈R(p,k1)
p_k_reciprocal_exp_ind = p_k_reciprocal_ind
for _, q in enumerate(p_k_reciprocal_ind):
q_k_reciprocal_ind = k_reciprocal_neighbor(initial_rank, q,
int(np.around(k1 / 2)))
if len(np.intersect1d(p_k_reciprocal_ind, q_k_reciprocal_ind)
) > 2 / 3 * len(q_k_reciprocal_ind):
p_k_reciprocal_exp_ind = np.append(p_k_reciprocal_exp_ind,
q_k_reciprocal_ind)
p_k_reciprocal_exp_ind = np.unique(p_k_reciprocal_exp_ind)
# reweight distance using gaussian kernel
weight = np.exp(-original_dist[p, p_k_reciprocal_exp_ind])
V[p, p_k_reciprocal_exp_ind] = weight / np.sum(weight)
# local query expansion
original_dist = original_dist[:num_query, ]
if k2 > 1:
V_qe = np.zeros_like(V, dtype=np.float16)
for i in range(all_num):
V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0)
for p in range(num_all):
V_qe[p, :] = np.mean(V[initial_rank[p, :k2], :], axis=0)
V = V_qe
del V_qe
del initial_rank
# cache k-reciprocal sets which contains gj
invIndex = []
for i in range(gallery_num):
invIndex.append(np.where(V[:, i] != 0)[0])
for gj in range(num_all):
invIndex.append(np.nonzero(V[:, gj])[0])
jaccard_dist = np.zeros_like(original_dist, dtype=np.float16)
for i in range(query_num):
temp_min = np.zeros(shape=[1, gallery_num], dtype=np.float16)
indNonZero = np.where(V[i, :] != 0)[0]
indImages = [invIndex[ind] for ind in indNonZero]
for j in range(len(indNonZero)):
temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum(
V[i, indNonZero[j]], V[indImages[j], indNonZero[j]])
jaccard_dist[i] = 1 - temp_min / (2 - temp_min)
final_dist = jaccard_dist * (1 - lambda_value
) + original_dist * lambda_value
for p in range(num_query):
sum_min = np.zeros(shape=[1, num_all], dtype=np.float16)
gj_ind = np.nonzero(V[p, :])[0]
gj_ind_inv = [invIndex[gj] for gj in gj_ind]
for j, gj in enumerate(gj_ind):
gi = gj_ind_inv[j]
sum_min[0, gi] += np.minimum(V[p, gj], V[gi, gj])
jaccard_dist[p] = 1 - sum_min / (2 - sum_min)
final_dist = jaccard_dist * (1 - lamb) + original_dist * lamb
del original_dist
del V
del jaccard_dist
final_dist = final_dist[:query_num, query_num:]
final_dist = final_dist[:num_query, num_query:]
final_dist = paddle.to_tensor(final_dist)
return final_dist
......@@ -13,15 +13,16 @@
# limitations under the License.
from . import logger
from . import metrics
from . import misc
from . import model_zoo
from . import metrics
from .save_load import init_model, save_model
from .config import get_config
from .misc import AverageMeter
from .metrics import multi_hot_encode
from .metrics import hamming_distance
from .dist_utils import all_gather
from .metrics import accuracy_score
from .metrics import precision_recall_fscore
from .metrics import hamming_distance
from .metrics import mean_average_precision
from .metrics import multi_hot_encode
from .metrics import precision_recall_fscore
from .misc import AverageMeter
from .save_load import init_model, save_model
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Union
import paddle
def all_gather(tensor: paddle.Tensor, concat: bool=True,
axis: int=0) -> Union[paddle.Tensor, List[paddle.Tensor]]:
"""Gather tensor from all devices, concatenate them along given axis if specified.
Args:
tensor (paddle.Tensor): Tensor to be gathered from all GPUs.
concat (bool, optional): Whether to concatenate gathered Tensors. Defaults to True.
axis (int, optional): Axis which concatenated along. Defaults to 0.
Returns:
Union[paddle.Tensor, List[paddle.Tensor]]: Gathered Tensors
"""
result = []
paddle.distributed.all_gather(result, tensor)
if concat:
return paddle.concat(result, axis)
return result
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册