提交 5d06a88a 编写于 作者: T Tingquan Gao

Revert "refactor: simplify engine"

This reverts commit 376d83d4.
上级 6aabb94d
......@@ -15,8 +15,6 @@
import inspect
import copy
import random
import platform
import paddle
import numpy as np
import paddle.distributed as dist
......@@ -88,7 +86,7 @@ def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int):
random.seed(worker_seed)
def build(config, mode, device, use_dali=False, seed=None):
def build_dataloader(config, mode, device, use_dali=False, seed=None):
assert mode in [
'Train', 'Eval', 'Test', 'Gallery', 'Query', 'UnLabelTrain'
], "Dataset mode should be Train, Eval, Test, Gallery, Query, UnLabelTrain"
......@@ -189,79 +187,3 @@ def build(config, mode, device, use_dali=False, seed=None):
logger.debug("build data_loader({}) success...".format(data_loader))
return data_loader
def build_dataloader(engine):
if "class_num" in engine.config["Global"]:
global_class_num = engine.config["Global"]["class_num"]
if "class_num" not in config["Arch"]:
engine.config["Arch"]["class_num"] = global_class_num
msg = f"The Global.class_num will be deprecated. Please use Arch.class_num instead. Arch.class_num has been set to {global_class_num}."
else:
msg = "The Global.class_num will be deprecated. Please use Arch.class_num instead. The Global.class_num has been ignored."
logger.warning(msg)
class_num = engine.config["Arch"].get("class_num", None)
engine.config["DataLoader"].update({"class_num": class_num})
engine.config["DataLoader"].update({
"epochs": engine.config["Global"]["epochs"]
})
use_dali = engine.config['Global'].get("use_dali", False)
dataloader_dict = {
"Train": None,
"UnLabelTrain": None,
"Eval": None,
"Query": None,
"Gallery": None,
"GalleryQuery": None
}
if engine.mode == 'train':
train_dataloader = build(
engine.config["DataLoader"],
"Train",
engine.device,
use_dali,
seed=None)
iter_per_epoch = len(train_dataloader) - 1 if platform.system(
) == "Windows" else len(train_dataloader)
if engine.config["Global"].get("iter_per_epoch", None):
# set max iteration per epoch mannualy, when training by iteration(s), such as XBM, FixMatch.
iter_per_epoch = engine.config["Global"].get("iter_per_epoch")
iter_per_epoch = iter_per_epoch // engine.update_freq * engine.update_freq
engine.iter_per_epoch = iter_per_epoch
train_dataloader.iter_per_epoch = iter_per_epoch
dataloader_dict["Train"] = train_dataloader
if engine.config["DataLoader"].get('UnLabelTrain', None) is not None:
dataloader_dict["UnLabelTrain"] = build(
engine.config["DataLoader"],
"UnLabelTrain",
engine.device,
use_dali,
seed=None)
if engine.mode == "eval" or (engine.mode == "train" and
engine.config["Global"]["eval_during_train"]):
if engine.eval_mode in ["classification", "adaface"]:
dataloader_dict["Eval"] = build(
engine.config["DataLoader"],
"Eval",
engine.device,
use_dali,
seed=None)
elif engine.eval_mode == "retrieval":
if len(engine.config["DataLoader"]["Eval"].keys()) == 1:
key = list(engine.config["DataLoader"]["Eval"].keys())[0]
dataloader_dict["GalleryQuery"] = build_dataloader(
engine.config["DataLoader"]["Eval"], key, engine.device,
use_dali)
else:
dataloader_dict["Gallery"] = build_dataloader(
engine.config["DataLoader"]["Eval"], "Gallery",
engine.device, use_dali)
dataloader_dict["Query"] = build_dataloader(
engine.config["DataLoader"]["Eval"], "Query",
engine.device, use_dali)
return dataloader_dict
此差异已折叠。
......@@ -51,7 +51,7 @@ from .metabinloss import IntraDomainScatterLoss
class CombinedLoss(nn.Layer):
def __init__(self, config_list):
super().__init__()
loss_func = []
self.loss_func = []
self.loss_weight = []
assert isinstance(config_list, list), (
'operator config should be a list')
......@@ -63,9 +63,8 @@ class CombinedLoss(nn.Layer):
assert "weight" in param, "weight must be in param, but param just contains {}".format(
param.keys())
self.loss_weight.append(param.pop("weight"))
loss_func.append(eval(name)(**param))
self.loss_func = nn.LayerList(loss_func)
logger.debug("build loss {} success.".format(loss_func))
self.loss_func.append(eval(name)(**param))
self.loss_func = nn.LayerList(self.loss_func)
def __call__(self, input, batch):
loss_dict = {}
......@@ -84,22 +83,9 @@ class CombinedLoss(nn.Layer):
return loss_dict
def build_loss(config, mode="train"):
train_loss_func, unlabel_train_loss_func, eval_loss_func = None, None, None
if mode == "train":
label_loss_info = config["Loss"]["Train"]
if label_loss_info:
train_loss_func = CombinedLoss(copy.deepcopy(label_loss_info))
unlabel_loss_info = config.get("UnLabelLoss", {}).get("Train", None)
if unlabel_loss_info:
unlabel_train_loss_func = CombinedLoss(
copy.deepcopy(unlabel_loss_info))
if mode == "eval" or (mode == "train" and
config["Global"]["eval_during_train"]):
loss_config = config.get("Loss", None)
if loss_config is not None:
loss_config = loss_config.get("Eval")
if loss_config is not None:
eval_loss_func = CombinedLoss(copy.deepcopy(loss_config))
return train_loss_func, unlabel_train_loss_func, eval_loss_func
def build_loss(config):
if config is None:
return None
module_class = CombinedLoss(copy.deepcopy(config))
logger.debug("build loss {} success.".format(module_class))
return module_class
......@@ -65,38 +65,6 @@ class CombinedMetrics(AvgMetrics):
metric.reset()
def build_metrics(engine):
config, mode = engine.config, engine.mode
if mode == 'train' and "Metric" in config and "Train" in config[
"Metric"] and config["Metric"]["Train"]:
metric_config = config["Metric"]["Train"]
if hasattr(engine.train_dataloader, "collate_fn"
) and engine.train_dataloader.collate_fn is not None:
for m_idx, m in enumerate(metric_config):
if "TopkAcc" in m:
msg = f"Unable to calculate accuracy when using \"batch_transform_ops\". The metric \"{m}\" has been removed."
logger.warning(msg)
metric_config.pop(m_idx)
train_metric_func = CombinedMetrics(copy.deepcopy(metric_config))
else:
train_metric_func = None
if mode == "eval" or (mode == "train" and
config["Global"]["eval_during_train"]):
eval_mode = config["Global"].get("eval_mode", "classification")
if eval_mode == "classification":
if "Metric" in config and "Eval" in config["Metric"]:
eval_metric_func = CombinedMetrics(
copy.deepcopy(config["Metric"]["Eval"]))
else:
eval_metric_func = None
elif eval_mode == "retrieval":
if "Metric" in config and "Eval" in config["Metric"]:
metric_config = config["Metric"]["Eval"]
else:
metric_config = [{"name": "Recallk", "topk": (1, 5)}]
eval_metric_func = CombinedMetrics(copy.deepcopy(metric_config))
else:
eval_metric_func = None
return train_metric_func, eval_metric_func
def build_metrics(config):
metrics_list = CombinedMetrics(copy.deepcopy(config))
return metrics_list
......@@ -45,11 +45,8 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch):
# model_list is None in static graph
def build_optimizer(config, dataloader, model_list=None):
optim_config = copy.deepcopy(config["Optimizer"])
epochs = config["Global"]["epochs"]
update_freq = config["Global"].get("update_freq", 1)
step_each_epoch = dataloader.iter_per_epoch // update_freq
def build_optimizer(config, epochs, step_each_epoch, model_list=None):
optim_config = copy.deepcopy(config)
if isinstance(optim_config, dict):
# convert {'name': xxx, **optim_cfg} to [{name: {scope: xxx, **optim_cfg}}]
optim_name = optim_config.pop("name")
......
......@@ -22,15 +22,16 @@ import paddle.distributed as dist
_logger = None
def init_logger(config, mode="train", name='ppcls', log_level=logging.INFO):
def init_logger(name='ppcls', log_file=None, log_level=logging.INFO):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
be directly returned. During initialization, a StreamHandler will always be
added.
added. If `log_file` is specified a FileHandler will also be added.
Args:
config(dict): Training config.
name (str): Logger name.
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the logger.
log_level (int): The logger level. Note that only the process of
rank 0 is affected, and other processes will set the level to
"Error" thus be silent most of the time.
......@@ -62,8 +63,6 @@ def init_logger(config, mode="train", name='ppcls', log_level=logging.INFO):
if init_flag:
_logger.addHandler(stream_handler)
log_file = os.path.join(config['Global']['output_dir'],
config["Arch"]["name"], f"{mode}.log")
if log_file is not None and dist.get_rank() == 0:
log_file_folder = os.path.split(log_file)[0]
os.makedirs(log_file_folder, exist_ok=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册