提交 f525cea0 编写于 作者: G gaotingquan 提交者: Wei Shengyu

replace the arg engine with config

上级 e4a3e1bb
...@@ -220,23 +220,21 @@ class DataIterator(object): ...@@ -220,23 +220,21 @@ class DataIterator(object):
return batch return batch
def build_dataloader(engine): def build_dataloader(config, mode):
if "class_num" in engine.config["Global"]: if "class_num" in config["Global"]:
global_class_num = engine.config["Global"]["class_num"] global_class_num = config["Global"]["class_num"]
if "class_num" not in config["Arch"]: if "class_num" not in config["Arch"]:
engine.config["Arch"]["class_num"] = global_class_num config["Arch"]["class_num"] = global_class_num
msg = f"The Global.class_num will be deprecated. Please use Arch.class_num instead. Arch.class_num has been set to {global_class_num}." msg = f"The Global.class_num will be deprecated. Please use Arch.class_num instead. Arch.class_num has been set to {global_class_num}."
else: else:
msg = "The Global.class_num will be deprecated. Please use Arch.class_num instead. The Global.class_num has been ignored." msg = "The Global.class_num will be deprecated. Please use Arch.class_num instead. The Global.class_num has been ignored."
logger.warning(msg) logger.warning(msg)
class_num = engine.config["Arch"].get("class_num", None) class_num = config["Arch"].get("class_num", None)
engine.config["DataLoader"].update({"class_num": class_num}) config["DataLoader"].update({"class_num": class_num})
engine.config["DataLoader"].update({ config["DataLoader"].update({"epochs": config["Global"]["epochs"]})
"epochs": engine.config["Global"]["epochs"]
})
use_dali = engine.use_dali use_dali = config["Global"].get("use_dali", False)
dataloader_dict = { dataloader_dict = {
"Train": None, "Train": None,
"UnLabelTrain": None, "UnLabelTrain": None,
...@@ -245,37 +243,37 @@ def build_dataloader(engine): ...@@ -245,37 +243,37 @@ def build_dataloader(engine):
"Gallery": None, "Gallery": None,
"GalleryQuery": None "GalleryQuery": None
} }
if engine.mode == 'train': if mode == 'train':
train_dataloader = build( train_dataloader = build(
engine.config["DataLoader"], "Train", use_dali, seed=None) config["DataLoader"], "Train", use_dali, seed=None)
if engine.config["DataLoader"]["Train"].get("max_iter", None): if config["DataLoader"]["Train"].get("max_iter", None):
# set max iteration per epoch mannualy, when training by iteration(s), such as XBM, FixMatch. # set max iteration per epoch mannualy, when training by iteration(s), such as XBM, FixMatch.
max_iter = engine.config["Train"].get("max_iter") max_iter = config["Train"].get("max_iter")
max_iter = train_dataloader.max_iter // engine.update_freq * engine.update_freq update_freq = config["Global"].get("update_freq", 1)
max_iter = train_dataloader.max_iter // update_freq * update_freq
train_dataloader.max_iter = max_iter train_dataloader.max_iter = max_iter
if engine.config["DataLoader"]["Train"].get("convert_iterator", True): if config["DataLoader"]["Train"].get("convert_iterator", True):
train_dataloader = DataIterator(train_dataloader, use_dali) train_dataloader = DataIterator(train_dataloader, use_dali)
dataloader_dict["Train"] = train_dataloader dataloader_dict["Train"] = train_dataloader
if engine.config["DataLoader"].get('UnLabelTrain', None) is not None: if config["DataLoader"].get('UnLabelTrain', None) is not None:
dataloader_dict["UnLabelTrain"] = build( dataloader_dict["UnLabelTrain"] = build(
engine.config["DataLoader"], "UnLabelTrain", use_dali, seed=None) config["DataLoader"], "UnLabelTrain", use_dali, seed=None)
if engine.mode == "eval" or (engine.mode == "train" and if mode == "eval" or (mode == "train" and
engine.config["Global"]["eval_during_train"]): config["Global"]["eval_during_train"]):
if engine.config["Global"][ if config["Global"]["eval_mode"] in ["classification", "adaface"]:
"eval_mode"] in ["classification", "adaface"]:
dataloader_dict["Eval"] = build( dataloader_dict["Eval"] = build(
engine.config["DataLoader"], "Eval", use_dali, seed=None) config["DataLoader"], "Eval", use_dali, seed=None)
elif engine.config["Global"]["eval_mode"] == "retrieval": elif config["Global"]["eval_mode"] == "retrieval":
if len(engine.config["DataLoader"]["Eval"].keys()) == 1: if len(config["DataLoader"]["Eval"].keys()) == 1:
key = list(engine.config["DataLoader"]["Eval"].keys())[0] key = list(config["DataLoader"]["Eval"].keys())[0]
dataloader_dict["GalleryQuery"] = build( dataloader_dict["GalleryQuery"] = build(
engine.config["DataLoader"]["Eval"], key, use_dali) config["DataLoader"]["Eval"], key, use_dali)
else: else:
dataloader_dict["Gallery"] = build( dataloader_dict["Gallery"] = build(
engine.config["DataLoader"]["Eval"], "Gallery", use_dali) config["DataLoader"]["Eval"], "Gallery", use_dali)
dataloader_dict["Query"] = build( dataloader_dict["Query"] = build(config["DataLoader"]["Eval"],
engine.config["DataLoader"]["Eval"], "Query", use_dali) "Query", use_dali)
return dataloader_dict return dataloader_dict
...@@ -76,7 +76,7 @@ class Engine(object): ...@@ -76,7 +76,7 @@ class Engine(object):
# build dataloader # build dataloader
self.use_dali = self.config["Global"].get("use_dali", False) self.use_dali = self.config["Global"].get("use_dali", False)
self.dataloader_dict = build_dataloader(self) self.dataloader_dict = build_dataloader(self.config, mode)
# build loss # build loss
self.train_loss_func, self.unlabel_train_loss_func, self.eval_loss_func = build_loss( self.train_loss_func, self.unlabel_train_loss_func, self.eval_loss_func = build_loss(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册