提交 5d06a88a 编写于 作者: T Tingquan Gao

Revert "refactor: simplify engine"

This reverts commit 376d83d4.
上级 6aabb94d
...@@ -15,8 +15,6 @@ ...@@ -15,8 +15,6 @@
import inspect import inspect
import copy import copy
import random import random
import platform
import paddle import paddle
import numpy as np import numpy as np
import paddle.distributed as dist import paddle.distributed as dist
...@@ -88,7 +86,7 @@ def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int): ...@@ -88,7 +86,7 @@ def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int):
random.seed(worker_seed) random.seed(worker_seed)
def build(config, mode, device, use_dali=False, seed=None): def build_dataloader(config, mode, device, use_dali=False, seed=None):
assert mode in [ assert mode in [
'Train', 'Eval', 'Test', 'Gallery', 'Query', 'UnLabelTrain' 'Train', 'Eval', 'Test', 'Gallery', 'Query', 'UnLabelTrain'
], "Dataset mode should be Train, Eval, Test, Gallery, Query, UnLabelTrain" ], "Dataset mode should be Train, Eval, Test, Gallery, Query, UnLabelTrain"
...@@ -189,79 +187,3 @@ def build(config, mode, device, use_dali=False, seed=None): ...@@ -189,79 +187,3 @@ def build(config, mode, device, use_dali=False, seed=None):
logger.debug("build data_loader({}) success...".format(data_loader)) logger.debug("build data_loader({}) success...".format(data_loader))
return data_loader return data_loader
def build_dataloader(engine):
if "class_num" in engine.config["Global"]:
global_class_num = engine.config["Global"]["class_num"]
if "class_num" not in config["Arch"]:
engine.config["Arch"]["class_num"] = global_class_num
msg = f"The Global.class_num will be deprecated. Please use Arch.class_num instead. Arch.class_num has been set to {global_class_num}."
else:
msg = "The Global.class_num will be deprecated. Please use Arch.class_num instead. The Global.class_num has been ignored."
logger.warning(msg)
class_num = engine.config["Arch"].get("class_num", None)
engine.config["DataLoader"].update({"class_num": class_num})
engine.config["DataLoader"].update({
"epochs": engine.config["Global"]["epochs"]
})
use_dali = engine.config['Global'].get("use_dali", False)
dataloader_dict = {
"Train": None,
"UnLabelTrain": None,
"Eval": None,
"Query": None,
"Gallery": None,
"GalleryQuery": None
}
if engine.mode == 'train':
train_dataloader = build(
engine.config["DataLoader"],
"Train",
engine.device,
use_dali,
seed=None)
iter_per_epoch = len(train_dataloader) - 1 if platform.system(
) == "Windows" else len(train_dataloader)
if engine.config["Global"].get("iter_per_epoch", None):
# set max iteration per epoch mannualy, when training by iteration(s), such as XBM, FixMatch.
iter_per_epoch = engine.config["Global"].get("iter_per_epoch")
iter_per_epoch = iter_per_epoch // engine.update_freq * engine.update_freq
engine.iter_per_epoch = iter_per_epoch
train_dataloader.iter_per_epoch = iter_per_epoch
dataloader_dict["Train"] = train_dataloader
if engine.config["DataLoader"].get('UnLabelTrain', None) is not None:
dataloader_dict["UnLabelTrain"] = build(
engine.config["DataLoader"],
"UnLabelTrain",
engine.device,
use_dali,
seed=None)
if engine.mode == "eval" or (engine.mode == "train" and
engine.config["Global"]["eval_during_train"]):
if engine.eval_mode in ["classification", "adaface"]:
dataloader_dict["Eval"] = build(
engine.config["DataLoader"],
"Eval",
engine.device,
use_dali,
seed=None)
elif engine.eval_mode == "retrieval":
if len(engine.config["DataLoader"]["Eval"].keys()) == 1:
key = list(engine.config["DataLoader"]["Eval"].keys())[0]
dataloader_dict["GalleryQuery"] = build_dataloader(
engine.config["DataLoader"]["Eval"], key, engine.device,
use_dali)
else:
dataloader_dict["Gallery"] = build_dataloader(
engine.config["DataLoader"]["Eval"], "Gallery",
engine.device, use_dali)
dataloader_dict["Query"] = build_dataloader(
engine.config["DataLoader"]["Eval"], "Query",
engine.device, use_dali)
return dataloader_dict
...@@ -15,6 +15,7 @@ from __future__ import division ...@@ -15,6 +15,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
import platform
import paddle import paddle
import paddle.distributed as dist import paddle.distributed as dist
from visualdl import LogWriter from visualdl import LogWriter
...@@ -51,60 +52,168 @@ class Engine(object): ...@@ -51,60 +52,168 @@ class Engine(object):
assert mode in ["train", "eval", "infer", "export"] assert mode in ["train", "eval", "infer", "export"]
self.mode = mode self.mode = mode
self.config = config self.config = config
self.eval_mode = self.config["Global"].get("eval_mode",
# set seed "classification")
self._init_seed() self.train_mode = self.config["Global"].get("train_mode", None)
# init logger
init_logger(self.config, mode=mode)
print_config(config)
# for visualdl
self.vdl_writer = self._init_vdl()
# is_rec
if "Head" in self.config["Arch"] or self.config["Arch"].get("is_rec", if "Head" in self.config["Arch"] or self.config["Arch"].get("is_rec",
False): False):
self.is_rec = True self.is_rec = True
else: else:
self.is_rec = False self.is_rec = False
# set seed
seed = self.config["Global"].get("seed", False)
if seed or seed == 0:
assert isinstance(seed, int), "The 'seed' must be a integer!"
paddle.seed(seed)
np.random.seed(seed)
random.seed(seed)
# init logger
self.output_dir = self.config['Global']['output_dir']
log_file = os.path.join(self.output_dir, self.config["Arch"]["name"],
f"{mode}.log")
init_logger(log_file=log_file)
print_config(config)
# init train_func and eval_func # init train_func and eval_func
self.train_mode = self.config["Global"].get("train_mode", None) assert self.eval_mode in [
"classification", "retrieval", "adaface"
], logger.error("Invalid eval mode: {}".format(self.eval_mode))
if self.train_mode is None: if self.train_mode is None:
self.train_epoch_func = train_method.train_epoch self.train_epoch_func = train_method.train_epoch
else: else:
self.train_epoch_func = getattr(train_method, self.train_epoch_func = getattr(train_method,
"train_epoch_" + self.train_mode) "train_epoch_" + self.train_mode)
self.eval_mode = self.config["Global"].get("eval_mode",
"classification")
assert self.eval_mode in [
"classification", "retrieval", "adaface"
], logger.error("Invalid eval mode: {}".format(self.eval_mode))
self.eval_func = getattr(evaluation, self.eval_mode + "_eval") self.eval_func = getattr(evaluation, self.eval_mode + "_eval")
self.use_dali = self.config['Global'].get("use_dali", False)
# for visualdl
self.vdl_writer = None
if self.config['Global'][
'use_visualdl'] and mode == "train" and dist.get_rank() == 0:
vdl_writer_path = os.path.join(self.output_dir, "vdl")
if not os.path.exists(vdl_writer_path):
os.makedirs(vdl_writer_path)
self.vdl_writer = LogWriter(logdir=vdl_writer_path)
# set device # set device
self.device = self._init_device() assert self.config["Global"][
"device"] in ["cpu", "gpu", "xpu", "npu", "mlu", "ascend"]
self.device = paddle.set_device(self.config["Global"]["device"])
logger.info('train with paddle {} and device {}'.format(
paddle.__version__, self.device))
# gradient accumulation # gradient accumulation
self.update_freq = self.config["Global"].get("update_freq", 1) self.update_freq = self.config["Global"].get("update_freq", 1)
if "class_num" in config["Global"]:
global_class_num = config["Global"]["class_num"]
if "class_num" not in config["Arch"]:
config["Arch"]["class_num"] = global_class_num
msg = f"The Global.class_num will be deprecated. Please use Arch.class_num instead. Arch.class_num has been set to {global_class_num}."
else:
msg = "The Global.class_num will be deprecated. Please use Arch.class_num instead. The Global.class_num has been ignored."
logger.warning(msg)
#TODO(gaotingquan): support rec
class_num = config["Arch"].get("class_num", None)
self.config["DataLoader"].update({"class_num": class_num})
self.config["DataLoader"].update({
"epochs": self.config["Global"]["epochs"]
})
# build dataloader # build dataloader
self.dataloader_dict = build_dataloader(self) if self.mode == 'train':
self.train_dataloader, self.unlabel_train_dataloader, self.eval_dataloader = self.dataloader_dict[ self.train_dataloader = build_dataloader(
"Train"], self.dataloader_dict[ self.config["DataLoader"], "Train", self.device, self.use_dali)
"UnLabelTrain"], self.dataloader_dict["Eval"] if self.config["DataLoader"].get('UnLabelTrain', None) is not None:
self.gallery_query_dataloader, self.gallery_dataloader, self.query_dataloader = self.dataloader_dict[ self.unlabel_train_dataloader = build_dataloader(
"GalleryQuery"], self.dataloader_dict[ self.config["DataLoader"], "UnLabelTrain", self.device,
"Gallery"], self.dataloader_dict["Query"] self.use_dali)
else:
self.unlabel_train_dataloader = None
self.iter_per_epoch = len(
self.train_dataloader) - 1 if platform.system(
) == "Windows" else len(self.train_dataloader)
if self.config["Global"].get("iter_per_epoch", None):
# set max iteration per epoch mannualy, when training by iteration(s), such as XBM, FixMatch.
self.iter_per_epoch = self.config["Global"].get(
"iter_per_epoch")
self.iter_per_epoch = self.iter_per_epoch // self.update_freq * self.update_freq
if self.mode == "eval" or (self.mode == "train" and
self.config["Global"]["eval_during_train"]):
if self.eval_mode in ["classification", "adaface"]:
self.eval_dataloader = build_dataloader(
self.config["DataLoader"], "Eval", self.device,
self.use_dali)
elif self.eval_mode == "retrieval":
self.gallery_query_dataloader = None
if len(self.config["DataLoader"]["Eval"].keys()) == 1:
key = list(self.config["DataLoader"]["Eval"].keys())[0]
self.gallery_query_dataloader = build_dataloader(
self.config["DataLoader"]["Eval"], key, self.device,
self.use_dali)
else:
self.gallery_dataloader = build_dataloader(
self.config["DataLoader"]["Eval"], "Gallery",
self.device, self.use_dali)
self.query_dataloader = build_dataloader(
self.config["DataLoader"]["Eval"], "Query",
self.device, self.use_dali)
# build loss # build loss
self.train_loss_func, self.unlabel_train_loss_func, self.eval_loss_func = build_loss( if self.mode == "train":
self.config, self.mode) label_loss_info = self.config["Loss"]["Train"]
self.train_loss_func = build_loss(label_loss_info)
unlabel_loss_info = self.config.get("UnLabelLoss", {}).get("Train",
None)
self.unlabel_train_loss_func = build_loss(unlabel_loss_info)
if self.mode == "eval" or (self.mode == "train" and
self.config["Global"]["eval_during_train"]):
loss_config = self.config.get("Loss", None)
if loss_config is not None:
loss_config = loss_config.get("Eval")
if loss_config is not None:
self.eval_loss_func = build_loss(loss_config)
else:
self.eval_loss_func = None
else:
self.eval_loss_func = None
# build metric # build metric
self.train_metric_func, self.eval_metric_func = build_metrics(self) if self.mode == 'train' and "Metric" in self.config and "Train" in self.config[
"Metric"] and self.config["Metric"]["Train"]:
metric_config = self.config["Metric"]["Train"]
if hasattr(self.train_dataloader, "collate_fn"
) and self.train_dataloader.collate_fn is not None:
for m_idx, m in enumerate(metric_config):
if "TopkAcc" in m:
msg = f"Unable to calculate accuracy when using \"batch_transform_ops\". The metric \"{m}\" has been removed."
logger.warning(msg)
metric_config.pop(m_idx)
self.train_metric_func = build_metrics(metric_config)
else:
self.train_metric_func = None
if self.mode == "eval" or (self.mode == "train" and
self.config["Global"]["eval_during_train"]):
if self.eval_mode == "classification":
if "Metric" in self.config and "Eval" in self.config["Metric"]:
self.eval_metric_func = build_metrics(self.config["Metric"]
["Eval"])
else:
self.eval_metric_func = None
elif self.eval_mode == "retrieval":
if "Metric" in self.config and "Eval" in self.config["Metric"]:
metric_config = self.config["Metric"]["Eval"]
else:
metric_config = [{"name": "Recallk", "topk": (1, 5)}]
self.eval_metric_func = build_metrics(metric_config)
else:
self.eval_metric_func = None
# build model # build model
self.model = build_model(self.config, self.mode) self.model = build_model(self.config, self.mode)
...@@ -112,18 +221,139 @@ class Engine(object): ...@@ -112,18 +221,139 @@ class Engine(object):
apply_to_static(self.config, self.model) apply_to_static(self.config, self.model)
# load_pretrain # load_pretrain
self._init_pretrained() if self.config["Global"]["pretrained_model"] is not None:
if self.config["Global"]["pretrained_model"].startswith("http"):
load_dygraph_pretrain_from_url(
[self.model, getattr(self, 'train_loss_func', None)],
self.config["Global"]["pretrained_model"])
else:
load_dygraph_pretrain(
[self.model, getattr(self, 'train_loss_func', None)],
self.config["Global"]["pretrained_model"])
# build optimizer # build optimizer
if self.mode == 'train':
self.optimizer, self.lr_sch = build_optimizer( self.optimizer, self.lr_sch = build_optimizer(
self.config, self.train_dataloader, self.config["Optimizer"], self.config["Global"]["epochs"],
self.iter_per_epoch // self.update_freq,
[self.model, self.train_loss_func]) [self.model, self.train_loss_func])
# AMP training and evaluating # AMP training and evaluating
self._init_amp() self.amp = "AMP" in self.config and self.config["AMP"] is not None
self.amp_eval = False
# for amp
if self.amp:
AMP_RELATED_FLAGS_SETTING = {'FLAGS_max_inplace_grad_add': 8, }
if paddle.is_compiled_with_cuda():
AMP_RELATED_FLAGS_SETTING.update({
'FLAGS_cudnn_batchnorm_spatial_persistent': 1
})
paddle.set_flags(AMP_RELATED_FLAGS_SETTING)
self.scale_loss = self.config["AMP"].get("scale_loss", 1.0)
self.use_dynamic_loss_scaling = self.config["AMP"].get(
"use_dynamic_loss_scaling", False)
self.scaler = paddle.amp.GradScaler(
init_loss_scaling=self.scale_loss,
use_dynamic_loss_scaling=self.use_dynamic_loss_scaling)
self.amp_level = self.config['AMP'].get("level", "O1")
if self.amp_level not in ["O1", "O2"]:
msg = "[Parameter Error]: The optimize level of AMP only support 'O1' and 'O2'. The level has been set 'O1'."
logger.warning(msg)
self.config['AMP']["level"] = "O1"
self.amp_level = "O1"
self.amp_eval = self.config["AMP"].get("use_fp16_test", False)
# TODO(gaotingquan): Paddle not yet support FP32 evaluation when training with AMPO2
if self.mode == "train" and self.config["Global"].get(
"eval_during_train",
True) and self.amp_level == "O2" and self.amp_eval == False:
msg = "PaddlePaddle only support FP16 evaluation when training with AMP O2 now. "
logger.warning(msg)
self.config["AMP"]["use_fp16_test"] = True
self.amp_eval = True
# TODO(gaotingquan): to compatible with different versions of Paddle
paddle_version = paddle.__version__[:3]
# paddle version < 2.3.0 and not develop
if paddle_version not in ["2.3", "0.0"]:
if self.mode == "train":
self.model, self.optimizer = paddle.amp.decorate(
models=self.model,
optimizers=self.optimizer,
level=self.amp_level,
save_dtype='float32')
elif self.amp_eval:
if self.amp_level == "O2":
msg = "The PaddlePaddle that installed not support FP16 evaluation in AMP O2. Please use PaddlePaddle version >= 2.3.0. Use FP32 evaluation instead and please notice the Eval Dataset output_fp16 should be 'False'."
logger.warning(msg)
self.amp_eval = False
else:
self.model, self.optimizer = paddle.amp.decorate(
models=self.model,
level=self.amp_level,
save_dtype='float32')
# paddle version >= 2.3.0 or develop
else:
if self.mode == "train" or self.amp_eval:
self.model = paddle.amp.decorate(
models=self.model,
level=self.amp_level,
save_dtype='float32')
if self.mode == "train" and len(self.train_loss_func.parameters(
)) > 0:
self.train_loss_func = paddle.amp.decorate(
models=self.train_loss_func,
level=self.amp_level,
save_dtype='float32')
# build EMA model
self.ema = "EMA" in self.config and self.mode == "train"
if self.ema:
self.model_ema = ExponentialMovingAverage(
self.model, self.config['EMA'].get("decay", 0.9999))
# check the gpu num
world_size = dist.get_world_size()
self.config["Global"]["distributed"] = world_size != 1
if self.mode == "train":
std_gpu_num = 8 if isinstance(
self.config["Optimizer"],
dict) and self.config["Optimizer"]["name"] == "AdamW" else 4
if world_size != std_gpu_num:
msg = f"The training strategy provided by PaddleClas is based on {std_gpu_num} gpus. But the number of gpu is {world_size} in current training. Please modify the stategy (learning rate, batch size and so on) if use this config to train."
logger.warning(msg)
# for distributed # for distributed
self._init_dist() if self.config["Global"]["distributed"]:
dist.init_parallel_env()
self.model = paddle.DataParallel(self.model)
if self.mode == 'train' and len(self.train_loss_func.parameters(
)) > 0:
self.train_loss_func = paddle.DataParallel(
self.train_loss_func)
# set different seed in different GPU manually in distributed environment
if seed is None:
logger.warning(
"The random seed cannot be None in a distributed environment. Global.seed has been set to 42 by default"
)
self.config["Global"]["seed"] = seed = 42
logger.info(
f"Set random seed to ({int(seed)} + $PADDLE_TRAINER_ID) for different trainer"
)
paddle.seed(int(seed) + dist.get_rank())
np.random.seed(int(seed) + dist.get_rank())
random.seed(int(seed) + dist.get_rank())
# build postprocess for infer
if self.mode == 'infer':
self.preprocess_func = create_operators(self.config["Infer"][
"transforms"])
self.postprocess_func = build_postprocess(self.config["Infer"][
"PostProcess"])
def train(self): def train(self):
assert self.mode == "train" assert self.mode == "train"
...@@ -133,17 +363,10 @@ class Engine(object): ...@@ -133,17 +363,10 @@ class Engine(object):
"metric": -1.0, "metric": -1.0,
"epoch": 0, "epoch": 0,
} }
ema_module = None
# build EMA model
self.ema = "EMA" in self.config and self.mode == "train"
if self.ema: if self.ema:
self.model_ema = ExponentialMovingAverage(
self.model, self.config['EMA'].get("decay", 0.9999))
best_metric_ema = 0.0 best_metric_ema = 0.0
ema_module = self.model_ema.module ema_module = self.model_ema.module
else:
ema_module = None
# key: # key:
# val: metrics list word # val: metrics list word
self.output_info = dict() self.output_info = dict()
...@@ -169,6 +392,8 @@ class Engine(object): ...@@ -169,6 +392,8 @@ class Engine(object):
# for one epoch train # for one epoch train
self.train_epoch_func(self, epoch_id, print_batch_step) self.train_epoch_func(self, epoch_id, print_batch_step)
if self.use_dali:
self.train_dataloader.reset()
metric_msg = ", ".join( metric_msg = ", ".join(
[self.output_info[key].avg_info for key in self.output_info]) [self.output_info[key].avg_info for key in self.output_info])
logger.info("[Train][Epoch {}/{}][Avg]{}".format( logger.info("[Train][Epoch {}/{}][Avg]{}".format(
...@@ -274,12 +499,6 @@ class Engine(object): ...@@ -274,12 +499,6 @@ class Engine(object):
@paddle.no_grad() @paddle.no_grad()
def infer(self): def infer(self):
assert self.mode == "infer" and self.eval_mode == "classification" assert self.mode == "infer" and self.eval_mode == "classification"
self.preprocess_func = create_operators(self.config["Infer"][
"transforms"])
self.postprocess_func = build_postprocess(self.config["Infer"][
"PostProcess"])
total_trainer = dist.get_world_size() total_trainer = dist.get_world_size()
local_rank = dist.get_rank() local_rank = dist.get_rank()
image_list = get_image_list(self.config["Infer"]["infer_imgs"]) image_list = get_image_list(self.config["Infer"]["infer_imgs"])
...@@ -367,148 +586,6 @@ class Engine(object): ...@@ -367,148 +586,6 @@ class Engine(object):
f"Export succeeded! The inference model exported has been saved in \"{self.config['Global']['save_inference_dir']}\"." f"Export succeeded! The inference model exported has been saved in \"{self.config['Global']['save_inference_dir']}\"."
) )
def _init_vdl(self):
if self.config['Global'][
'use_visualdl'] and mode == "train" and dist.get_rank() == 0:
vdl_writer_path = os.path.join(self.output_dir, "vdl")
if not os.path.exists(vdl_writer_path):
os.makedirs(vdl_writer_path)
return LogWriter(logdir=vdl_writer_path)
return None
def _init_seed(self):
seed = self.config["Global"].get("seed", False)
if dist.get_world_size() != 1:
# if self.config["Global"]["distributed"]:
# set different seed in different GPU manually in distributed environment
if not seed:
logger.warning(
"The random seed cannot be None in a distributed environment. Global.seed has been set to 42 by default"
)
self.config["Global"]["seed"] = seed = 42
logger.info(
f"Set random seed to ({int(seed)} + $PADDLE_TRAINER_ID) for different trainer"
)
dist_seed = int(seed) + dist.get_rank()
paddle.seed(dist_seed)
np.random.seed(dist_seed)
random.seed(dist_seed)
elif seed or seed == 0:
assert isinstance(seed, int), "The 'seed' must be a integer!"
paddle.seed(seed)
np.random.seed(seed)
random.seed(seed)
def _init_device(self):
device = self.config["Global"]["device"]
assert device in ["cpu", "gpu", "xpu", "npu", "mlu", "ascend"]
logger.info('train with paddle {} and device {}'.format(
paddle.__version__, device))
return paddle.set_device(device)
def _init_pretrained(self):
if self.config["Global"]["pretrained_model"] is not None:
if self.config["Global"]["pretrained_model"].startswith("http"):
load_dygraph_pretrain_from_url(
[self.model, getattr(self, 'train_loss_func', None)],
self.config["Global"]["pretrained_model"])
else:
load_dygraph_pretrain(
[self.model, getattr(self, 'train_loss_func', None)],
self.config["Global"]["pretrained_model"])
def _init_amp(self):
self.amp = "AMP" in self.config and self.config["AMP"] is not None
self.amp_eval = False
# for amp
if self.amp:
AMP_RELATED_FLAGS_SETTING = {'FLAGS_max_inplace_grad_add': 8, }
if paddle.is_compiled_with_cuda():
AMP_RELATED_FLAGS_SETTING.update({
'FLAGS_cudnn_batchnorm_spatial_persistent': 1
})
paddle.set_flags(AMP_RELATED_FLAGS_SETTING)
self.scale_loss = self.config["AMP"].get("scale_loss", 1.0)
self.use_dynamic_loss_scaling = self.config["AMP"].get(
"use_dynamic_loss_scaling", False)
self.scaler = paddle.amp.GradScaler(
init_loss_scaling=self.scale_loss,
use_dynamic_loss_scaling=self.use_dynamic_loss_scaling)
self.amp_level = self.config['AMP'].get("level", "O1")
if self.amp_level not in ["O1", "O2"]:
msg = "[Parameter Error]: The optimize level of AMP only support 'O1' and 'O2'. The level has been set 'O1'."
logger.warning(msg)
self.config['AMP']["level"] = "O1"
self.amp_level = "O1"
self.amp_eval = self.config["AMP"].get("use_fp16_test", False)
# TODO(gaotingquan): Paddle not yet support FP32 evaluation when training with AMPO2
if self.mode == "train" and self.config["Global"].get(
"eval_during_train",
True) and self.amp_level == "O2" and self.amp_eval == False:
msg = "PaddlePaddle only support FP16 evaluation when training with AMP O2 now. "
logger.warning(msg)
self.config["AMP"]["use_fp16_test"] = True
self.amp_eval = True
# TODO(gaotingquan): to compatible with different versions of Paddle
paddle_version = paddle.__version__[:3]
# paddle version < 2.3.0 and not develop
if paddle_version not in ["2.3", "0.0"]:
if self.mode == "train":
self.model, self.optimizer = paddle.amp.decorate(
models=self.model,
optimizers=self.optimizer,
level=self.amp_level,
save_dtype='float32')
elif self.amp_eval:
if self.amp_level == "O2":
msg = "The PaddlePaddle that installed not support FP16 evaluation in AMP O2. Please use PaddlePaddle version >= 2.3.0. Use FP32 evaluation instead and please notice the Eval Dataset output_fp16 should be 'False'."
logger.warning(msg)
self.amp_eval = False
else:
self.model, self.optimizer = paddle.amp.decorate(
models=self.model,
level=self.amp_level,
save_dtype='float32')
# paddle version >= 2.3.0 or develop
else:
if self.mode == "train" or self.amp_eval:
self.model = paddle.amp.decorate(
models=self.model,
level=self.amp_level,
save_dtype='float32')
if self.mode == "train" and len(self.train_loss_func.parameters(
)) > 0:
self.train_loss_func = paddle.amp.decorate(
models=self.train_loss_func,
level=self.amp_level,
save_dtype='float32')
def _init_dist(self):
# check the gpu num
world_size = dist.get_world_size()
self.config["Global"]["distributed"] = world_size != 1
# TODO(gaotingquan):
if self.mode == "train":
std_gpu_num = 8 if isinstance(
self.config["Optimizer"],
dict) and self.config["Optimizer"]["name"] == "AdamW" else 4
if world_size != std_gpu_num:
msg = f"The training strategy provided by PaddleClas is based on {std_gpu_num} gpus. But the number of gpu is {world_size} in current training. Please modify the stategy (learning rate, batch size and so on) if use this config to train."
logger.warning(msg)
if self.config["Global"]["distributed"]:
dist.init_parallel_env()
self.model = paddle.DataParallel(self.model)
if self.mode == 'train' and len(self.train_loss_func.parameters(
)) > 0:
self.train_loss_func = paddle.DataParallel(
self.train_loss_func)
class ExportModel(TheseusLayer): class ExportModel(TheseusLayer):
""" """
......
...@@ -51,7 +51,7 @@ from .metabinloss import IntraDomainScatterLoss ...@@ -51,7 +51,7 @@ from .metabinloss import IntraDomainScatterLoss
class CombinedLoss(nn.Layer): class CombinedLoss(nn.Layer):
def __init__(self, config_list): def __init__(self, config_list):
super().__init__() super().__init__()
loss_func = [] self.loss_func = []
self.loss_weight = [] self.loss_weight = []
assert isinstance(config_list, list), ( assert isinstance(config_list, list), (
'operator config should be a list') 'operator config should be a list')
...@@ -63,9 +63,8 @@ class CombinedLoss(nn.Layer): ...@@ -63,9 +63,8 @@ class CombinedLoss(nn.Layer):
assert "weight" in param, "weight must be in param, but param just contains {}".format( assert "weight" in param, "weight must be in param, but param just contains {}".format(
param.keys()) param.keys())
self.loss_weight.append(param.pop("weight")) self.loss_weight.append(param.pop("weight"))
loss_func.append(eval(name)(**param)) self.loss_func.append(eval(name)(**param))
self.loss_func = nn.LayerList(loss_func) self.loss_func = nn.LayerList(self.loss_func)
logger.debug("build loss {} success.".format(loss_func))
def __call__(self, input, batch): def __call__(self, input, batch):
loss_dict = {} loss_dict = {}
...@@ -84,22 +83,9 @@ class CombinedLoss(nn.Layer): ...@@ -84,22 +83,9 @@ class CombinedLoss(nn.Layer):
return loss_dict return loss_dict
def build_loss(config, mode="train"): def build_loss(config):
train_loss_func, unlabel_train_loss_func, eval_loss_func = None, None, None if config is None:
if mode == "train": return None
label_loss_info = config["Loss"]["Train"] module_class = CombinedLoss(copy.deepcopy(config))
if label_loss_info: logger.debug("build loss {} success.".format(module_class))
train_loss_func = CombinedLoss(copy.deepcopy(label_loss_info)) return module_class
unlabel_loss_info = config.get("UnLabelLoss", {}).get("Train", None)
if unlabel_loss_info:
unlabel_train_loss_func = CombinedLoss(
copy.deepcopy(unlabel_loss_info))
if mode == "eval" or (mode == "train" and
config["Global"]["eval_during_train"]):
loss_config = config.get("Loss", None)
if loss_config is not None:
loss_config = loss_config.get("Eval")
if loss_config is not None:
eval_loss_func = CombinedLoss(copy.deepcopy(loss_config))
return train_loss_func, unlabel_train_loss_func, eval_loss_func
...@@ -65,38 +65,6 @@ class CombinedMetrics(AvgMetrics): ...@@ -65,38 +65,6 @@ class CombinedMetrics(AvgMetrics):
metric.reset() metric.reset()
def build_metrics(engine): def build_metrics(config):
config, mode = engine.config, engine.mode metrics_list = CombinedMetrics(copy.deepcopy(config))
if mode == 'train' and "Metric" in config and "Train" in config[ return metrics_list
"Metric"] and config["Metric"]["Train"]:
metric_config = config["Metric"]["Train"]
if hasattr(engine.train_dataloader, "collate_fn"
) and engine.train_dataloader.collate_fn is not None:
for m_idx, m in enumerate(metric_config):
if "TopkAcc" in m:
msg = f"Unable to calculate accuracy when using \"batch_transform_ops\". The metric \"{m}\" has been removed."
logger.warning(msg)
metric_config.pop(m_idx)
train_metric_func = CombinedMetrics(copy.deepcopy(metric_config))
else:
train_metric_func = None
if mode == "eval" or (mode == "train" and
config["Global"]["eval_during_train"]):
eval_mode = config["Global"].get("eval_mode", "classification")
if eval_mode == "classification":
if "Metric" in config and "Eval" in config["Metric"]:
eval_metric_func = CombinedMetrics(
copy.deepcopy(config["Metric"]["Eval"]))
else:
eval_metric_func = None
elif eval_mode == "retrieval":
if "Metric" in config and "Eval" in config["Metric"]:
metric_config = config["Metric"]["Eval"]
else:
metric_config = [{"name": "Recallk", "topk": (1, 5)}]
eval_metric_func = CombinedMetrics(copy.deepcopy(metric_config))
else:
eval_metric_func = None
return train_metric_func, eval_metric_func
...@@ -45,11 +45,8 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch): ...@@ -45,11 +45,8 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch):
# model_list is None in static graph # model_list is None in static graph
def build_optimizer(config, dataloader, model_list=None): def build_optimizer(config, epochs, step_each_epoch, model_list=None):
optim_config = copy.deepcopy(config["Optimizer"]) optim_config = copy.deepcopy(config)
epochs = config["Global"]["epochs"]
update_freq = config["Global"].get("update_freq", 1)
step_each_epoch = dataloader.iter_per_epoch // update_freq
if isinstance(optim_config, dict): if isinstance(optim_config, dict):
# convert {'name': xxx, **optim_cfg} to [{name: {scope: xxx, **optim_cfg}}] # convert {'name': xxx, **optim_cfg} to [{name: {scope: xxx, **optim_cfg}}]
optim_name = optim_config.pop("name") optim_name = optim_config.pop("name")
......
...@@ -22,15 +22,16 @@ import paddle.distributed as dist ...@@ -22,15 +22,16 @@ import paddle.distributed as dist
_logger = None _logger = None
def init_logger(config, mode="train", name='ppcls', log_level=logging.INFO): def init_logger(name='ppcls', log_file=None, log_level=logging.INFO):
"""Initialize and get a logger by name. """Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will logger by adding one or two handlers, otherwise the initialized logger will
be directly returned. During initialization, a StreamHandler will always be be directly returned. During initialization, a StreamHandler will always be
added. added. If `log_file` is specified a FileHandler will also be added.
Args: Args:
config(dict): Training config.
name (str): Logger name. name (str): Logger name.
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the logger.
log_level (int): The logger level. Note that only the process of log_level (int): The logger level. Note that only the process of
rank 0 is affected, and other processes will set the level to rank 0 is affected, and other processes will set the level to
"Error" thus be silent most of the time. "Error" thus be silent most of the time.
...@@ -62,8 +63,6 @@ def init_logger(config, mode="train", name='ppcls', log_level=logging.INFO): ...@@ -62,8 +63,6 @@ def init_logger(config, mode="train", name='ppcls', log_level=logging.INFO):
if init_flag: if init_flag:
_logger.addHandler(stream_handler) _logger.addHandler(stream_handler)
log_file = os.path.join(config['Global']['output_dir'],
config["Arch"]["name"], f"{mode}.log")
if log_file is not None and dist.get_rank() == 0: if log_file is not None and dist.get_rank() == 0:
log_file_folder = os.path.split(log_file)[0] log_file_folder = os.path.split(log_file)[0]
os.makedirs(log_file_folder, exist_ok=True) os.makedirs(log_file_folder, exist_ok=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册