diff --git a/ppcls/engine/evaluation/retrieval.py b/ppcls/engine/evaluation/retrieval.py index 6fd0559f13042bb8de546ec8b42d4d08542a6bf1..1b99237f4c23806c58787291916b69c9856c30ec 100644 --- a/ppcls/engine/evaluation/retrieval.py +++ b/ppcls/engine/evaluation/retrieval.py @@ -20,11 +20,10 @@ from collections import defaultdict import numpy as np import paddle +import scipy from ppcls.utils import all_gather, logger -# from is_sorted import is_sorted - def retrieval_eval(engine, epoch_id=0): engine.model.eval() @@ -38,12 +37,7 @@ def retrieval_eval(engine, epoch_id=0): engine, "gallery") query_feat, query_label, query_camera = compute_feature(engine, "query") - # gallery_feat = gallery_feat[:50] - # gallery_label = gallery_label[:50] - # gallery_camera = gallery_camera[:50] - # query_feat = query_feat[:20] - # query_label = query_label[:20] - # query_camera = query_camera[:20] + # step2. split features into feature blocks for saving memory num_query = len(query_feat) block_size = engine.config["Global"].get("sim_block_size", 64) @@ -255,8 +249,6 @@ def compute_re_ranking_dist(query_feat: paddle.Tensor, num_all = num_query + num_gallery feat = paddle.concat([query_feat, gallery_feat], 0) logger.info("Using GPU to compute original distance matrix") - import time - t = time.perf_counter() # use L2 distance if feature_normed: original_dist = 2 - 2 * paddle.matmul(feat, feat, transpose_y=True) @@ -264,24 +256,12 @@ def compute_re_ranking_dist(query_feat: paddle.Tensor, original_dist = paddle.pow(feat, 2).sum(axis=1, keepdim=True).expand([num_all, num_all]) + \ paddle.pow(feat, 2).sum(axis=1, keepdim=True).expand([num_all, num_all]).t() original_dist = original_dist.addmm(feat, feat.t(), -2.0, 1.0) - print(f"t1.cost = {time.perf_counter() - t}") - t = time.perf_counter() original_dist = original_dist.numpy() - print(f"t2.cost = {time.perf_counter() - t}") - t = time.perf_counter() del feat - print(f"t3.cost = {time.perf_counter() - t}") - t = time.perf_counter() original_dist = np.transpose(original_dist / np.max(original_dist, axis=0)) - print(f"t4.cost = {time.perf_counter() - t}") - t = time.perf_counter() V = np.zeros_like(original_dist).astype(np.float16) - print(f"t5.cost = {time.perf_counter() - t}") - t = time.perf_counter() initial_rank = np.argpartition(original_dist, range(1, k1 + 1)) # 22.2s - print(f"t6.cost = {time.perf_counter() - t}") - t = time.perf_counter() logger.info("Start re-ranking...") for p in range(num_all): @@ -305,46 +285,38 @@ def compute_re_ranking_dist(query_feat: paddle.Tensor, # reweight distance using gaussian kernel weight = np.exp(-original_dist[p, p_k_reciprocal_exp_ind]) V[p, p_k_reciprocal_exp_ind] = weight / np.sum(weight) - print(f"t7.cost = {time.perf_counter() - t}") # 9.2s/15.3s(无) - t = time.perf_counter() # local query expansion original_dist = original_dist[:num_query, ] - print(f"t8.cost = {time.perf_counter() - t}") - t = time.perf_counter() if k2 > 1: - V_qe = np.zeros_like(V, dtype=np.float16) - V_qe_t = paddle.to_tensor(V, dtype="float32") - indices = np.stack([ - np.repeat(np.arange(num_all), k2), - initial_rank[:, :k2].reshape([-1, ]) - ]) # [2, nnz] - values = np.array( - [1 / k2 for _ in range(num_all * k2)], dtype="float32") - Lmat = paddle.sparse.sparse_coo_tensor(indices, values, - original_dist.shape) - V = paddle.sparse.matmul(Lmat, V_qe_t).numpy() - # for p in range(num_all): - # V_qe[p, :] = np.mean(V[initial_rank[p, :k2], :], axis=0) - # V = V_qe - del V_qe - print(f"t9.cost = {time.perf_counter() - t}") # 54.6s - t = time.perf_counter() + try: + # use sparse tensor to speed up query expansion + indices = (np.repeat(np.arange(num_all), k2), + initial_rank[:, :k2].reshape([-1, ])) + values = np.array( + [1 / k2 for _ in range(num_all * k2)], dtype="float16") + V = scipy.sparse.coo_matrix( + (values, indices), V.shape, + dtype="float16") @V.astype("float16") + except Exception as e: + logger.info( + f"Failed to do local query expansion with sparse tensor for reason: \n{e}\n" + f"now use for-loop instead") + # use vanilla for-loop + V_qe = np.zeros_like(V, dtype=np.float16) + for i in range(num_all): + V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0) + V = V_qe + del V_qe del initial_rank - print(f"t10.cost = {time.perf_counter() - t}") - t = time.perf_counter() # cache k-reciprocal sets which contains gj invIndex = [] for gj in range(num_all): invIndex.append(np.nonzero(V[:, gj])[0]) - print(f"t11.cost = {time.perf_counter() - t}") - t = time.perf_counter() # compute jaccard distance jaccard_dist = np.zeros_like(original_dist, dtype=np.float16) - print(f"t12.cost = {time.perf_counter() - t}") - t = time.perf_counter() for p in range(num_query): sum_min = np.zeros(shape=[1, num_all], dtype=np.float16) gj_ind = np.nonzero(V[p, :])[0] @@ -353,8 +325,6 @@ def compute_re_ranking_dist(query_feat: paddle.Tensor, gi = gj_ind_inv[j] sum_min[0, gi] += np.minimum(V[p, gj], V[gi, gj]) jaccard_dist[p] = 1 - sum_min / (2 - sum_min) - print(f"t13.cost = {time.perf_counter() - t}") - t = time.perf_counter() # fuse jaccard distance with original distance final_dist = (1 - lamb) * jaccard_dist + lamb * original_dist @@ -363,6 +333,4 @@ def compute_re_ranking_dist(query_feat: paddle.Tensor, del jaccard_dist final_dist = final_dist[:num_query, num_query:] final_dist = paddle.to_tensor(final_dist) - print(f"t14.cost = {time.perf_counter() - t}") - t = time.perf_counter() return final_dist