提交 1a1eb3a1 编写于 作者: D dongshuilong

fix issues when gallery == query dataset

上级 ec5e07da
...@@ -54,8 +54,9 @@ def create_operators(params): ...@@ -54,8 +54,9 @@ def create_operators(params):
def build_dataloader(config, mode, device, use_dali=False, seed=None): def build_dataloader(config, mode, device, use_dali=False, seed=None):
assert mode in ['Train', 'Eval', 'Test', 'Gallery', 'Query' assert mode in [
], "Mode should be Train, Eval, Test, Gallery, Query" 'Train', 'Eval', 'Test', 'Gallery', 'Query'
], "Dataset mode should be Train, Eval, Test, Gallery, Query"
# build dataset # build dataset
if use_dali: if use_dali:
from ppcls.data.dataloader.dali import dali_dataloader from ppcls.data.dataloader.dali import dali_dataloader
......
...@@ -106,12 +106,19 @@ class Engine(object): ...@@ -106,12 +106,19 @@ class Engine(object):
self.config["DataLoader"], "Eval", self.device, self.config["DataLoader"], "Eval", self.device,
self.use_dali) self.use_dali)
elif self.eval_mode == "retrieval": elif self.eval_mode == "retrieval":
self.gallery_dataloader = build_dataloader( self.gallery_query_dataloader = None
self.config["DataLoader"]["Eval"], "Gallery", self.device, if len(self.config["DataLoader"]["Eval"].keys()) == 1:
self.use_dali) key = list(self.config["DataLoader"]["Eval"].keys())[0]
self.query_dataloader = build_dataloader( self.gallery_query_dataloader = build_dataloader(
self.config["DataLoader"]["Eval"], "Query", self.device, self.config["DataLoader"]["Eval"], key, self.device,
self.use_dali) self.use_dali)
else:
self.gallery_dataloader = build_dataloader(
self.config["DataLoader"]["Eval"], "Gallery",
self.device, self.use_dali)
self.query_dataloader = build_dataloader(
self.config["DataLoader"]["Eval"], "Query",
self.device, self.use_dali)
# build loss # build loss
if self.mode == "train": if self.mode == "train":
......
...@@ -23,10 +23,15 @@ from ppcls.utils import logger ...@@ -23,10 +23,15 @@ from ppcls.utils import logger
def retrieval_eval(evaler, epoch_id=0): def retrieval_eval(evaler, epoch_id=0):
evaler.model.eval() evaler.model.eval()
# step1. build gallery # step1. build gallery
gallery_feas, gallery_img_id, gallery_unique_id = cal_feature( if evaler.gallery_query_dataloader is not None:
evaler, name='gallery') gallery_feas, gallery_img_id, gallery_unique_id = cal_feature(
query_feas, query_img_id, query_query_id = cal_feature( evaler, name='gallery_query')
evaler, name='query') query_feas, query_img_id, query_query_id = gallery_feas, gallery_img_id, gallery_unique_id
else:
gallery_feas, gallery_img_id, gallery_unique_id = cal_feature(
evaler, name='gallery')
query_feas, query_img_id, query_query_id = cal_feature(
evaler, name='query')
# step2. do evaluation # step2. do evaluation
sim_block_size = evaler.config["Global"].get("sim_block_size", 64) sim_block_size = evaler.config["Global"].get("sim_block_size", 64)
...@@ -93,6 +98,8 @@ def cal_feature(evaler, name='gallery'): ...@@ -93,6 +98,8 @@ def cal_feature(evaler, name='gallery'):
dataloader = evaler.gallery_dataloader dataloader = evaler.gallery_dataloader
elif name == 'query': elif name == 'query':
dataloader = evaler.query_dataloader dataloader = evaler.query_dataloader
elif name == 'gallery_query':
dataloader = evaler.gallery_query_dataloader
else: else:
raise RuntimeError("Only support gallery or query dataset") raise RuntimeError("Only support gallery or query dataset")
...@@ -124,7 +131,7 @@ def cal_feature(evaler, name='gallery'): ...@@ -124,7 +131,7 @@ def cal_feature(evaler, name='gallery'):
feas_norm = paddle.sqrt( feas_norm = paddle.sqrt(
paddle.sum(paddle.square(batch_feas), axis=1, keepdim=True)) paddle.sum(paddle.square(batch_feas), axis=1, keepdim=True))
batch_feas = paddle.divide(batch_feas, feas_norm) batch_feas = paddle.divide(batch_feas, feas_norm)
# do binarize # do binarize
if evaler.config["Global"].get("feature_binarize") == "round": if evaler.config["Global"].get("feature_binarize") == "round":
batch_feas = paddle.round(batch_feas).astype("float32") * 2.0 - 1.0 batch_feas = paddle.round(batch_feas).astype("float32") * 2.0 - 1.0
...@@ -142,10 +149,10 @@ def cal_feature(evaler, name='gallery'): ...@@ -142,10 +149,10 @@ def cal_feature(evaler, name='gallery'):
all_image_id = paddle.concat([all_image_id, batch[1]]) all_image_id = paddle.concat([all_image_id, batch[1]])
if has_unique_id: if has_unique_id:
all_unique_id = paddle.concat([all_unique_id, batch[2]]) all_unique_id = paddle.concat([all_unique_id, batch[2]])
if evaler.use_dali: if evaler.use_dali:
dataloader_tmp.reset() dataloader_tmp.reset()
if paddle.distributed.get_world_size() > 1: if paddle.distributed.get_world_size() > 1:
feat_list = [] feat_list = []
img_id_list = [] img_id_list = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册