提交 1a1eb3a1 编写于 作者: D dongshuilong

fix issues when gallery == query dataset

上级 ec5e07da
......@@ -54,8 +54,9 @@ def create_operators(params):
def build_dataloader(config, mode, device, use_dali=False, seed=None):
assert mode in ['Train', 'Eval', 'Test', 'Gallery', 'Query'
], "Mode should be Train, Eval, Test, Gallery, Query"
assert mode in [
'Train', 'Eval', 'Test', 'Gallery', 'Query'
], "Dataset mode should be Train, Eval, Test, Gallery, Query"
# build dataset
if use_dali:
from ppcls.data.dataloader.dali import dali_dataloader
......
......@@ -106,12 +106,19 @@ class Engine(object):
self.config["DataLoader"], "Eval", self.device,
self.use_dali)
elif self.eval_mode == "retrieval":
self.gallery_dataloader = build_dataloader(
self.config["DataLoader"]["Eval"], "Gallery", self.device,
self.gallery_query_dataloader = None
if len(self.config["DataLoader"]["Eval"].keys()) == 1:
key = list(self.config["DataLoader"]["Eval"].keys())[0]
self.gallery_query_dataloader = build_dataloader(
self.config["DataLoader"]["Eval"], key, self.device,
self.use_dali)
else:
self.gallery_dataloader = build_dataloader(
self.config["DataLoader"]["Eval"], "Gallery",
self.device, self.use_dali)
self.query_dataloader = build_dataloader(
self.config["DataLoader"]["Eval"], "Query", self.device,
self.use_dali)
self.config["DataLoader"]["Eval"], "Query",
self.device, self.use_dali)
# build loss
if self.mode == "train":
......
......@@ -23,6 +23,11 @@ from ppcls.utils import logger
def retrieval_eval(evaler, epoch_id=0):
evaler.model.eval()
# step1. build gallery
if evaler.gallery_query_dataloader is not None:
gallery_feas, gallery_img_id, gallery_unique_id = cal_feature(
evaler, name='gallery_query')
query_feas, query_img_id, query_query_id = gallery_feas, gallery_img_id, gallery_unique_id
else:
gallery_feas, gallery_img_id, gallery_unique_id = cal_feature(
evaler, name='gallery')
query_feas, query_img_id, query_query_id = cal_feature(
......@@ -93,6 +98,8 @@ def cal_feature(evaler, name='gallery'):
dataloader = evaler.gallery_dataloader
elif name == 'query':
dataloader = evaler.query_dataloader
elif name == 'gallery_query':
dataloader = evaler.gallery_query_dataloader
else:
raise RuntimeError("Only support gallery or query dataset")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册