diff --git a/ppcls/engine/evaluation/retrieval.py b/ppcls/engine/evaluation/retrieval.py index 4ec7355bf035c8586587687ca37a58e45d1a3c80..40e75346f065678517ca3e7592f11e12163ba7c7 100644 --- a/ppcls/engine/evaluation/retrieval.py +++ b/ppcls/engine/evaluation/retrieval.py @@ -16,137 +16,132 @@ from __future__ import division from __future__ import print_function import platform -from typing import Optional import numpy as np import paddle from ppcls.engine.train.utils import type_name from ppcls.utils import logger +from ppcls.utils import all_gather def retrieval_eval(engine, epoch_id=0): engine.model.eval() - # step1. build query & gallery + # step1. prepare query and gallery features if engine.gallery_query_dataloader is not None: - gallery_feas, gallery_img_id, gallery_unique_id = cal_feature( - engine, name='gallery_query') - query_feas, query_img_id, query_unique_id = gallery_feas, gallery_img_id, gallery_unique_id + gallery_feas, gallery_label_id, gallery_camera_id = compute_feature( + engine, "gallery_query") + query_feas, query_label_id, query_camera_id = gallery_feas, gallery_label_id, gallery_camera_id else: - gallery_feas, gallery_img_id, gallery_unique_id = cal_feature( - engine, name='gallery') - query_feas, query_img_id, query_unique_id = cal_feature( - engine, name='query') - - # step2. split data into blocks so as to save memory - sim_block_size = engine.config["Global"].get("sim_block_size", 64) - sections = [sim_block_size] * (len(query_feas) // sim_block_size) - if len(query_feas) % sim_block_size: - sections.append(len(query_feas) % sim_block_size) - - fea_blocks = paddle.split(query_feas, num_or_sections=sections) - if query_unique_id is not None: - query_unique_id_blocks = paddle.split( - query_unique_id, num_or_sections=sections) - query_img_id_blocks = paddle.split(query_img_id, num_or_sections=sections) + gallery_feas, gallery_label_id, gallery_camera_id = compute_feature( + engine, "gallery") + query_feas, query_label_id, query_camera_id = compute_feature(engine, + "query") + + # step2. split features into feature blocks for saving memory + block_size = engine.config["Global"].get("sim_block_size", 64) + sections = [block_size] * (len(query_feas) // block_size) + if len(query_feas) % block_size > 0: + sections.append(len(query_feas) % block_size) + + query_feas_blocks = paddle.split(query_feas, sections) + query_camera_id_blocks = (paddle.split(query_camera_id, sections) + if query_camera_id is not None else None) + query_label_id_blocks = paddle.split(query_label_id, sections) metric_key = None - # step3. do evaluation + # step3. compute metric if engine.eval_loss_func is None: metric_dict = {metric_key: 0.} else: - # do evaluation with re-ranking(k-reciprocal) - reranking_flag = engine.config['Global'].get('re_ranking', False) - logger.info(f"re_ranking={reranking_flag}") - metric_dict = dict() - if reranking_flag: - # set the order from small to large - for i in range(len(engine.eval_metric_func.metric_func_list)): - if hasattr(engine.eval_metric_func.metric_func_list[i], 'descending') \ - and engine.eval_metric_func.metric_func_list[i].descending is True: - engine.eval_metric_func.metric_func_list[ - i].descending = False + use_reranking = engine.config["Global"].get("re_ranking", False) + logger.info(f"re_ranking={use_reranking}") + metric_dict = {} + if use_reranking: + for _, metric_func in enumerate( + engine.eval_metric_func.metric_func_list): + if hasattr(metric_func, + "descending") and metric_func.descending is True: + metric_func.descending = False logger.warning( - f"re_ranking=True,{type_name(engine.eval_metric_func.metric_func_list[i])}.descending has been set to False" + f"re_ranking=True, set {type_name(metric_func)}.descending set to False" ) - - # compute distance matrix(The smaller the value, the more similar) - distmat = re_ranking( - query_feas, gallery_feas, k1=20, k2=6, lambda_value=0.3) - - # compute keep mask - unique_id_mask = (query_unique_id != gallery_unique_id.t()) - image_id_mask = (query_img_id != gallery_img_id.t()) - keep_mask = paddle.logical_or(image_id_mask, unique_id_mask) - - # set inf(1e9) distance to those exist in gallery - distmat = distmat * keep_mask.astype("float32") - inf_mat = (paddle.logical_not(keep_mask).astype("float32")) * 1e20 + # compute distance matrix + distmat = compute_re_ranking_dist( + query_feas, gallery_feas, engine.config["Global"].get( + "feature_normalize", True), 20, 6, 0.3) + + # exclude illegal distance + camera_id_mask = query_camera_id != gallery_camera_id.t() + image_id_mask = query_label_id != gallery_label_id.t() + keep_mask = paddle.logical_or(image_id_mask, camera_id_mask) + distmat = distmat * keep_mask.astype(query_feas.dtype) + inf_mat = ( + paddle.logical_not(keep_mask).astype(query_feas.dtype)) * ( + distmat.max() + 1) distmat = distmat + inf_mat - # compute metric - metric_tmp = engine.eval_metric_func(distmat, query_img_id, - gallery_img_id, keep_mask) - for key in metric_tmp: - metric_dict[key] = metric_tmp[key] + metric_block = engine.eval_metric_func(distmat, query_label_id, + gallery_label_id, keep_mask) + for key in metric_block: + metric_dict[key] = metric_block[key] else: - # do evaluation without re-ranking - for block_idx, block_fea in enumerate(fea_blocks): - similarity_matrix = paddle.matmul( - block_fea, gallery_feas, transpose_y=True) # [n,m] - if query_unique_id is not None: - query_unique_id_block = query_unique_id_blocks[block_idx] - unique_id_mask = ( - query_unique_id_block != gallery_unique_id.t()) - - query_img_id_block = query_img_id_blocks[block_idx] - image_id_mask = (query_img_id_block != gallery_img_id.t()) + for block_idx, block_fea in enumerate(query_feas_blocks): + distmat = paddle.matmul( + block_fea, gallery_feas, transpose_y=True) + if query_camera_id is not None: + query_camera_id_block = query_camera_id_blocks[block_idx] + camera_id_mask = query_camera_id_block != gallery_camera_id.t( + ) + + query_label_id_block = query_label_id_blocks[block_idx] + image_id_mask = query_label_id_block != gallery_label_id.t( + ) keep_mask = paddle.logical_or(image_id_mask, - unique_id_mask) - similarity_matrix = similarity_matrix * keep_mask.astype( - "float32") + camera_id_mask) + distmat = distmat * keep_mask.astype("float32") else: keep_mask = None - metric_tmp = engine.eval_metric_func( - similarity_matrix, query_img_id_blocks[block_idx], - gallery_img_id, keep_mask) + metric_block = engine.eval_metric_func( + distmat, query_label_id_blocks[block_idx], + gallery_label_id, keep_mask) - for key in metric_tmp: + for key in metric_block: if key not in metric_dict: - metric_dict[key] = metric_tmp[key] * block_fea.shape[ + metric_dict[key] = metric_block[key] * block_fea.shape[ 0] / len(query_feas) else: - metric_dict[key] += metric_tmp[key] * block_fea.shape[ - 0] / len(query_feas) + metric_dict[key] += metric_block[ + key] * block_fea.shape[0] / len(query_feas) metric_info_list = [] for key in metric_dict: + metric_info_list.append(f"{key}: {metric_dict[key]:.5f}") if metric_key is None: metric_key = key - metric_info_list.append("{}: {:.5f}".format(key, metric_dict[key])) metric_msg = ", ".join(metric_info_list) - logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg)) + logger.info(f"[Eval][Epoch {epoch_id}][Avg]{metric_msg}") return metric_dict[metric_key] -def cal_feature(engine, name='gallery'): - has_unique_id = False - all_unique_id = None +def compute_feature(engine, name="gallery"): + has_camera_id = False + all_camera_id = None - if name == 'gallery': + if name == "gallery": dataloader = engine.gallery_dataloader - elif name == 'query': + elif name == "query": dataloader = engine.query_dataloader - elif name == 'gallery_query': + elif name == "gallery_query": dataloader = engine.gallery_query_dataloader else: raise RuntimeError("Only support gallery or query dataset") batch_feas_list = [] - img_id_list = [] - unique_id_list = [] + label_id_list = [] + camera_id_list = [] max_iter = len(dataloader) - 1 if platform.system() == "Windows" else len( dataloader) for idx, batch in enumerate(dataloader): # load is very time-consuming @@ -160,7 +155,7 @@ def cal_feature(engine, name='gallery'): batch = [paddle.to_tensor(x) for x in batch] batch[1] = batch[1].reshape([-1, 1]).astype("int64") if len(batch) == 3: - has_unique_id = True + has_camera_id = True batch[2] = batch[2].reshape([-1, 1]).astype("int64") if engine.amp and engine.amp_eval: with paddle.amp.auto_cast( @@ -183,158 +178,160 @@ def cal_feature(engine, name='gallery'): # use backbone's output as features batch_feas = out["backbone"] - # do norm + # do norm(optinal) if engine.config["Global"].get("feature_normalize", True): - feas_norm = paddle.sqrt( + batch_feas_norm = paddle.sqrt( paddle.sum(paddle.square(batch_feas), axis=1, keepdim=True)) - batch_feas = paddle.divide(batch_feas, feas_norm) + batch_feas = paddle.divide(batch_feas, batch_feas_norm) - # do binarize + # do binarize(optinal) if engine.config["Global"].get("feature_binarize") == "round": batch_feas = paddle.round(batch_feas).astype("float32") * 2.0 - 1.0 - - if engine.config["Global"].get("feature_binarize") == "sign": + elif engine.config["Global"].get("feature_binarize") == "sign": batch_feas = paddle.sign(batch_feas).astype("float32") if paddle.distributed.get_world_size() > 1: - batch_feas_gather = [] - img_id_gather = [] - unique_id_gather = [] - paddle.distributed.all_gather(batch_feas_gather, batch_feas) - paddle.distributed.all_gather(img_id_gather, batch[1]) - batch_feas_list.append(paddle.concat(batch_feas_gather)) - img_id_list.append(paddle.concat(img_id_gather)) - if has_unique_id: - paddle.distributed.all_gather(unique_id_gather, batch[2]) - unique_id_list.append(paddle.concat(unique_id_gather)) + batch_feas_list.append(all_gather(batch_feas)) + label_id_list.append(all_gather(batch[1])) + if has_camera_id: + camera_id_list.append(all_gather(batch[2])) else: batch_feas_list.append(batch_feas) - img_id_list.append(batch[1]) - if has_unique_id: - unique_id_list.append(batch[2]) + label_id_list.append(batch[1]) + if has_camera_id: + camera_id_list.append(batch[2]) if engine.use_dali: dataloader.reset() all_feas = paddle.concat(batch_feas_list) - all_img_id = paddle.concat(img_id_list) - if has_unique_id: - all_unique_id = paddle.concat(unique_id_list) + all_label_id = paddle.concat(label_id_list) + if has_camera_id: + all_camera_id = paddle.concat(camera_id_list) - # just for DistributedBatchSampler issue: repeat sampling + # discard redundant padding sample(s) at the end total_samples = len( dataloader.dataset) if not engine.use_dali else dataloader.size all_feas = all_feas[:total_samples] - all_img_id = all_img_id[:total_samples] - if has_unique_id: - all_unique_id = all_unique_id[:total_samples] + all_label_id = all_label_id[:total_samples] + if has_camera_id: + all_camera_id = all_camera_id[:total_samples] - logger.info("Build {} done, all feat shape: {}, begin to eval..".format( - name, all_feas.shape)) - return all_feas, all_img_id, all_unique_id + logger.info(f"Build {name} done, all feat shape: {all_feas.shape}") + return all_feas, all_label_id, all_camera_id -def re_ranking(query_feas: paddle.Tensor, - gallery_feas: paddle.Tensor, - k1: int=20, - k2: int=6, - lambda_value: int=0.5, - local_distmat: Optional[np.ndarray]=None, - only_local: bool=False) -> paddle.Tensor: - """re-ranking, most computed with numpy - - code heavily based on - https://github.com/michuanhaohao/reid-strong-baseline/blob/3da7e6f03164a92e696cb6da059b1cd771b0346d/utils/reid_metric.py +def k_reciprocal_neighbor(rank: np.ndarray, p: int, k: int) -> np.ndarray: + """Implementation of k-reciprocal nearest neighbors, i.e. R(p, k) Args: - query_feas (paddle.Tensor): query features, [num_query, num_features] - gallery_feas (paddle.Tensor): gallery features, [num_gallery, num_features] - k1 (int, optional): k1. Defaults to 20. - k2 (int, optional): k2. Defaults to 6. - lambda_value (int, optional): lambda. Defaults to 0.5. - local_distmat (Optional[np.ndarray], optional): local_distmat. Defaults to None. - only_local (bool, optional): only_local. Defaults to False. + rank (np.ndarray): Rank mat with shape of [N, N]. + p (int): Probe index. + k (int): Parameter k for k-reciprocal nearest neighbors algorithm. Returns: - paddle.Tensor: final_dist matrix after re-ranking, [num_query, num_gallery] + np.ndarray: K-reciprocal nearest neighbors of probe p with shape of [M, ]. """ - query_num = query_feas.shape[0] - all_num = query_num + gallery_feas.shape[0] - if only_local: - original_dist = local_distmat - else: - feat = paddle.concat([query_feas, gallery_feas]) - logger.info('using GPU to compute original distance') + # use k+1 for excluding probe index itself + forward_k_neigh_index = rank[p, :k + 1] + backward_k_neigh_index = rank[forward_k_neigh_index, :k + 1] + candidate = np.where(backward_k_neigh_index == p)[0] + return forward_k_neigh_index[candidate] + + +def compute_re_ranking_dist(query_feas: paddle.Tensor, + gallery_feas: paddle.Tensor, + feature_normed: bool=True, + k1: int=20, + k2: int=6, + lamb: float=0.5) -> paddle.Tensor: + """ + Re-ranking Person Re-identification with k-reciprocal Encoding + Reference: https://arxiv.org/abs/1701.08398 + Code refernence: https://github.com/michuanhaohao/reid-strong-baseline/blob/master/utils/re_ranking.py - # L2 distance - distmat = paddle.pow(feat, 2).sum(axis=1, keepdim=True).expand([all_num, all_num]) + \ - paddle.pow(feat, 2).sum(axis=1, keepdim=True).expand([all_num, all_num]).t() - distmat = distmat.addmm(x=feat, y=feat.t(), alpha=-2.0, beta=1.0) + Args: + query_feas (paddle.Tensor): Query features with shape of [num_query, feature_dim]. + gallery_feas (paddle.Tensor): Gallery features with shape of [num_gallery, feature_dim]. + feature_normed (bool, optional): Whether input features are normalized. + k1 (int, optional): Parameter for K-reciprocal nearest neighbors. Defaults to 20. + k2 (int, optional): Parameter for K-nearest neighbors. Defaults to 6. + lamb (float, optional): Penalty factor. Defaults to 0.5. - original_dist = distmat.cpu().numpy() - del feat - if local_distmat is not None: - original_dist = original_dist + local_distmat + Returns: + paddle.Tensor: (1 - lamb) x Dj + lamb x D, with shape of [num_query, num_gallery]. + """ + num_query = query_feas.shape[0] + num_gallery = gallery_feas.shape[0] + num_all = num_query + num_gallery + feat = paddle.concat([query_feas, gallery_feas], 0) + logger.info("Using GPU to compute original distance matrix") + + # use L2 distance + if feature_normed: + original_dist = 2 - 2 * paddle.matmul(feat, feat, transpose_y=True) + else: + original_dist = paddle.pow(feat, 2).sum(axis=1, keepdim=True).expand([num_all, num_all]) + \ + paddle.pow(feat, 2).sum(axis=1, keepdim=True).expand([num_all, num_all]).t() + original_dist = original_dist.addmm( + x=feat, y=feat.t(), alpha=-2.0, beta=1.0) + original_dist = original_dist.numpy() + del feat - gallery_num = original_dist.shape[0] original_dist = np.transpose(original_dist / np.max(original_dist, axis=0)) V = np.zeros_like(original_dist).astype(np.float16) - initial_rank = np.argsort(original_dist).astype(np.int32) - logger.info('starting re_ranking') - for i in range(all_num): - # k-reciprocal neighbors - forward_k_neigh_index = initial_rank[i, :k1 + 1] - backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1] - fi = np.where(backward_k_neigh_index == i)[0] - k_reciprocal_index = forward_k_neigh_index[fi] - k_reciprocal_expansion_index = k_reciprocal_index - for j in range(len(k_reciprocal_index)): - candidate = k_reciprocal_index[j] - candidate_forward_k_neigh_index = initial_rank[candidate, :int( - np.around(k1 / 2)) + 1] - candidate_backward_k_neigh_index = initial_rank[ - candidate_forward_k_neigh_index, :int(np.around(k1 / 2)) + 1] - fi_candidate = np.where( - candidate_backward_k_neigh_index == candidate)[0] - candidate_k_reciprocal_index = candidate_forward_k_neigh_index[ - fi_candidate] - if len( - np.intersect1d(candidate_k_reciprocal_index, - k_reciprocal_index)) > 2 / 3 * len( - candidate_k_reciprocal_index): - k_reciprocal_expansion_index = np.append( - k_reciprocal_expansion_index, candidate_k_reciprocal_index) - - k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) - weight = np.exp(-original_dist[i, k_reciprocal_expansion_index]) - V[i, k_reciprocal_expansion_index] = weight / np.sum(weight) - original_dist = original_dist[:query_num, ] - if k2 != 1: + initial_rank = np.argpartition(original_dist, range(1, k1 + 1)) + logger.info("Start re-ranking...") + + for p in range(num_all): + # compute R(p,k1) + p_k_reciprocal_ind = k_reciprocal_neighbor(initial_rank, p, k1) + + # compute R*(p,k1)=R(p,k1)∪R(q,k1/2) + # s.t. |R(p,k1)∩R(q,k1/2)|>=2/3|R(q,k1/2)|, ∀q∈R(p,k1) + p_k_reciprocal_exp_ind = p_k_reciprocal_ind + for _, q in enumerate(p_k_reciprocal_ind): + q_k_reciprocal_ind = k_reciprocal_neighbor(initial_rank, q, + int(np.around(k1 / 2))) + if len(np.intersect1d(p_k_reciprocal_ind, q_k_reciprocal_ind) + ) > 2 / 3 * len(q_k_reciprocal_ind): + p_k_reciprocal_exp_ind = np.append(p_k_reciprocal_exp_ind, + q_k_reciprocal_ind) + p_k_reciprocal_exp_ind = np.unique(p_k_reciprocal_exp_ind) + + # reweight distance using gaussian kernel + weight = np.exp(-original_dist[p, p_k_reciprocal_exp_ind]) + V[p, p_k_reciprocal_exp_ind] = weight / np.sum(weight) + + # local query expansion + original_dist = original_dist[:num_query, ] + if k2 > 1: V_qe = np.zeros_like(V, dtype=np.float16) - for i in range(all_num): - V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0) + for p in range(num_all): + V_qe[p, :] = np.mean(V[initial_rank[p, :k2], :], axis=0) V = V_qe del V_qe del initial_rank + + # cache k-reciprocal sets which contains gj invIndex = [] - for i in range(gallery_num): - invIndex.append(np.where(V[:, i] != 0)[0]) + for gj in range(num_all): + invIndex.append(np.nonzero(V[:, gj])[0]) jaccard_dist = np.zeros_like(original_dist, dtype=np.float16) - for i in range(query_num): - temp_min = np.zeros(shape=[1, gallery_num], dtype=np.float16) - indNonZero = np.where(V[i, :] != 0)[0] - indImages = [invIndex[ind] for ind in indNonZero] - for j in range(len(indNonZero)): - temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum( - V[i, indNonZero[j]], V[indImages[j], indNonZero[j]]) - jaccard_dist[i] = 1 - temp_min / (2 - temp_min) - - final_dist = jaccard_dist * (1 - lambda_value - ) + original_dist * lambda_value + for p in range(num_query): + sum_min = np.zeros(shape=[1, num_all], dtype=np.float16) + gj_ind = np.nonzero(V[p, :])[0] + gj_ind_inv = [invIndex[gj] for gj in gj_ind] + for j, gj in enumerate(gj_ind): + gi = gj_ind_inv[j] + sum_min[0, gi] += np.minimum(V[p, gj], V[gi, gj]) + jaccard_dist[p] = 1 - sum_min / (2 - sum_min) + + final_dist = jaccard_dist * (1 - lamb) + original_dist * lamb del original_dist del V del jaccard_dist - final_dist = final_dist[:query_num, query_num:] + final_dist = final_dist[:num_query, num_query:] final_dist = paddle.to_tensor(final_dist) return final_dist diff --git a/ppcls/utils/__init__.py b/ppcls/utils/__init__.py index 632cc78824d51d5adae9315fda8fccde50eda73a..f9307ffd27a9ab0c1f4bab04ca6b21b9f21098e4 100644 --- a/ppcls/utils/__init__.py +++ b/ppcls/utils/__init__.py @@ -13,15 +13,16 @@ # limitations under the License. from . import logger +from . import metrics from . import misc from . import model_zoo -from . import metrics -from .save_load import init_model, save_model from .config import get_config -from .misc import AverageMeter -from .metrics import multi_hot_encode -from .metrics import hamming_distance +from .dist_utils import all_gather from .metrics import accuracy_score -from .metrics import precision_recall_fscore +from .metrics import hamming_distance from .metrics import mean_average_precision +from .metrics import multi_hot_encode +from .metrics import precision_recall_fscore +from .misc import AverageMeter +from .save_load import init_model, save_model diff --git a/ppcls/utils/dist_utils.py b/ppcls/utils/dist_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6b7d889d6585ff5a59c28806cb18fbe767d03316 --- /dev/null +++ b/ppcls/utils/dist_utils.py @@ -0,0 +1,36 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Union + +import paddle + + +def all_gather(tensor: paddle.Tensor, concat: bool=True, + axis: int=0) -> Union[paddle.Tensor, List[paddle.Tensor]]: + """Gather tensor from all devices, concatenate them along given axis if specified. + + Args: + tensor (paddle.Tensor): Tensor to be gathered from all GPUs. + concat (bool, optional): Whether to concatenate gathered Tensors. Defaults to True. + axis (int, optional): Axis which concatenated along. Defaults to 0. + + Returns: + Union[paddle.Tensor, List[paddle.Tensor]]: Gathered Tensors + """ + result = [] + paddle.distributed.all_gather(result, tensor) + if concat: + return paddle.concat(result, axis) + return result