提交 339be96e 编写于 作者: T Tingquan Gao

Revert "refactor"

This reverts commit 187f38eb.
上级 915dde17
......@@ -110,7 +110,7 @@ def build(config, mode, use_dali=False, seed=None):
config_dataset = copy.deepcopy(config_dataset)
dataset_name = config_dataset.pop('name')
if 'batch_transform_ops' in config_dataset:
batch_transform = config_dataset['batch_transform_ops']
batch_transform = config_dataset.pop('batch_transform_ops')
else:
batch_transform = None
......@@ -254,11 +254,10 @@ def build_dataloader(config, mode):
if mode == "eval" or (mode == "train" and
config["Global"]["eval_during_train"]):
task = config["Global"].get("task", "classification")
if task in ["classification", "adaface"]:
if config["Global"]["eval_mode"] in ["classification", "adaface"]:
dataloader_dict["Eval"] = build(
config["DataLoader"], "Eval", use_dali, seed=None)
elif task == "retrieval":
elif config["Global"]["eval_mode"] == "retrieval":
if len(config["DataLoader"]["Eval"].keys()) == 1:
key = list(config["DataLoader"]["Eval"].keys())[0]
dataloader_dict["GalleryQuery"] = build(
......
......@@ -42,7 +42,7 @@ from ppcls.data.preprocess.ops.dali_operators import RandomRot90
from ppcls.data.preprocess.ops.dali_operators import RandomRotation
from ppcls.data.preprocess.ops.dali_operators import ResizeImage
from ppcls.data.preprocess.ops.dali_operators import ToCHWImage
from ppcls.utils import type_name
from ppcls.engine.train.utils import type_name
from ppcls.utils import logger
INTERP_MAP = {
......
......@@ -14,7 +14,8 @@
from __future__ import absolute_import, division, print_function
from ppcls.data import build_dataloader
from ppcls.utils import logger, type_name
from ppcls.engine.train.utils import type_name
from ppcls.utils import logger
from .regular_train_epoch import regular_train_epoch
......
......@@ -14,7 +14,7 @@
from __future__ import absolute_import, division, print_function
import datetime
from ppcls.utils import logger, type_name
from ppcls.utils import logger
from ppcls.utils.misc import AverageMeter
......@@ -75,3 +75,8 @@ def log_info(trainer, batch_size, epoch_id, iter_id):
value=trainer.output_info[key].avg,
step=trainer.global_step,
writer=trainer.vdl_writer)
def type_name(object: object) -> str:
"""get class name of an object"""
return object.__class__.__name__
......@@ -94,7 +94,6 @@ def build_loss(config, mode="train"):
if unlabel_loss_info:
unlabel_train_loss_func = CombinedLoss(
copy.deepcopy(unlabel_loss_info))
return train_loss_func, unlabel_train_loss_func
if mode == "eval" or (mode == "train" and
config["Global"]["eval_during_train"]):
loss_config = config.get("Loss", None)
......@@ -102,4 +101,5 @@ def build_loss(config, mode="train"):
loss_config = loss_config.get("Eval")
if loss_config is not None:
eval_loss_func = CombinedLoss(copy.deepcopy(loss_config))
return eval_loss_func
return train_loss_func, unlabel_train_loss_func, eval_loss_func
......@@ -65,19 +65,22 @@ class CombinedMetrics(AvgMetrics):
metric.reset()
def build_metrics(config, mode):
def build_metrics(engine):
config, mode = engine.config, engine.mode
if mode == 'train' and "Metric" in config and "Train" in config[
"Metric"] and config["Metric"]["Train"]:
metric_config = config["Metric"]["Train"]
if config["DataLoader"]["Train"]["dataset"].get("batch_transform_ops",
None):
if hasattr(engine.dataloader_dict["Train"],
"collate_fn") and engine.dataloader_dict[
"Train"].collate_fn is not None:
for m_idx, m in enumerate(metric_config):
if "TopkAcc" in m:
msg = f"Unable to calculate accuracy when using \"batch_transform_ops\". The metric \"{m}\" has been removed."
logger.warning(msg)
metric_config.pop(m_idx)
train_metric_func = CombinedMetrics(copy.deepcopy(metric_config))
return train_metric_func
else:
train_metric_func = None
if mode == "eval" or (mode == "train" and
config["Global"]["eval_during_train"]):
......@@ -94,4 +97,7 @@ def build_metrics(config, mode):
else:
metric_config = [{"name": "Recallk", "topk": (1, 5)}]
eval_metric_func = CombinedMetrics(copy.deepcopy(metric_config))
return eval_metric_func
else:
eval_metric_func = None
return train_metric_func, eval_metric_func
......@@ -21,7 +21,8 @@ import copy
import paddle
from typing import Dict, List
from ..utils import logger, type_name
from ppcls.engine.train.utils import type_name
from ppcls.utils import logger
from . import optimizer
......@@ -44,10 +45,14 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch):
# model_list is None in static graph
def build_optimizer(config, max_iter, model_list, update_freq):
def build_optimizer(engine):
if engine.mode != "train":
return None, None
config, max_iter, model_list = engine.config, engine.dataloader_dict[
"Train"].max_iter, [engine.model, engine.train_loss_func]
optim_config = copy.deepcopy(config["Optimizer"])
epochs = config["Global"]["epochs"]
update_freq = config["Global"].get("update_freq", 1)
update_freq = engine.update_freq
step_each_epoch = max_iter // update_freq
if isinstance(optim_config, dict):
# convert {'name': xxx, **optim_cfg} to [{name: {scope: xxx, **optim_cfg}}]
......
......@@ -26,8 +26,3 @@ from .metrics import multi_hot_encode
from .metrics import precision_recall_fscore
from .misc import AverageMeter
from .save_load import init_model
def type_name(object: object) -> str:
"""get class name of an object"""
return object.__class__.__name__
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册