未验证 提交 ce39aea9 编写于 作者: W Walter 提交者: GitHub

Merge pull request #1187 from RainFrost1/retrieval_dataloader

fix issues when gallery == query dataset
...@@ -54,8 +54,9 @@ def create_operators(params): ...@@ -54,8 +54,9 @@ def create_operators(params):
def build_dataloader(config, mode, device, use_dali=False, seed=None): def build_dataloader(config, mode, device, use_dali=False, seed=None):
assert mode in ['Train', 'Eval', 'Test', 'Gallery', 'Query' assert mode in [
], "Mode should be Train, Eval, Test, Gallery, Query" 'Train', 'Eval', 'Test', 'Gallery', 'Query'
], "Dataset mode should be Train, Eval, Test, Gallery, Query"
# build dataset # build dataset
if use_dali: if use_dali:
from ppcls.data.dataloader.dali import dali_dataloader from ppcls.data.dataloader.dali import dali_dataloader
......
...@@ -111,12 +111,19 @@ class Engine(object): ...@@ -111,12 +111,19 @@ class Engine(object):
self.config["DataLoader"], "Eval", self.device, self.config["DataLoader"], "Eval", self.device,
self.use_dali) self.use_dali)
elif self.eval_mode == "retrieval": elif self.eval_mode == "retrieval":
self.gallery_dataloader = build_dataloader( self.gallery_query_dataloader = None
self.config["DataLoader"]["Eval"], "Gallery", self.device, if len(self.config["DataLoader"]["Eval"].keys()) == 1:
key = list(self.config["DataLoader"]["Eval"].keys())[0]
self.gallery_query_dataloader = build_dataloader(
self.config["DataLoader"]["Eval"], key, self.device,
self.use_dali) self.use_dali)
else:
self.gallery_dataloader = build_dataloader(
self.config["DataLoader"]["Eval"], "Gallery",
self.device, self.use_dali)
self.query_dataloader = build_dataloader( self.query_dataloader = build_dataloader(
self.config["DataLoader"]["Eval"], "Query", self.device, self.config["DataLoader"]["Eval"], "Query",
self.use_dali) self.device, self.use_dali)
# build loss # build loss
if self.mode == "train": if self.mode == "train":
......
...@@ -23,6 +23,11 @@ from ppcls.utils import logger ...@@ -23,6 +23,11 @@ from ppcls.utils import logger
def retrieval_eval(evaler, epoch_id=0): def retrieval_eval(evaler, epoch_id=0):
evaler.model.eval() evaler.model.eval()
# step1. build gallery # step1. build gallery
if evaler.gallery_query_dataloader is not None:
gallery_feas, gallery_img_id, gallery_unique_id = cal_feature(
evaler, name='gallery_query')
query_feas, query_img_id, query_query_id = gallery_feas, gallery_img_id, gallery_unique_id
else:
gallery_feas, gallery_img_id, gallery_unique_id = cal_feature( gallery_feas, gallery_img_id, gallery_unique_id = cal_feature(
evaler, name='gallery') evaler, name='gallery')
query_feas, query_img_id, query_query_id = cal_feature( query_feas, query_img_id, query_query_id = cal_feature(
...@@ -93,6 +98,8 @@ def cal_feature(evaler, name='gallery'): ...@@ -93,6 +98,8 @@ def cal_feature(evaler, name='gallery'):
dataloader = evaler.gallery_dataloader dataloader = evaler.gallery_dataloader
elif name == 'query': elif name == 'query':
dataloader = evaler.query_dataloader dataloader = evaler.query_dataloader
elif name == 'gallery_query':
dataloader = evaler.gallery_query_dataloader
else: else:
raise RuntimeError("Only support gallery or query dataset") raise RuntimeError("Only support gallery or query dataset")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册