diff --git a/ppcls/data/__init__.py b/ppcls/data/__init__.py index 942391d393fe0074e8d57c28da3394fd186e4e31..6d2914bc2059cad88e32594c03d3a9c66763f4ba 100644 --- a/ppcls/data/__init__.py +++ b/ppcls/data/__init__.py @@ -110,7 +110,7 @@ def build(config, mode, use_dali=False, seed=None): config_dataset = copy.deepcopy(config_dataset) dataset_name = config_dataset.pop('name') if 'batch_transform_ops' in config_dataset: - batch_transform = config_dataset['batch_transform_ops'] + batch_transform = config_dataset.pop('batch_transform_ops') else: batch_transform = None @@ -254,11 +254,10 @@ def build_dataloader(config, mode): if mode == "eval" or (mode == "train" and config["Global"]["eval_during_train"]): - task = config["Global"].get("task", "classification") - if task in ["classification", "adaface"]: + if config["Global"]["eval_mode"] in ["classification", "adaface"]: dataloader_dict["Eval"] = build( config["DataLoader"], "Eval", use_dali, seed=None) - elif task == "retrieval": + elif config["Global"]["eval_mode"] == "retrieval": if len(config["DataLoader"]["Eval"].keys()) == 1: key = list(config["DataLoader"]["Eval"].keys())[0] dataloader_dict["GalleryQuery"] = build( diff --git a/ppcls/data/dataloader/dali.py b/ppcls/data/dataloader/dali.py index bd654aafaa2e9ee658c37eff7cbd860b97e61890..9720262c65a48a820f8b26d236904cb55201a340 100644 --- a/ppcls/data/dataloader/dali.py +++ b/ppcls/data/dataloader/dali.py @@ -42,7 +42,7 @@ from ppcls.data.preprocess.ops.dali_operators import RandomRot90 from ppcls.data.preprocess.ops.dali_operators import RandomRotation from ppcls.data.preprocess.ops.dali_operators import ResizeImage from ppcls.data.preprocess.ops.dali_operators import ToCHWImage -from ppcls.utils import type_name +from ppcls.engine.train.utils import type_name from ppcls.utils import logger INTERP_MAP = { diff --git a/ppcls/engine/train/train_progressive.py b/ppcls/engine/train/train_progressive.py index f1bba12de6c8cf840cb46bb89b1516546ea8dfb7..d53a5f56f2091c47b7e1b4618df0e8407b163ac6 100644 --- a/ppcls/engine/train/train_progressive.py +++ b/ppcls/engine/train/train_progressive.py @@ -14,7 +14,8 @@ from __future__ import absolute_import, division, print_function from ppcls.data import build_dataloader -from ppcls.utils import logger, type_name +from ppcls.engine.train.utils import type_name +from ppcls.utils import logger from .regular_train_epoch import regular_train_epoch diff --git a/ppcls/engine/train/utils.py b/ppcls/engine/train/utils.py index 087e943a49913204e7475fff316a36a3d0a9a477..c31084e844cd685acb3ee3bbd5feee717b7d14c8 100644 --- a/ppcls/engine/train/utils.py +++ b/ppcls/engine/train/utils.py @@ -14,7 +14,7 @@ from __future__ import absolute_import, division, print_function import datetime -from ppcls.utils import logger, type_name +from ppcls.utils import logger from ppcls.utils.misc import AverageMeter @@ -75,3 +75,8 @@ def log_info(trainer, batch_size, epoch_id, iter_id): value=trainer.output_info[key].avg, step=trainer.global_step, writer=trainer.vdl_writer) + + +def type_name(object: object) -> str: + """get class name of an object""" + return object.__class__.__name__ diff --git a/ppcls/loss/__init__.py b/ppcls/loss/__init__.py index 52b3bdc96900a117c86cba719932e1ce8dbe20e7..b3e4becc0df9a6f4c787c1cf65fa02980d5e8ce6 100644 --- a/ppcls/loss/__init__.py +++ b/ppcls/loss/__init__.py @@ -94,7 +94,6 @@ def build_loss(config, mode="train"): if unlabel_loss_info: unlabel_train_loss_func = CombinedLoss( copy.deepcopy(unlabel_loss_info)) - return train_loss_func, unlabel_train_loss_func if mode == "eval" or (mode == "train" and config["Global"]["eval_during_train"]): loss_config = config.get("Loss", None) @@ -102,4 +101,5 @@ def build_loss(config, mode="train"): loss_config = loss_config.get("Eval") if loss_config is not None: eval_loss_func = CombinedLoss(copy.deepcopy(loss_config)) - return eval_loss_func + + return train_loss_func, unlabel_train_loss_func, eval_loss_func diff --git a/ppcls/metric/__init__.py b/ppcls/metric/__init__.py index b614c116efe3ad55e094cdbc3c0c24ad567634e9..55a3f345ace90d10e7214c8dc0d1446289b88fc9 100644 --- a/ppcls/metric/__init__.py +++ b/ppcls/metric/__init__.py @@ -65,19 +65,22 @@ class CombinedMetrics(AvgMetrics): metric.reset() -def build_metrics(config, mode): +def build_metrics(engine): + config, mode = engine.config, engine.mode if mode == 'train' and "Metric" in config and "Train" in config[ "Metric"] and config["Metric"]["Train"]: metric_config = config["Metric"]["Train"] - if config["DataLoader"]["Train"]["dataset"].get("batch_transform_ops", - None): + if hasattr(engine.dataloader_dict["Train"], + "collate_fn") and engine.dataloader_dict[ + "Train"].collate_fn is not None: for m_idx, m in enumerate(metric_config): if "TopkAcc" in m: msg = f"Unable to calculate accuracy when using \"batch_transform_ops\". The metric \"{m}\" has been removed." logger.warning(msg) metric_config.pop(m_idx) train_metric_func = CombinedMetrics(copy.deepcopy(metric_config)) - return train_metric_func + else: + train_metric_func = None if mode == "eval" or (mode == "train" and config["Global"]["eval_during_train"]): @@ -94,4 +97,7 @@ def build_metrics(config, mode): else: metric_config = [{"name": "Recallk", "topk": (1, 5)}] eval_metric_func = CombinedMetrics(copy.deepcopy(metric_config)) - return eval_metric_func + else: + eval_metric_func = None + + return train_metric_func, eval_metric_func diff --git a/ppcls/optimizer/__init__.py b/ppcls/optimizer/__init__.py index ccccd6f385f2a7b91b886bdefd2750056ccb8a4f..559a340d1218dde0d7134c274018ed1cc4749da6 100644 --- a/ppcls/optimizer/__init__.py +++ b/ppcls/optimizer/__init__.py @@ -21,7 +21,8 @@ import copy import paddle from typing import Dict, List -from ..utils import logger, type_name +from ppcls.engine.train.utils import type_name +from ppcls.utils import logger from . import optimizer @@ -44,10 +45,14 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch): # model_list is None in static graph -def build_optimizer(config, max_iter, model_list, update_freq): +def build_optimizer(engine): + if engine.mode != "train": + return None, None + config, max_iter, model_list = engine.config, engine.dataloader_dict[ + "Train"].max_iter, [engine.model, engine.train_loss_func] optim_config = copy.deepcopy(config["Optimizer"]) epochs = config["Global"]["epochs"] - update_freq = config["Global"].get("update_freq", 1) + update_freq = engine.update_freq step_each_epoch = max_iter // update_freq if isinstance(optim_config, dict): # convert {'name': xxx, **optim_cfg} to [{name: {scope: xxx, **optim_cfg}}] diff --git a/ppcls/utils/__init__.py b/ppcls/utils/__init__.py index 08c26bb2b8dcd85624c2af9d5a1eb175f6def6ec..59c0d050797f1bec2d06fa8b1bba6125eb78107c 100644 --- a/ppcls/utils/__init__.py +++ b/ppcls/utils/__init__.py @@ -26,8 +26,3 @@ from .metrics import multi_hot_encode from .metrics import precision_recall_fscore from .misc import AverageMeter from .save_load import init_model - - -def type_name(object: object) -> str: - """get class name of an object""" - return object.__class__.__name__