提交 f525cea0 编写于 作者: G gaotingquan 提交者: Wei Shengyu

replace the arg engine with config

上级 e4a3e1bb
......@@ -220,23 +220,21 @@ class DataIterator(object):
return batch
def build_dataloader(engine):
if "class_num" in engine.config["Global"]:
global_class_num = engine.config["Global"]["class_num"]
def build_dataloader(config, mode):
if "class_num" in config["Global"]:
global_class_num = config["Global"]["class_num"]
if "class_num" not in config["Arch"]:
engine.config["Arch"]["class_num"] = global_class_num
config["Arch"]["class_num"] = global_class_num
msg = f"The Global.class_num will be deprecated. Please use Arch.class_num instead. Arch.class_num has been set to {global_class_num}."
else:
msg = "The Global.class_num will be deprecated. Please use Arch.class_num instead. The Global.class_num has been ignored."
logger.warning(msg)
class_num = engine.config["Arch"].get("class_num", None)
engine.config["DataLoader"].update({"class_num": class_num})
engine.config["DataLoader"].update({
"epochs": engine.config["Global"]["epochs"]
})
class_num = config["Arch"].get("class_num", None)
config["DataLoader"].update({"class_num": class_num})
config["DataLoader"].update({"epochs": config["Global"]["epochs"]})
use_dali = engine.use_dali
use_dali = config["Global"].get("use_dali", False)
dataloader_dict = {
"Train": None,
"UnLabelTrain": None,
......@@ -245,37 +243,37 @@ def build_dataloader(engine):
"Gallery": None,
"GalleryQuery": None
}
if engine.mode == 'train':
if mode == 'train':
train_dataloader = build(
engine.config["DataLoader"], "Train", use_dali, seed=None)
config["DataLoader"], "Train", use_dali, seed=None)
if engine.config["DataLoader"]["Train"].get("max_iter", None):
if config["DataLoader"]["Train"].get("max_iter", None):
# set max iteration per epoch mannualy, when training by iteration(s), such as XBM, FixMatch.
max_iter = engine.config["Train"].get("max_iter")
max_iter = train_dataloader.max_iter // engine.update_freq * engine.update_freq
max_iter = config["Train"].get("max_iter")
update_freq = config["Global"].get("update_freq", 1)
max_iter = train_dataloader.max_iter // update_freq * update_freq
train_dataloader.max_iter = max_iter
if engine.config["DataLoader"]["Train"].get("convert_iterator", True):
if config["DataLoader"]["Train"].get("convert_iterator", True):
train_dataloader = DataIterator(train_dataloader, use_dali)
dataloader_dict["Train"] = train_dataloader
if engine.config["DataLoader"].get('UnLabelTrain', None) is not None:
if config["DataLoader"].get('UnLabelTrain', None) is not None:
dataloader_dict["UnLabelTrain"] = build(
engine.config["DataLoader"], "UnLabelTrain", use_dali, seed=None)
config["DataLoader"], "UnLabelTrain", use_dali, seed=None)
if engine.mode == "eval" or (engine.mode == "train" and
engine.config["Global"]["eval_during_train"]):
if engine.config["Global"][
"eval_mode"] in ["classification", "adaface"]:
if mode == "eval" or (mode == "train" and
config["Global"]["eval_during_train"]):
if config["Global"]["eval_mode"] in ["classification", "adaface"]:
dataloader_dict["Eval"] = build(
engine.config["DataLoader"], "Eval", use_dali, seed=None)
elif engine.config["Global"]["eval_mode"] == "retrieval":
if len(engine.config["DataLoader"]["Eval"].keys()) == 1:
key = list(engine.config["DataLoader"]["Eval"].keys())[0]
config["DataLoader"], "Eval", use_dali, seed=None)
elif config["Global"]["eval_mode"] == "retrieval":
if len(config["DataLoader"]["Eval"].keys()) == 1:
key = list(config["DataLoader"]["Eval"].keys())[0]
dataloader_dict["GalleryQuery"] = build(
engine.config["DataLoader"]["Eval"], key, use_dali)
config["DataLoader"]["Eval"], key, use_dali)
else:
dataloader_dict["Gallery"] = build(
engine.config["DataLoader"]["Eval"], "Gallery", use_dali)
dataloader_dict["Query"] = build(
engine.config["DataLoader"]["Eval"], "Query", use_dali)
config["DataLoader"]["Eval"], "Gallery", use_dali)
dataloader_dict["Query"] = build(config["DataLoader"]["Eval"],
"Query", use_dali)
return dataloader_dict
......@@ -76,7 +76,7 @@ class Engine(object):
# build dataloader
self.use_dali = self.config["Global"].get("use_dali", False)
self.dataloader_dict = build_dataloader(self)
self.dataloader_dict = build_dataloader(self.config, mode)
# build loss
self.train_loss_func, self.unlabel_train_loss_func, self.eval_loss_func = build_loss(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册