From c03a66bfe43f2cbb092d375c3116ac02a3b42f6d Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Fri, 26 Aug 2022 06:16:11 +0000 Subject: [PATCH] Rename variable names that may be confused in retrieval.py --- ppcls/engine/evaluation/retrieval.py | 40 ++++++++++++++++------------ 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/ppcls/engine/evaluation/retrieval.py b/ppcls/engine/evaluation/retrieval.py index 02cae167..ef4bbd24 100644 --- a/ppcls/engine/evaluation/retrieval.py +++ b/ppcls/engine/evaluation/retrieval.py @@ -25,32 +25,35 @@ from ppcls.utils import logger def retrieval_eval(engine, epoch_id=0): engine.model.eval() - # step1. build gallery + # step1. build query & gallery if engine.gallery_query_dataloader is not None: gallery_feas, gallery_img_id, gallery_unique_id = cal_feature( engine, name='gallery_query') - query_feas, query_img_id, query_query_id = gallery_feas, gallery_img_id, gallery_unique_id + query_feas, query_img_id, query_unique_id = gallery_feas, gallery_img_id, gallery_unique_id else: gallery_feas, gallery_img_id, gallery_unique_id = cal_feature( engine, name='gallery') - query_feas, query_img_id, query_query_id = cal_feature( + query_feas, query_img_id, query_unique_id = cal_feature( engine, name='query') - # step2. do evaluation + # step2. split data into blocks so as to save memory sim_block_size = engine.config["Global"].get("sim_block_size", 64) sections = [sim_block_size] * (len(query_feas) // sim_block_size) if len(query_feas) % sim_block_size: sections.append(len(query_feas) % sim_block_size) + fea_blocks = paddle.split(query_feas, num_or_sections=sections) - if query_query_id is not None: - query_id_blocks = paddle.split( - query_query_id, num_or_sections=sections) - image_id_blocks = paddle.split(query_img_id, num_or_sections=sections) + if query_unique_id is not None: + query_unique_id_blocks = paddle.split( + query_unique_id, num_or_sections=sections) + query_img_id_blocks = paddle.split(query_img_id, num_or_sections=sections) metric_key = None + # step3. do evaluation if engine.eval_loss_func is None: metric_dict = {metric_key: 0.} else: + # do evaluation with re-ranking(k-reciprocal) reranking_flag = engine.config['Global'].get('re_ranking', False) logger.info(f"re_ranking={reranking_flag}") metric_dict = dict() @@ -70,9 +73,9 @@ def retrieval_eval(engine, epoch_id=0): query_feas, gallery_feas, k1=20, k2=6, lambda_value=0.3) # compute keep mask - query_id_mask = (query_query_id != gallery_unique_id.t()) + unique_id_mask = (query_unique_id != gallery_unique_id.t()) image_id_mask = (query_img_id != gallery_img_id.t()) - keep_mask = paddle.logical_or(query_id_mask, image_id_mask) + keep_mask = paddle.logical_or(image_id_mask, unique_id_mask) # set inf(1e9) distance to those exist in gallery distmat = distmat * keep_mask.astype("float32") @@ -85,24 +88,27 @@ def retrieval_eval(engine, epoch_id=0): for key in metric_tmp: metric_dict[key] = metric_tmp[key] else: + # do evaluation without re-ranking for block_idx, block_fea in enumerate(fea_blocks): similarity_matrix = paddle.matmul( block_fea, gallery_feas, transpose_y=True) # [n,m] - if query_query_id is not None: - query_id_block = query_id_blocks[block_idx] - query_id_mask = (query_id_block != gallery_unique_id.t()) + if query_unique_id is not None: + query_unique_id_block = query_unique_id_blocks[block_idx] + unique_id_mask = ( + query_unique_id_block != gallery_unique_id.t()) - image_id_block = image_id_blocks[block_idx] - image_id_mask = (image_id_block != gallery_img_id.t()) + query_img_id_block = query_img_id_blocks[block_idx] + image_id_mask = (query_img_id_block != gallery_img_id.t()) - keep_mask = paddle.logical_or(query_id_mask, image_id_mask) + keep_mask = paddle.logical_or(image_id_mask, + unique_id_mask) similarity_matrix = similarity_matrix * keep_mask.astype( "float32") else: keep_mask = None metric_tmp = engine.eval_metric_func( - similarity_matrix, image_id_blocks[block_idx], + similarity_matrix, query_img_id_blocks[block_idx], gallery_img_id, keep_mask) for key in metric_tmp: -- GitLab