提交 f42719af 编写于 作者: T Tingquan Gao

Revert "replace the arg engine with config"

This reverts commit f525cea0.
上级 7243f142
...@@ -220,21 +220,23 @@ class DataIterator(object): ...@@ -220,21 +220,23 @@ class DataIterator(object):
return batch return batch
def build_dataloader(config, mode): def build_dataloader(engine):
if "class_num" in config["Global"]: if "class_num" in engine.config["Global"]:
global_class_num = config["Global"]["class_num"] global_class_num = engine.config["Global"]["class_num"]
if "class_num" not in config["Arch"]: if "class_num" not in config["Arch"]:
config["Arch"]["class_num"] = global_class_num engine.config["Arch"]["class_num"] = global_class_num
msg = f"The Global.class_num will be deprecated. Please use Arch.class_num instead. Arch.class_num has been set to {global_class_num}." msg = f"The Global.class_num will be deprecated. Please use Arch.class_num instead. Arch.class_num has been set to {global_class_num}."
else: else:
msg = "The Global.class_num will be deprecated. Please use Arch.class_num instead. The Global.class_num has been ignored." msg = "The Global.class_num will be deprecated. Please use Arch.class_num instead. The Global.class_num has been ignored."
logger.warning(msg) logger.warning(msg)
class_num = config["Arch"].get("class_num", None) class_num = engine.config["Arch"].get("class_num", None)
config["DataLoader"].update({"class_num": class_num}) engine.config["DataLoader"].update({"class_num": class_num})
config["DataLoader"].update({"epochs": config["Global"]["epochs"]}) engine.config["DataLoader"].update({
"epochs": engine.config["Global"]["epochs"]
})
use_dali = config["Global"].get("use_dali", False) use_dali = engine.use_dali
dataloader_dict = { dataloader_dict = {
"Train": None, "Train": None,
"UnLabelTrain": None, "UnLabelTrain": None,
...@@ -243,37 +245,37 @@ def build_dataloader(config, mode): ...@@ -243,37 +245,37 @@ def build_dataloader(config, mode):
"Gallery": None, "Gallery": None,
"GalleryQuery": None "GalleryQuery": None
} }
if mode == 'train': if engine.mode == 'train':
train_dataloader = build( train_dataloader = build(
config["DataLoader"], "Train", use_dali, seed=None) engine.config["DataLoader"], "Train", use_dali, seed=None)
if config["DataLoader"]["Train"].get("max_iter", None): if engine.config["DataLoader"]["Train"].get("max_iter", None):
# set max iteration per epoch mannualy, when training by iteration(s), such as XBM, FixMatch. # set max iteration per epoch mannualy, when training by iteration(s), such as XBM, FixMatch.
max_iter = config["Train"].get("max_iter") max_iter = engine.config["Train"].get("max_iter")
update_freq = config["Global"].get("update_freq", 1) max_iter = train_dataloader.max_iter // engine.update_freq * engine.update_freq
max_iter = train_dataloader.max_iter // update_freq * update_freq
train_dataloader.max_iter = max_iter train_dataloader.max_iter = max_iter
if config["DataLoader"]["Train"].get("convert_iterator", True): if engine.config["DataLoader"]["Train"].get("convert_iterator", True):
train_dataloader = DataIterator(train_dataloader, use_dali) train_dataloader = DataIterator(train_dataloader, use_dali)
dataloader_dict["Train"] = train_dataloader dataloader_dict["Train"] = train_dataloader
if config["DataLoader"].get('UnLabelTrain', None) is not None: if engine.config["DataLoader"].get('UnLabelTrain', None) is not None:
dataloader_dict["UnLabelTrain"] = build( dataloader_dict["UnLabelTrain"] = build(
config["DataLoader"], "UnLabelTrain", use_dali, seed=None) engine.config["DataLoader"], "UnLabelTrain", use_dali, seed=None)
if mode == "eval" or (mode == "train" and if engine.mode == "eval" or (engine.mode == "train" and
config["Global"]["eval_during_train"]): engine.config["Global"]["eval_during_train"]):
if config["Global"]["eval_mode"] in ["classification", "adaface"]: if engine.config["Global"][
"eval_mode"] in ["classification", "adaface"]:
dataloader_dict["Eval"] = build( dataloader_dict["Eval"] = build(
config["DataLoader"], "Eval", use_dali, seed=None) engine.config["DataLoader"], "Eval", use_dali, seed=None)
elif config["Global"]["eval_mode"] == "retrieval": elif engine.config["Global"]["eval_mode"] == "retrieval":
if len(config["DataLoader"]["Eval"].keys()) == 1: if len(engine.config["DataLoader"]["Eval"].keys()) == 1:
key = list(config["DataLoader"]["Eval"].keys())[0] key = list(engine.config["DataLoader"]["Eval"].keys())[0]
dataloader_dict["GalleryQuery"] = build( dataloader_dict["GalleryQuery"] = build(
config["DataLoader"]["Eval"], key, use_dali) engine.config["DataLoader"]["Eval"], key, use_dali)
else: else:
dataloader_dict["Gallery"] = build( dataloader_dict["Gallery"] = build(
config["DataLoader"]["Eval"], "Gallery", use_dali) engine.config["DataLoader"]["Eval"], "Gallery", use_dali)
dataloader_dict["Query"] = build(config["DataLoader"]["Eval"], dataloader_dict["Query"] = build(
"Query", use_dali) engine.config["DataLoader"]["Eval"], "Query", use_dali)
return dataloader_dict return dataloader_dict
...@@ -76,7 +76,7 @@ class Engine(object): ...@@ -76,7 +76,7 @@ class Engine(object):
# build dataloader # build dataloader
self.use_dali = self.config["Global"].get("use_dali", False) self.use_dali = self.config["Global"].get("use_dali", False)
self.dataloader_dict = build_dataloader(self.config, mode) self.dataloader_dict = build_dataloader(self)
# build loss # build loss
self.train_loss_func, self.unlabel_train_loss_func, self.eval_loss_func = build_loss( self.train_loss_func, self.unlabel_train_loss_func, self.eval_loss_func = build_loss(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册