From f42719afbb60c04beb92b959fed524ad1b7d7aaf Mon Sep 17 00:00:00 2001 From: Tingquan Gao <35441050@qq.com> Date: Tue, 14 Mar 2023 16:16:40 +0800 Subject: [PATCH] Revert "replace the arg engine with config" This reverts commit f525cea006c66b11e6fee69b2088d0adc08ca57f. --- ppcls/data/__init__.py | 58 ++++++++++++++++++++++-------------------- ppcls/engine/engine.py | 2 +- 2 files changed, 31 insertions(+), 29 deletions(-) diff --git a/ppcls/data/__init__.py b/ppcls/data/__init__.py index 923285cb..a964a831 100644 --- a/ppcls/data/__init__.py +++ b/ppcls/data/__init__.py @@ -220,21 +220,23 @@ class DataIterator(object): return batch -def build_dataloader(config, mode): - if "class_num" in config["Global"]: - global_class_num = config["Global"]["class_num"] +def build_dataloader(engine): + if "class_num" in engine.config["Global"]: + global_class_num = engine.config["Global"]["class_num"] if "class_num" not in config["Arch"]: - config["Arch"]["class_num"] = global_class_num + engine.config["Arch"]["class_num"] = global_class_num msg = f"The Global.class_num will be deprecated. Please use Arch.class_num instead. Arch.class_num has been set to {global_class_num}." else: msg = "The Global.class_num will be deprecated. Please use Arch.class_num instead. The Global.class_num has been ignored." logger.warning(msg) - class_num = config["Arch"].get("class_num", None) - config["DataLoader"].update({"class_num": class_num}) - config["DataLoader"].update({"epochs": config["Global"]["epochs"]}) + class_num = engine.config["Arch"].get("class_num", None) + engine.config["DataLoader"].update({"class_num": class_num}) + engine.config["DataLoader"].update({ + "epochs": engine.config["Global"]["epochs"] + }) - use_dali = config["Global"].get("use_dali", False) + use_dali = engine.use_dali dataloader_dict = { "Train": None, "UnLabelTrain": None, @@ -243,37 +245,37 @@ def build_dataloader(config, mode): "Gallery": None, "GalleryQuery": None } - if mode == 'train': + if engine.mode == 'train': train_dataloader = build( - config["DataLoader"], "Train", use_dali, seed=None) + engine.config["DataLoader"], "Train", use_dali, seed=None) - if config["DataLoader"]["Train"].get("max_iter", None): + if engine.config["DataLoader"]["Train"].get("max_iter", None): # set max iteration per epoch mannualy, when training by iteration(s), such as XBM, FixMatch. - max_iter = config["Train"].get("max_iter") - update_freq = config["Global"].get("update_freq", 1) - max_iter = train_dataloader.max_iter // update_freq * update_freq + max_iter = engine.config["Train"].get("max_iter") + max_iter = train_dataloader.max_iter // engine.update_freq * engine.update_freq train_dataloader.max_iter = max_iter - if config["DataLoader"]["Train"].get("convert_iterator", True): + if engine.config["DataLoader"]["Train"].get("convert_iterator", True): train_dataloader = DataIterator(train_dataloader, use_dali) dataloader_dict["Train"] = train_dataloader - if config["DataLoader"].get('UnLabelTrain', None) is not None: + if engine.config["DataLoader"].get('UnLabelTrain', None) is not None: dataloader_dict["UnLabelTrain"] = build( - config["DataLoader"], "UnLabelTrain", use_dali, seed=None) + engine.config["DataLoader"], "UnLabelTrain", use_dali, seed=None) - if mode == "eval" or (mode == "train" and - config["Global"]["eval_during_train"]): - if config["Global"]["eval_mode"] in ["classification", "adaface"]: + if engine.mode == "eval" or (engine.mode == "train" and + engine.config["Global"]["eval_during_train"]): + if engine.config["Global"][ + "eval_mode"] in ["classification", "adaface"]: dataloader_dict["Eval"] = build( - config["DataLoader"], "Eval", use_dali, seed=None) - elif config["Global"]["eval_mode"] == "retrieval": - if len(config["DataLoader"]["Eval"].keys()) == 1: - key = list(config["DataLoader"]["Eval"].keys())[0] + engine.config["DataLoader"], "Eval", use_dali, seed=None) + elif engine.config["Global"]["eval_mode"] == "retrieval": + if len(engine.config["DataLoader"]["Eval"].keys()) == 1: + key = list(engine.config["DataLoader"]["Eval"].keys())[0] dataloader_dict["GalleryQuery"] = build( - config["DataLoader"]["Eval"], key, use_dali) + engine.config["DataLoader"]["Eval"], key, use_dali) else: dataloader_dict["Gallery"] = build( - config["DataLoader"]["Eval"], "Gallery", use_dali) - dataloader_dict["Query"] = build(config["DataLoader"]["Eval"], - "Query", use_dali) + engine.config["DataLoader"]["Eval"], "Gallery", use_dali) + dataloader_dict["Query"] = build( + engine.config["DataLoader"]["Eval"], "Query", use_dali) return dataloader_dict diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index 3e1ee7ac..d29eb969 100755 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -76,7 +76,7 @@ class Engine(object): # build dataloader self.use_dali = self.config["Global"].get("use_dali", False) - self.dataloader_dict = build_dataloader(self.config, mode) + self.dataloader_dict = build_dataloader(self) # build loss self.train_loss_func, self.unlabel_train_loss_func, self.eval_loss_func = build_loss( -- GitLab