提交 5d06a88a 编写于 作者: T Tingquan Gao

Revert "refactor: simplify engine"

This reverts commit 376d83d4.
上级 6aabb94d
...@@ -15,8 +15,6 @@ ...@@ -15,8 +15,6 @@
import inspect import inspect
import copy import copy
import random import random
import platform
import paddle import paddle
import numpy as np import numpy as np
import paddle.distributed as dist import paddle.distributed as dist
...@@ -88,7 +86,7 @@ def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int): ...@@ -88,7 +86,7 @@ def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int):
random.seed(worker_seed) random.seed(worker_seed)
def build(config, mode, device, use_dali=False, seed=None): def build_dataloader(config, mode, device, use_dali=False, seed=None):
assert mode in [ assert mode in [
'Train', 'Eval', 'Test', 'Gallery', 'Query', 'UnLabelTrain' 'Train', 'Eval', 'Test', 'Gallery', 'Query', 'UnLabelTrain'
], "Dataset mode should be Train, Eval, Test, Gallery, Query, UnLabelTrain" ], "Dataset mode should be Train, Eval, Test, Gallery, Query, UnLabelTrain"
...@@ -189,79 +187,3 @@ def build(config, mode, device, use_dali=False, seed=None): ...@@ -189,79 +187,3 @@ def build(config, mode, device, use_dali=False, seed=None):
logger.debug("build data_loader({}) success...".format(data_loader)) logger.debug("build data_loader({}) success...".format(data_loader))
return data_loader return data_loader
def build_dataloader(engine):
if "class_num" in engine.config["Global"]:
global_class_num = engine.config["Global"]["class_num"]
if "class_num" not in config["Arch"]:
engine.config["Arch"]["class_num"] = global_class_num
msg = f"The Global.class_num will be deprecated. Please use Arch.class_num instead. Arch.class_num has been set to {global_class_num}."
else:
msg = "The Global.class_num will be deprecated. Please use Arch.class_num instead. The Global.class_num has been ignored."
logger.warning(msg)
class_num = engine.config["Arch"].get("class_num", None)
engine.config["DataLoader"].update({"class_num": class_num})
engine.config["DataLoader"].update({
"epochs": engine.config["Global"]["epochs"]
})
use_dali = engine.config['Global'].get("use_dali", False)
dataloader_dict = {
"Train": None,
"UnLabelTrain": None,
"Eval": None,
"Query": None,
"Gallery": None,
"GalleryQuery": None
}
if engine.mode == 'train':
train_dataloader = build(
engine.config["DataLoader"],
"Train",
engine.device,
use_dali,
seed=None)
iter_per_epoch = len(train_dataloader) - 1 if platform.system(
) == "Windows" else len(train_dataloader)
if engine.config["Global"].get("iter_per_epoch", None):
# set max iteration per epoch mannualy, when training by iteration(s), such as XBM, FixMatch.
iter_per_epoch = engine.config["Global"].get("iter_per_epoch")
iter_per_epoch = iter_per_epoch // engine.update_freq * engine.update_freq
engine.iter_per_epoch = iter_per_epoch
train_dataloader.iter_per_epoch = iter_per_epoch
dataloader_dict["Train"] = train_dataloader
if engine.config["DataLoader"].get('UnLabelTrain', None) is not None:
dataloader_dict["UnLabelTrain"] = build(
engine.config["DataLoader"],
"UnLabelTrain",
engine.device,
use_dali,
seed=None)
if engine.mode == "eval" or (engine.mode == "train" and
engine.config["Global"]["eval_during_train"]):
if engine.eval_mode in ["classification", "adaface"]:
dataloader_dict["Eval"] = build(
engine.config["DataLoader"],
"Eval",
engine.device,
use_dali,
seed=None)
elif engine.eval_mode == "retrieval":
if len(engine.config["DataLoader"]["Eval"].keys()) == 1:
key = list(engine.config["DataLoader"]["Eval"].keys())[0]
dataloader_dict["GalleryQuery"] = build_dataloader(
engine.config["DataLoader"]["Eval"], key, engine.device,
use_dali)
else:
dataloader_dict["Gallery"] = build_dataloader(
engine.config["DataLoader"]["Eval"], "Gallery",
engine.device, use_dali)
dataloader_dict["Query"] = build_dataloader(
engine.config["DataLoader"]["Eval"], "Query",
engine.device, use_dali)
return dataloader_dict
此差异已折叠。
...@@ -51,7 +51,7 @@ from .metabinloss import IntraDomainScatterLoss ...@@ -51,7 +51,7 @@ from .metabinloss import IntraDomainScatterLoss
class CombinedLoss(nn.Layer): class CombinedLoss(nn.Layer):
def __init__(self, config_list): def __init__(self, config_list):
super().__init__() super().__init__()
loss_func = [] self.loss_func = []
self.loss_weight = [] self.loss_weight = []
assert isinstance(config_list, list), ( assert isinstance(config_list, list), (
'operator config should be a list') 'operator config should be a list')
...@@ -63,9 +63,8 @@ class CombinedLoss(nn.Layer): ...@@ -63,9 +63,8 @@ class CombinedLoss(nn.Layer):
assert "weight" in param, "weight must be in param, but param just contains {}".format( assert "weight" in param, "weight must be in param, but param just contains {}".format(
param.keys()) param.keys())
self.loss_weight.append(param.pop("weight")) self.loss_weight.append(param.pop("weight"))
loss_func.append(eval(name)(**param)) self.loss_func.append(eval(name)(**param))
self.loss_func = nn.LayerList(loss_func) self.loss_func = nn.LayerList(self.loss_func)
logger.debug("build loss {} success.".format(loss_func))
def __call__(self, input, batch): def __call__(self, input, batch):
loss_dict = {} loss_dict = {}
...@@ -84,22 +83,9 @@ class CombinedLoss(nn.Layer): ...@@ -84,22 +83,9 @@ class CombinedLoss(nn.Layer):
return loss_dict return loss_dict
def build_loss(config, mode="train"): def build_loss(config):
train_loss_func, unlabel_train_loss_func, eval_loss_func = None, None, None if config is None:
if mode == "train": return None
label_loss_info = config["Loss"]["Train"] module_class = CombinedLoss(copy.deepcopy(config))
if label_loss_info: logger.debug("build loss {} success.".format(module_class))
train_loss_func = CombinedLoss(copy.deepcopy(label_loss_info)) return module_class
unlabel_loss_info = config.get("UnLabelLoss", {}).get("Train", None)
if unlabel_loss_info:
unlabel_train_loss_func = CombinedLoss(
copy.deepcopy(unlabel_loss_info))
if mode == "eval" or (mode == "train" and
config["Global"]["eval_during_train"]):
loss_config = config.get("Loss", None)
if loss_config is not None:
loss_config = loss_config.get("Eval")
if loss_config is not None:
eval_loss_func = CombinedLoss(copy.deepcopy(loss_config))
return train_loss_func, unlabel_train_loss_func, eval_loss_func
...@@ -65,38 +65,6 @@ class CombinedMetrics(AvgMetrics): ...@@ -65,38 +65,6 @@ class CombinedMetrics(AvgMetrics):
metric.reset() metric.reset()
def build_metrics(engine): def build_metrics(config):
config, mode = engine.config, engine.mode metrics_list = CombinedMetrics(copy.deepcopy(config))
if mode == 'train' and "Metric" in config and "Train" in config[ return metrics_list
"Metric"] and config["Metric"]["Train"]:
metric_config = config["Metric"]["Train"]
if hasattr(engine.train_dataloader, "collate_fn"
) and engine.train_dataloader.collate_fn is not None:
for m_idx, m in enumerate(metric_config):
if "TopkAcc" in m:
msg = f"Unable to calculate accuracy when using \"batch_transform_ops\". The metric \"{m}\" has been removed."
logger.warning(msg)
metric_config.pop(m_idx)
train_metric_func = CombinedMetrics(copy.deepcopy(metric_config))
else:
train_metric_func = None
if mode == "eval" or (mode == "train" and
config["Global"]["eval_during_train"]):
eval_mode = config["Global"].get("eval_mode", "classification")
if eval_mode == "classification":
if "Metric" in config and "Eval" in config["Metric"]:
eval_metric_func = CombinedMetrics(
copy.deepcopy(config["Metric"]["Eval"]))
else:
eval_metric_func = None
elif eval_mode == "retrieval":
if "Metric" in config and "Eval" in config["Metric"]:
metric_config = config["Metric"]["Eval"]
else:
metric_config = [{"name": "Recallk", "topk": (1, 5)}]
eval_metric_func = CombinedMetrics(copy.deepcopy(metric_config))
else:
eval_metric_func = None
return train_metric_func, eval_metric_func
...@@ -45,11 +45,8 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch): ...@@ -45,11 +45,8 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch):
# model_list is None in static graph # model_list is None in static graph
def build_optimizer(config, dataloader, model_list=None): def build_optimizer(config, epochs, step_each_epoch, model_list=None):
optim_config = copy.deepcopy(config["Optimizer"]) optim_config = copy.deepcopy(config)
epochs = config["Global"]["epochs"]
update_freq = config["Global"].get("update_freq", 1)
step_each_epoch = dataloader.iter_per_epoch // update_freq
if isinstance(optim_config, dict): if isinstance(optim_config, dict):
# convert {'name': xxx, **optim_cfg} to [{name: {scope: xxx, **optim_cfg}}] # convert {'name': xxx, **optim_cfg} to [{name: {scope: xxx, **optim_cfg}}]
optim_name = optim_config.pop("name") optim_name = optim_config.pop("name")
......
...@@ -22,15 +22,16 @@ import paddle.distributed as dist ...@@ -22,15 +22,16 @@ import paddle.distributed as dist
_logger = None _logger = None
def init_logger(config, mode="train", name='ppcls', log_level=logging.INFO): def init_logger(name='ppcls', log_file=None, log_level=logging.INFO):
"""Initialize and get a logger by name. """Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will logger by adding one or two handlers, otherwise the initialized logger will
be directly returned. During initialization, a StreamHandler will always be be directly returned. During initialization, a StreamHandler will always be
added. added. If `log_file` is specified a FileHandler will also be added.
Args: Args:
config(dict): Training config.
name (str): Logger name. name (str): Logger name.
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the logger.
log_level (int): The logger level. Note that only the process of log_level (int): The logger level. Note that only the process of
rank 0 is affected, and other processes will set the level to rank 0 is affected, and other processes will set the level to
"Error" thus be silent most of the time. "Error" thus be silent most of the time.
...@@ -62,8 +63,6 @@ def init_logger(config, mode="train", name='ppcls', log_level=logging.INFO): ...@@ -62,8 +63,6 @@ def init_logger(config, mode="train", name='ppcls', log_level=logging.INFO):
if init_flag: if init_flag:
_logger.addHandler(stream_handler) _logger.addHandler(stream_handler)
log_file = os.path.join(config['Global']['output_dir'],
config["Arch"]["name"], f"{mode}.log")
if log_file is not None and dist.get_rank() == 0: if log_file is not None and dist.get_rank() == 0:
log_file_folder = os.path.split(log_file)[0] log_file_folder = os.path.split(log_file)[0]
os.makedirs(log_file_folder, exist_ok=True) os.makedirs(log_file_folder, exist_ok=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册