提交 c03a66bf 编写于 作者: H HydrogenSulfate

Rename variable names that may be confused in retrieval.py

上级 2f5c0c71
......@@ -25,32 +25,35 @@ from ppcls.utils import logger
def retrieval_eval(engine, epoch_id=0):
engine.model.eval()
# step1. build gallery
# step1. build query & gallery
if engine.gallery_query_dataloader is not None:
gallery_feas, gallery_img_id, gallery_unique_id = cal_feature(
engine, name='gallery_query')
query_feas, query_img_id, query_query_id = gallery_feas, gallery_img_id, gallery_unique_id
query_feas, query_img_id, query_unique_id = gallery_feas, gallery_img_id, gallery_unique_id
else:
gallery_feas, gallery_img_id, gallery_unique_id = cal_feature(
engine, name='gallery')
query_feas, query_img_id, query_query_id = cal_feature(
query_feas, query_img_id, query_unique_id = cal_feature(
engine, name='query')
# step2. do evaluation
# step2. split data into blocks so as to save memory
sim_block_size = engine.config["Global"].get("sim_block_size", 64)
sections = [sim_block_size] * (len(query_feas) // sim_block_size)
if len(query_feas) % sim_block_size:
sections.append(len(query_feas) % sim_block_size)
fea_blocks = paddle.split(query_feas, num_or_sections=sections)
if query_query_id is not None:
query_id_blocks = paddle.split(
query_query_id, num_or_sections=sections)
image_id_blocks = paddle.split(query_img_id, num_or_sections=sections)
if query_unique_id is not None:
query_unique_id_blocks = paddle.split(
query_unique_id, num_or_sections=sections)
query_img_id_blocks = paddle.split(query_img_id, num_or_sections=sections)
metric_key = None
# step3. do evaluation
if engine.eval_loss_func is None:
metric_dict = {metric_key: 0.}
else:
# do evaluation with re-ranking(k-reciprocal)
reranking_flag = engine.config['Global'].get('re_ranking', False)
logger.info(f"re_ranking={reranking_flag}")
metric_dict = dict()
......@@ -70,9 +73,9 @@ def retrieval_eval(engine, epoch_id=0):
query_feas, gallery_feas, k1=20, k2=6, lambda_value=0.3)
# compute keep mask
query_id_mask = (query_query_id != gallery_unique_id.t())
unique_id_mask = (query_unique_id != gallery_unique_id.t())
image_id_mask = (query_img_id != gallery_img_id.t())
keep_mask = paddle.logical_or(query_id_mask, image_id_mask)
keep_mask = paddle.logical_or(image_id_mask, unique_id_mask)
# set inf(1e9) distance to those exist in gallery
distmat = distmat * keep_mask.astype("float32")
......@@ -85,24 +88,27 @@ def retrieval_eval(engine, epoch_id=0):
for key in metric_tmp:
metric_dict[key] = metric_tmp[key]
else:
# do evaluation without re-ranking
for block_idx, block_fea in enumerate(fea_blocks):
similarity_matrix = paddle.matmul(
block_fea, gallery_feas, transpose_y=True) # [n,m]
if query_query_id is not None:
query_id_block = query_id_blocks[block_idx]
query_id_mask = (query_id_block != gallery_unique_id.t())
if query_unique_id is not None:
query_unique_id_block = query_unique_id_blocks[block_idx]
unique_id_mask = (
query_unique_id_block != gallery_unique_id.t())
image_id_block = image_id_blocks[block_idx]
image_id_mask = (image_id_block != gallery_img_id.t())
query_img_id_block = query_img_id_blocks[block_idx]
image_id_mask = (query_img_id_block != gallery_img_id.t())
keep_mask = paddle.logical_or(query_id_mask, image_id_mask)
keep_mask = paddle.logical_or(image_id_mask,
unique_id_mask)
similarity_matrix = similarity_matrix * keep_mask.astype(
"float32")
else:
keep_mask = None
metric_tmp = engine.eval_metric_func(
similarity_matrix, image_id_blocks[block_idx],
similarity_matrix, query_img_id_blocks[block_idx],
gallery_img_id, keep_mask)
for key in metric_tmp:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册