提交 32593b63 编写于 作者: G gaotingquan 提交者: Wei Shengyu

refactor

上级 d3941dc1
......@@ -88,25 +88,32 @@ def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int):
random.seed(worker_seed)
def build(config, mode, use_dali=False, seed=None):
def build_dataloader(config, mode, seed=None):
assert mode in [
'Train', 'Eval', 'Test', 'Gallery', 'Query', 'UnLabelTrain'
], "Dataset mode should be Train, Eval, Test, Gallery, Query, UnLabelTrain"
assert mode in config.keys(), "{} config not in yaml".format(mode)
assert mode in config["DataLoader"].keys(), "{} config not in yaml".format(
mode)
dataloader_config = config["DataLoader"][mode]
class_num = config["Arch"].get("class_num", None)
epochs = config["Global"]["epochs"]
use_dali = config["Global"].get("use_dali", False)
num_workers = dataloader_config['loader']["num_workers"]
use_shared_memory = dataloader_config['loader']["use_shared_memory"]
# build dataset
if use_dali:
from ppcls.data.dataloader.dali import dali_dataloader
return dali_dataloader(
config,
config["DataLoader"],
mode,
paddle.device.get_device(),
num_threads=config[mode]['loader']["num_workers"],
num_threads=num_workers,
seed=seed,
enable_fuse=True)
class_num = config.get("class_num", None)
epochs = config.get("epochs", None)
config_dataset = config[mode]['dataset']
config_dataset = dataloader_config['dataset']
config_dataset = copy.deepcopy(config_dataset)
dataset_name = config_dataset.pop('name')
if 'batch_transform_ops' in config_dataset:
......@@ -119,7 +126,7 @@ def build(config, mode, use_dali=False, seed=None):
logger.debug("build dataset({}) success...".format(dataset))
# build sampler
config_sampler = config[mode]['sampler']
config_sampler = dataloader_config['sampler']
if config_sampler and "name" not in config_sampler:
batch_sampler = None
batch_size = config_sampler["batch_size"]
......@@ -153,11 +160,6 @@ def build(config, mode, use_dali=False, seed=None):
else:
batch_collate_fn = None
# build dataloader
config_loader = config[mode]['loader']
num_workers = config_loader["num_workers"]
use_shared_memory = config_loader["use_shared_memory"]
init_fn = partial(
worker_init_fn,
num_workers=num_workers,
......@@ -194,78 +196,36 @@ def build(config, mode, use_dali=False, seed=None):
data_loader.max_iter = max_iter
data_loader.total_samples = total_samples
logger.debug("build data_loader({}) success...".format(data_loader))
return data_loader
# TODO(gaotingquan): perf
class DataIterator(object):
def __init__(self, dataloader, use_dali=False):
self.dataloader = dataloader
self.use_dali = use_dali
self.iterator = iter(dataloader)
self.max_iter = dataloader.max_iter
self.total_samples = dataloader.total_samples
def get_batch(self):
# fetch data batch from dataloader
try:
batch = next(self.iterator)
except Exception:
# NOTE: reset DALI dataloader manually
if self.use_dali:
self.dataloader.reset()
self.iterator = iter(self.dataloader)
batch = next(self.iterator)
return batch
def build_dataloader(config, mode):
class_num = config["Arch"].get("class_num", None)
config["DataLoader"].update({"class_num": class_num})
config["DataLoader"].update({"epochs": config["Global"]["epochs"]})
use_dali = config["Global"].get("use_dali", False)
dataloader_dict = {
"Train": None,
"UnLabelTrain": None,
"Eval": None,
"Query": None,
"Gallery": None,
"GalleryQuery": None
}
if mode == 'train':
train_dataloader = build(
config["DataLoader"], "Train", use_dali, seed=None)
if config["DataLoader"]["Train"].get("max_iter", None):
# TODO(gaotingquan): mv to build_sampler
if mode == "train":
if dataloader_config["Train"].get("max_iter", None):
# set max iteration per epoch mannualy, when training by iteration(s), such as XBM, FixMatch.
max_iter = config["Train"].get("max_iter")
update_freq = config["Global"].get("update_freq", 1)
max_iter = train_dataloader.max_iter // update_freq * update_freq
train_dataloader.max_iter = max_iter
if config["DataLoader"]["Train"].get("convert_iterator", True):
train_dataloader = DataIterator(train_dataloader, use_dali)
dataloader_dict["Train"] = train_dataloader
max_iter = data_loader.max_iter // update_freq * update_freq
data_loader.max_iter = max_iter
logger.debug("build data_loader({}) success...".format(data_loader))
return data_loader
if config["DataLoader"].get('UnLabelTrain', None) is not None:
dataloader_dict["UnLabelTrain"] = build(
config["DataLoader"], "UnLabelTrain", use_dali, seed=None)
if mode == "eval" or (mode == "train" and
config["Global"]["eval_during_train"]):
task = config["Global"].get("task", "classification")
if task in ["classification", "adaface"]:
dataloader_dict["Eval"] = build(
config["DataLoader"], "Eval", use_dali, seed=None)
elif task == "retrieval":
if len(config["DataLoader"]["Eval"].keys()) == 1:
key = list(config["DataLoader"]["Eval"].keys())[0]
dataloader_dict["GalleryQuery"] = build(
config["DataLoader"]["Eval"], key, use_dali)
else:
dataloader_dict["Gallery"] = build(
config["DataLoader"]["Eval"], "Gallery", use_dali)
dataloader_dict["Query"] = build(config["DataLoader"]["Eval"],
"Query", use_dali)
return dataloader_dict
# # TODO(gaotingquan): the length of dataloader should be determined by sampler
# class DataIterator(object):
# def __init__(self, dataloader, use_dali=False):
# self.dataloader = dataloader
# self.use_dali = use_dali
# self.iterator = iter(dataloader)
# self.max_iter = dataloader.max_iter
# self.total_samples = dataloader.total_samples
# def get_batch(self):
# # fetch data batch from dataloader
# try:
# batch = next(self.iterator)
# except Exception:
# # NOTE: reset DALI dataloader manually
# if self.use_dali:
# self.dataloader.reset()
# self.iterator = iter(self.dataloader)
# batch = next(self.iterator)
# return batch
......@@ -60,6 +60,8 @@ class Engine(object):
# load_pretrain
self._init_pretrained()
self._init_amp()
# init train_func and eval_func
self.eval = build_eval_func(
self.config, mode=self.mode, model=self.model)
......@@ -69,7 +71,7 @@ class Engine(object):
# for distributed
self._init_dist()
print_config(config)
print_config(self.config)
@paddle.no_grad()
def infer(self):
......
......@@ -29,10 +29,11 @@ class ClassEval(object):
def __init__(self, config, mode, model):
self.config = config
self.model = model
self.print_batch_step = self.config["Global"]["print_batch_step"]
self.use_dali = self.config["Global"].get("use_dali", False)
self.eval_metric_func = build_metrics(config, "eval")
self.eval_dataloader = build_dataloader(config, "eval")
self.eval_loss_func = build_loss(config, "eval")
self.eval_metric_func = build_metrics(self.config, "eval")
self.eval_dataloader = build_dataloader(self.config, "Eval")
self.eval_loss_func = build_loss(self.config, "Eval")
self.output_info = dict()
@paddle.no_grad()
......@@ -48,13 +49,12 @@ class ClassEval(object):
"reader_cost": AverageMeter(
"reader_cost", ".5f", postfix=" s,"),
}
print_batch_step = self.config["Global"]["print_batch_step"]
tic = time.time()
total_samples = self.eval_dataloader["Eval"].total_samples
total_samples = self.eval_dataloader.total_samples
accum_samples = 0
max_iter = self.eval_dataloader["Eval"].max_iter
for iter_id, batch in enumerate(self.eval_dataloader["Eval"]):
max_iter = self.eval_dataloader.max_iter
for iter_id, batch in enumerate(self.eval_dataloader):
if iter_id >= max_iter:
break
if iter_id == 5:
......@@ -130,7 +130,7 @@ class ClassEval(object):
self.eval_metric_func(preds, labels)
time_info["batch_cost"].update(time.time() - tic)
if iter_id % print_batch_step == 0:
if iter_id % self.print_batch_step == 0:
time_msg = "s, ".join([
"{}: {:.5f}".format(key, time_info[key].avg)
for key in time_info
......@@ -153,7 +153,7 @@ class ClassEval(object):
tic = time.time()
if self.use_dali:
self.eval_dataloader["Eval"].reset()
self.eval_dataloader.reset()
if "ATTRMetric" in self.config["Metric"]["Eval"][0]:
metric_msg = ", ".join([
......
......@@ -25,7 +25,7 @@ def build_train_func(config, mode, model, eval_func):
train_mode = config["Global"].get("task", None)
if train_mode is None:
config["Global"]["task"] = "classification"
return ClassTrainer(config, mode, model, eval_func)
return ClassTrainer(config, model, eval_func)
else:
return getattr(sys.modules[__name__], "train_epoch_" + train_mode)(
config, mode, model, eval_func)
......@@ -28,7 +28,7 @@ from ...utils.save_load import init_model, ModelSaver
class ClassTrainer(object):
def __init__(self, config, mode, model, eval_func):
def __init__(self, config, model, eval_func):
self.config = config
self.model = model
self.eval = eval_func
......@@ -41,32 +41,32 @@ class ClassTrainer(object):
# gradient accumulation
self.update_freq = self.config["Global"].get("update_freq", 1)
# AMP training and evaluating
# self._init_amp()
# TODO(gaotingquan): mv to build_model
# build EMA model
self.model_ema = self._build_ema_model()
# build dataloader
self.use_dali = self.config["Global"].get("use_dali", False)
self.dataloader_dict = build_dataloader(self.config, mode)
self.dataloader = build_dataloader(self.config, "Train")
# build loss
self.train_loss_func, self.unlabel_train_loss_func = build_loss(
self.config, mode)
self.loss_func = build_loss(config, "Train")
# build metric
self.train_metric_func = build_metrics(config, "train")
# build optimizer
self.optimizer, self.lr_sch = build_optimizer(
self.config, self.dataloader_dict["Train"].max_iter,
[self.model, self.train_loss_func], self.update_freq)
self.config, self.dataloader.max_iter,
[self.model, self.loss_func], self.update_freq)
# build model saver
self.model_saver = ModelSaver(
self,
net_name="model",
loss_name="train_loss_func",
opt_name="optimizer",
model_ema_name="model_ema")
config=self.config,
net=self.model,
loss=self.loss_func,
opt=self.optimizer,
model_ema=self.model_ema if self.model_ema else None)
# build best metric
self.best_metric = {
......@@ -84,8 +84,6 @@ class ClassTrainer(object):
"reader_cost", ".5f", postfix=" s,"),
}
# build EMA model
self.model_ema = self._build_ema_model()
self._init_checkpoints()
# for visualdl
......@@ -173,11 +171,10 @@ class ClassTrainer(object):
}, prefix="latest")
def train_epoch(self, epoch_id):
self.model.train()
tic = time.time()
for iter_id in range(self.dataloader_dict["Train"].max_iter):
batch = self.dataloader_dict["Train"].get_batch()
for iter_id, batch in enumerate(self.dataloader):
profiler.add_profiler_step(self.config["profiler_options"])
if iter_id == 5:
for key in self.time_info:
......@@ -190,7 +187,7 @@ class ClassTrainer(object):
self.global_step += 1
out = self.model(batch)
loss_dict = self.train_loss_func(out, batch[1])
loss_dict = self.loss_func(out, batch[1])
# TODO(gaotingquan): mv update_freq to loss and optimizer
loss = loss_dict["loss"] / self.update_freq
loss.backward()
......
......@@ -55,13 +55,13 @@ def log_info(trainer, batch_size, epoch_id, iter_id):
batch_size / trainer.time_info["batch_cost"].avg)
eta_sec = ((trainer.config["Global"]["epochs"] - epoch_id + 1
) * trainer.dataloader_dict["Train"].max_iter - iter_id
) * len(trainer.dataloader) - iter_id
) * trainer.time_info["batch_cost"].avg
eta_msg = "eta: {:s}".format(str(datetime.timedelta(seconds=int(eta_sec))))
logger.info("[Train][Epoch {}/{}][Iter: {}/{}]{}, {}, {}, {}, {}".format(
epoch_id, trainer.config["Global"][
"epochs"], iter_id, trainer.dataloader_dict["Train"]
.max_iter, lr_msg, metric_msg, time_msg, ips_msg, eta_msg))
epoch_id, trainer.config["Global"]["epochs"], iter_id,
len(trainer.dataloader), lr_msg, metric_msg, time_msg, ips_msg,
eta_msg))
for i, lr in enumerate(trainer.lr_sch):
logger.scaler(
......
......@@ -50,8 +50,9 @@ from .metabinloss import IntraDomainScatterLoss
class CombinedLoss(nn.Layer):
def __init__(self, config_list, amp_config=None):
def __init__(self, config_list, mode, amp_config=None):
super().__init__()
self.mode = mode
loss_func = []
self.loss_weight = []
assert isinstance(config_list, list), (
......@@ -68,11 +69,13 @@ class CombinedLoss(nn.Layer):
self.loss_func = nn.LayerList(loss_func)
logger.debug("build loss {} success.".format(loss_func))
self.scaler = None
if amp_config:
self.scaler = paddle.amp.GradScaler(
init_loss_scaling=config["AMP"].get("scale_loss", 1.0),
use_dynamic_loss_scaling=config["AMP"].get(
"use_dynamic_loss_scaling", False))
if self.mode == "Train" or AMPForwardDecorator.amp_eval:
self.scaler = paddle.amp.GradScaler(
init_loss_scaling=amp_config.get("scale_loss", 1.0),
use_dynamic_loss_scaling=amp_config.get(
"use_dynamic_loss_scaling", False))
@AMP_forward_decorator
def __call__(self, input, batch):
......@@ -89,49 +92,26 @@ class CombinedLoss(nn.Layer):
loss = {key: loss[key] * weight for key in loss}
loss_dict.update(loss)
loss_dict["loss"] = paddle.add_n(list(loss_dict.values()))
# TODO(gaotingquan): if amp_eval & eval_loss ?
if AMPForwardDecorator.amp_level:
if self.scaler:
self.scaler(loss_dict["loss"])
return loss_dict
def build_loss(config, mode="train"):
if mode == "train":
label_loss_info = config["Loss"]["Train"]
if label_loss_info:
train_loss_func = CombinedLoss(
copy.deepcopy(label_loss_info), config.get("AMP", None))
unlabel_loss_info = config.get("UnLabelLoss", {}).get("Train", None)
if unlabel_loss_info:
unlabel_train_loss_func = CombinedLoss(
copy.deepcopy(unlabel_loss_info), config.get("AMP", None))
else:
unlabel_train_loss_func = None
def build_loss(config, mode):
if config["Loss"][mode] is None:
return None
module_class = CombinedLoss(
copy.deepcopy(config["Loss"][mode]),
mode,
amp_config=config.get("AMP", None))
if AMPForwardDecorator.amp_level is not None:
train_loss_func = paddle.amp.decorate(
models=train_loss_func,
level=AMPForwardDecorator.amp_level,
save_dtype='float32')
# TODO(gaotingquan): unlabel_loss_info may be None
unlabel_train_loss_func = paddle.amp.decorate(
models=unlabel_train_loss_func,
if AMPForwardDecorator.amp_level is not None:
if mode == "Train" or AMPForwardDecorator.amp_eval:
module_class = paddle.amp.decorate(
models=module_class,
level=AMPForwardDecorator.amp_level,
save_dtype='float32')
return train_loss_func, unlabel_train_loss_func
if mode == "eval" or (mode == "train" and
config["Global"]["eval_during_train"]):
loss_config = config.get("Loss", None)
if loss_config is not None:
loss_config = loss_config.get("Eval")
if loss_config is not None:
eval_loss_func = CombinedLoss(
copy.deepcopy(loss_config), config.get("AMP", None))
if AMPForwardDecorator.amp_level is not None and AMPForwardDecorator.amp_eval:
eval_loss_func = paddle.amp.decorate(
models=eval_loss_func,
level=AMPForwardDecorator.amp_level,
save_dtype='float32')
return eval_loss_func
logger.debug("build loss {} success.".format(module_class))
return module_class
......@@ -150,21 +150,16 @@ def _extract_student_weights(all_params, student_prefix="Student."):
class ModelSaver(object):
def __init__(self,
trainer,
net_name="model",
loss_name="train_loss_func",
opt_name="optimizer",
model_ema_name="model_ema"):
def __init__(self, config, net, loss, opt, model_ema):
# net, loss, opt, model_ema, output_dir,
self.trainer = trainer
self.net_name = net_name
self.loss_name = loss_name
self.opt_name = opt_name
self.model_ema_name = model_ema_name
arch_name = trainer.config["Arch"]["name"]
self.output_dir = os.path.join(trainer.output_dir, arch_name)
self.net = net
self.loss = loss
self.opt = opt
self.model_ema = model_ema
arch_name = config["Arch"]["name"]
self.output_dir = os.path.join(config["Global"]["output_dir"],
arch_name)
_mkdir_if_not_exist(self.output_dir)
def save(self, metric_info, prefix='ppcls', save_student_model=False):
......@@ -174,8 +169,8 @@ class ModelSaver(object):
save_dir = os.path.join(self.output_dir, prefix)
params_state_dict = getattr(self.trainer, self.net_name).state_dict()
loss = getattr(self.trainer, self.loss_name)
params_state_dict = self.net.state_dict()
loss = self.loss
if loss is not None:
loss_state_dict = loss.state_dict()
keys_inter = set(params_state_dict.keys()) & set(
......@@ -190,11 +185,11 @@ class ModelSaver(object):
paddle.save(s_params, save_dir + "_student.pdparams")
paddle.save(params_state_dict, save_dir + ".pdparams")
model_ema = getattr(self.trainer, self.model_ema_name)
model_ema = self.model_ema
if model_ema is not None:
paddle.save(model_ema.module.state_dict(),
save_dir + ".ema.pdparams")
optimizer = getattr(self.trainer, self.opt_name)
optimizer = self.opt
paddle.save([opt.state_dict() for opt in optimizer],
save_dir + ".pdopt")
paddle.save(metric_info, save_dir + ".pdstates")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册