提交 187f38eb 编写于 作者: G gaotingquan 提交者: Wei Shengyu

refactor

1. rm Global.eval_mode
2. add Global.task
3. mv type_name to ppcls.utils
4. build dataloader, loss, metric by mode
上级 97935164
......@@ -110,7 +110,7 @@ def build(config, mode, use_dali=False, seed=None):
config_dataset = copy.deepcopy(config_dataset)
dataset_name = config_dataset.pop('name')
if 'batch_transform_ops' in config_dataset:
batch_transform = config_dataset.pop('batch_transform_ops')
batch_transform = config_dataset['batch_transform_ops']
else:
batch_transform = None
......@@ -254,10 +254,11 @@ def build_dataloader(config, mode):
if mode == "eval" or (mode == "train" and
config["Global"]["eval_during_train"]):
if config["Global"]["eval_mode"] in ["classification", "adaface"]:
task = config["Global"].get("task", "classification")
if task in ["classification", "adaface"]:
dataloader_dict["Eval"] = build(
config["DataLoader"], "Eval", use_dali, seed=None)
elif config["Global"]["eval_mode"] == "retrieval":
elif task == "retrieval":
if len(config["DataLoader"]["Eval"].keys()) == 1:
key = list(config["DataLoader"]["Eval"].keys())[0]
dataloader_dict["GalleryQuery"] = build(
......
......@@ -42,7 +42,7 @@ from ppcls.data.preprocess.ops.dali_operators import RandomRot90
from ppcls.data.preprocess.ops.dali_operators import RandomRotation
from ppcls.data.preprocess.ops.dali_operators import ResizeImage
from ppcls.data.preprocess.ops.dali_operators import ToCHWImage
from ppcls.engine.train.utils import type_name
from ppcls.utils import type_name
from ppcls.utils import logger
INTERP_MAP = {
......
......@@ -14,8 +14,7 @@
from __future__ import absolute_import, division, print_function
from ppcls.data import build_dataloader
from ppcls.engine.train.utils import type_name
from ppcls.utils import logger
from ppcls.utils import logger, type_name
from .regular_train_epoch import regular_train_epoch
......
......@@ -14,7 +14,7 @@
from __future__ import absolute_import, division, print_function
import datetime
from ppcls.utils import logger
from ppcls.utils import logger, type_name
from ppcls.utils.misc import AverageMeter
......@@ -75,8 +75,3 @@ def log_info(trainer, batch_size, epoch_id, iter_id):
value=trainer.output_info[key].avg,
step=trainer.global_step,
writer=trainer.vdl_writer)
def type_name(object: object) -> str:
"""get class name of an object"""
return object.__class__.__name__
......@@ -94,6 +94,7 @@ def build_loss(config, mode="train"):
if unlabel_loss_info:
unlabel_train_loss_func = CombinedLoss(
copy.deepcopy(unlabel_loss_info))
return train_loss_func, unlabel_train_loss_func
if mode == "eval" or (mode == "train" and
config["Global"]["eval_during_train"]):
loss_config = config.get("Loss", None)
......@@ -101,5 +102,4 @@ def build_loss(config, mode="train"):
loss_config = loss_config.get("Eval")
if loss_config is not None:
eval_loss_func = CombinedLoss(copy.deepcopy(loss_config))
return train_loss_func, unlabel_train_loss_func, eval_loss_func
return eval_loss_func
......@@ -65,22 +65,19 @@ class CombinedMetrics(AvgMetrics):
metric.reset()
def build_metrics(engine):
config, mode = engine.config, engine.mode
def build_metrics(config, mode):
if mode == 'train' and "Metric" in config and "Train" in config[
"Metric"] and config["Metric"]["Train"]:
metric_config = config["Metric"]["Train"]
if hasattr(engine.dataloader_dict["Train"],
"collate_fn") and engine.dataloader_dict[
"Train"].collate_fn is not None:
if config["DataLoader"]["Train"]["dataset"].get("batch_transform_ops",
None):
for m_idx, m in enumerate(metric_config):
if "TopkAcc" in m:
msg = f"Unable to calculate accuracy when using \"batch_transform_ops\". The metric \"{m}\" has been removed."
logger.warning(msg)
metric_config.pop(m_idx)
train_metric_func = CombinedMetrics(copy.deepcopy(metric_config))
else:
train_metric_func = None
return train_metric_func
if mode == "eval" or (mode == "train" and
config["Global"]["eval_during_train"]):
......@@ -97,7 +94,4 @@ def build_metrics(engine):
else:
metric_config = [{"name": "Recallk", "topk": (1, 5)}]
eval_metric_func = CombinedMetrics(copy.deepcopy(metric_config))
else:
eval_metric_func = None
return train_metric_func, eval_metric_func
return eval_metric_func
......@@ -21,8 +21,7 @@ import copy
import paddle
from typing import Dict, List
from ppcls.engine.train.utils import type_name
from ppcls.utils import logger
from ..utils import logger, type_name
from . import optimizer
......@@ -45,14 +44,10 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch):
# model_list is None in static graph
def build_optimizer(engine):
if engine.mode != "train":
return None, None
config, max_iter, model_list = engine.config, engine.dataloader_dict[
"Train"].max_iter, [engine.model, engine.train_loss_func]
def build_optimizer(config, max_iter, model_list, update_freq):
optim_config = copy.deepcopy(config["Optimizer"])
epochs = config["Global"]["epochs"]
update_freq = engine.update_freq
update_freq = config["Global"].get("update_freq", 1)
step_each_epoch = max_iter // update_freq
if isinstance(optim_config, dict):
# convert {'name': xxx, **optim_cfg} to [{name: {scope: xxx, **optim_cfg}}]
......
......@@ -26,3 +26,8 @@ from .metrics import multi_hot_encode
from .metrics import precision_recall_fscore
from .misc import AverageMeter
from .save_load import init_model
def type_name(object: object) -> str:
"""get class name of an object"""
return object.__class__.__name__
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册