提交 5a6fe171 编写于 作者: G gaotingquan 提交者: Wei Shengyu

refactor: rm train and eval from engine

上级 187f38eb
...@@ -22,25 +22,17 @@ from paddle import nn ...@@ -22,25 +22,17 @@ from paddle import nn
import numpy as np import numpy as np
import random import random
from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger from ppcls.utils import logger
from ppcls.utils.logger import init_logger from ppcls.utils.logger import init_logger
from ppcls.utils.config import print_config from ppcls.utils.config import print_config
from ppcls.data import build_dataloader
from ppcls.arch import build_model, RecModel, DistillationModel, TheseusLayer from ppcls.arch import build_model, RecModel, DistillationModel, TheseusLayer
from ppcls.loss import build_loss
from ppcls.metric import build_metrics
from ppcls.optimizer import build_optimizer
from ppcls.utils.ema import ExponentialMovingAverage
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
from ppcls.utils.save_load import init_model, ModelSaver
from ppcls.data.utils.get_image_list import get_image_list from ppcls.data.utils.get_image_list import get_image_list
from ppcls.data.postprocess import build_postprocess from ppcls.data.postprocess import build_postprocess
from ppcls.data import create_operators from ppcls.data import create_operators
from .train import build_train_epoch_func from .train import build_train_func
from .evaluation import build_eval_func from .evaluation import build_eval_func
from ppcls.engine.train.utils import type_name
from ppcls.engine import evaluation from ppcls.engine import evaluation
from ppcls.arch.gears.identity_head import IdentityHead from ppcls.arch.gears.identity_head import IdentityHead
...@@ -50,186 +42,35 @@ class Engine(object): ...@@ -50,186 +42,35 @@ class Engine(object):
assert mode in ["train", "eval", "infer", "export"] assert mode in ["train", "eval", "infer", "export"]
self.mode = mode self.mode = mode
self.config = config self.config = config
self.start_eval_epoch = self.config["Global"].get("start_eval_epoch",
0) - 1
self.epochs = self.config["Global"].get("epochs", 1)
# set seed # set seed
self._init_seed() self._init_seed()
# init logger # init logger
self.output_dir = self.config['Global']['output_dir'] log_file = os.path.join(self.config['Global']['output_dir'],
log_file = os.path.join(self.output_dir, self.config["Arch"]["name"], self.config["Arch"]["name"], f"{mode}.log")
f"{mode}.log")
init_logger(log_file=log_file) init_logger(log_file=log_file)
# for visualdl
self.vdl_writer = self._init_vdl()
# init train_func and eval_func
self.train_epoch_func = build_train_epoch_func(self.config)
self.eval_func = build_eval_func(self.config)
# set device # set device
self._init_device() self._init_device()
# gradient accumulation
self.update_freq = self.config["Global"].get("update_freq", 1)
# build dataloader
self.use_dali = self.config["Global"].get("use_dali", False)
self.dataloader_dict = build_dataloader(self.config, mode)
# build loss
self.train_loss_func, self.unlabel_train_loss_func, self.eval_loss_func = build_loss(
self.config, self.mode)
# build metric
self.train_metric_func, self.eval_metric_func = build_metrics(self)
# build model # build model
self.model = build_model(self.config, self.mode) self.model = build_model(self.config, self.mode)
# load_pretrain # load_pretrain
self._init_pretrained() self._init_pretrained()
# build optimizer # init train_func and eval_func
self.optimizer, self.lr_sch = build_optimizer(self) self.eval = build_eval_func(
self.config, mode=self.mode, model=self.model)
# AMP training and evaluating self.train = build_train_func(
self._init_amp() self.config, mode=self.mode, model=self.model, eval_func=self.eval)
# for distributed # for distributed
self._init_dist() self._init_dist()
# build model saver
self.model_saver = ModelSaver(
self,
net_name="model",
loss_name="train_loss_func",
opt_name="optimizer",
model_ema_name="model_ema")
print_config(config) print_config(config)
def train(self):
assert self.mode == "train"
print_batch_step = self.config['Global']['print_batch_step']
save_interval = self.config["Global"]["save_interval"]
best_metric = {
"metric": -1.0,
"epoch": 0,
}
# key:
# val: metrics list word
self.output_info = dict()
self.time_info = {
"batch_cost": AverageMeter(
"batch_cost", '.5f', postfix=" s,"),
"reader_cost": AverageMeter(
"reader_cost", ".5f", postfix=" s,"),
}
# build EMA model
self.model_ema = self._build_ema_model()
# TODO: mv best_metric_ema to best_metric dict
best_metric_ema = 0
self._init_checkpoints(best_metric)
# global iter counter
self.global_step = 0
for epoch_id in range(best_metric["epoch"] + 1, self.epochs + 1):
# for one epoch train
self.train_epoch_func(self, epoch_id, print_batch_step)
metric_msg = ", ".join(
[self.output_info[key].avg_info for key in self.output_info])
logger.info("[Train][Epoch {}/{}][Avg]{}".format(
epoch_id, self.epochs, metric_msg))
self.output_info.clear()
acc = 0.0
if self.config["Global"][
"eval_during_train"] and epoch_id % self.config["Global"][
"eval_interval"] == 0 and epoch_id > self.start_eval_epoch:
acc = self.eval(epoch_id)
# step lr (by epoch) according to given metric, such as acc
for i in range(len(self.lr_sch)):
if getattr(self.lr_sch[i], "by_epoch", False) and \
type_name(self.lr_sch[i]) == "ReduceOnPlateau":
self.lr_sch[i].step(acc)
if acc > best_metric["metric"]:
best_metric["metric"] = acc
best_metric["epoch"] = epoch_id
self.model_saver.save(
best_metric,
prefix="best_model",
save_student_model=True)
logger.info("[Eval][Epoch {}][best metric: {}]".format(
epoch_id, best_metric["metric"]))
logger.scaler(
name="eval_acc",
value=acc,
step=epoch_id,
writer=self.vdl_writer)
self.model.train()
if self.model_ema:
ori_model, self.model = self.model, self.model_ema.module
acc_ema = self.eval(epoch_id)
self.model = ori_model
self.model_ema.module.eval()
if acc_ema > best_metric_ema:
best_metric_ema = acc_ema
self.model_saver.save(
{
"metric": acc_ema,
"epoch": epoch_id
},
prefix="best_model_ema")
logger.info("[Eval][Epoch {}][best metric ema: {}]".format(
epoch_id, best_metric_ema))
logger.scaler(
name="eval_acc_ema",
value=acc_ema,
step=epoch_id,
writer=self.vdl_writer)
# save model
if save_interval > 0 and epoch_id % save_interval == 0:
self.model_saver.save(
{
"metric": acc,
"epoch": epoch_id
},
prefix=f"epoch_{epoch_id}")
# save the latest model
self.model_saver.save(
{
"metric": acc,
"epoch": epoch_id
}, prefix="latest")
if self.vdl_writer is not None:
self.vdl_writer.close()
@paddle.no_grad()
def eval(self, epoch_id=0):
assert self.mode in ["train", "eval"]
self.model.eval()
eval_result = self.eval_func(self, epoch_id)
self.model.train()
return eval_result
@paddle.no_grad() @paddle.no_grad()
def infer(self): def infer(self):
assert self.mode == "infer" and self.eval_mode == "classification" assert self.mode == "infer" and self.eval_mode == "classification"
...@@ -326,15 +167,6 @@ class Engine(object): ...@@ -326,15 +167,6 @@ class Engine(object):
f"Export succeeded! The inference model exported has been saved in \"{self.config['Global']['save_inference_dir']}\"." f"Export succeeded! The inference model exported has been saved in \"{self.config['Global']['save_inference_dir']}\"."
) )
def _init_vdl(self):
if self.config['Global'][
'use_visualdl'] and mode == "train" and dist.get_rank() == 0:
vdl_writer_path = os.path.join(self.output_dir, "vdl")
if not os.path.exists(vdl_writer_path):
os.makedirs(vdl_writer_path)
return LogWriter(logdir=vdl_writer_path)
return None
def _init_seed(self): def _init_seed(self):
seed = self.config["Global"].get("seed", False) seed = self.config["Global"].get("seed", False)
if dist.get_world_size() != 1: if dist.get_world_size() != 1:
...@@ -455,22 +287,6 @@ class Engine(object): ...@@ -455,22 +287,6 @@ class Engine(object):
self.train_loss_func = paddle.DataParallel( self.train_loss_func = paddle.DataParallel(
self.train_loss_func) self.train_loss_func)
def _build_ema_model(self):
if "EMA" in self.config and self.mode == "train":
model_ema = ExponentialMovingAverage(
self.model, self.config['EMA'].get("decay", 0.9999))
return model_ema
else:
return None
def _init_checkpoints(self, best_metric):
if self.config["Global"].get("checkpoints", None) is not None:
metric_info = init_model(self.config.Global, self.model,
self.optimizer, self.train_loss_func,
self.model_ema)
if metric_info is not None:
best_metric.update(metric_info)
class ExportModel(TheseusLayer): class ExportModel(TheseusLayer):
""" """
......
...@@ -12,15 +12,18 @@ ...@@ -12,15 +12,18 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .classification import classification_eval from .classification import ClassEval
from .retrieval import retrieval_eval from .retrieval import retrieval_eval
from .adaface import adaface_eval from .adaface import adaface_eval
def build_eval_func(config): def build_eval_func(config, mode, model):
if mode not in ["eval", "train"]:
return None
eval_mode = config["Global"].get("eval_mode", None) eval_mode = config["Global"].get("eval_mode", None)
if eval_mode is None: if eval_mode is None:
config["Global"]["eval_mode"] = "classification" config["Global"]["eval_mode"] = "classification"
return classification_eval return ClassEval(config, mode, model)
else: else:
return getattr(sys.modules[__name__], eval_mode + "_eval") return getattr(sys.modules[__name__], eval_mode + "_eval")(config,
mode, model)
...@@ -18,164 +18,185 @@ import time ...@@ -18,164 +18,185 @@ import time
import platform import platform
import paddle import paddle
from ppcls.utils.misc import AverageMeter from ...utils.misc import AverageMeter
from ppcls.utils import logger from ...utils import logger
from ...data import build_dataloader
from ...loss import build_loss
def classification_eval(engine, epoch_id=0): from ...metric import build_metrics
if hasattr(engine.eval_metric_func, "reset"):
engine.eval_metric_func.reset()
output_info = dict() class ClassEval(object):
time_info = { def __init__(self, config, mode, model):
"batch_cost": AverageMeter( self.config = config
"batch_cost", '.5f', postfix=" s,"), self.model = model
"reader_cost": AverageMeter( self.use_dali = self.config["Global"].get("use_dali", False)
"reader_cost", ".5f", postfix=" s,"), self.eval_metric_func = build_metrics(config, "eval")
} self.eval_dataloader = build_dataloader(config, "eval")
print_batch_step = engine.config["Global"]["print_batch_step"] self.eval_loss_func = build_loss(config, "eval")
self.output_info = dict()
tic = time.time()
total_samples = engine.dataloader_dict["Eval"].total_samples @paddle.no_grad()
accum_samples = 0 def __call__(self, epoch_id=0):
max_iter = engine.dataloader_dict["Eval"].max_iter self.model.eval()
for iter_id, batch in enumerate(engine.dataloader_dict["Eval"]):
if iter_id >= max_iter: if hasattr(self.eval_metric_func, "reset"):
break self.eval_metric_func.reset()
if iter_id == 5:
for key in time_info: time_info = {
time_info[key].reset() "batch_cost": AverageMeter(
"batch_cost", '.5f', postfix=" s,"),
time_info["reader_cost"].update(time.time() - tic) "reader_cost": AverageMeter(
batch_size = batch[0].shape[0] "reader_cost", ".5f", postfix=" s,"),
batch[0] = paddle.to_tensor(batch[0]) }
if not engine.config["Global"].get("use_multilabel", False): print_batch_step = self.config["Global"]["print_batch_step"]
batch[1] = batch[1].reshape([-1, 1]).astype("int64")
tic = time.time()
# image input total_samples = self.eval_dataloader["Eval"].total_samples
if engine.amp and engine.amp_eval: accum_samples = 0
with paddle.amp.auto_cast( max_iter = self.eval_dataloader["Eval"].max_iter
custom_black_list={ for iter_id, batch in enumerate(self.eval_dataloader["Eval"]):
"flatten_contiguous_range", "greater_than" if iter_id >= max_iter:
}, break
level=engine.amp_level): if iter_id == 5:
out = engine.model(batch) for key in time_info:
else: time_info[key].reset()
out = engine.model(batch)
time_info["reader_cost"].update(time.time() - tic)
# just for DistributedBatchSampler issue: repeat sampling batch_size = batch[0].shape[0]
current_samples = batch_size * paddle.distributed.get_world_size() batch[0] = paddle.to_tensor(batch[0])
accum_samples += current_samples if not self.config["Global"].get("use_multilabel", False):
batch[1] = batch[1].reshape([-1, 1]).astype("int64")
if isinstance(out, dict) and "Student" in out:
out = out["Student"] # image input
if isinstance(out, dict) and "logits" in out: # if engine.amp and engine.amp_eval:
out = out["logits"] # with paddle.amp.auto_cast(
# custom_black_list={
# gather Tensor when distributed # "flatten_contiguous_range", "greater_than"
if paddle.distributed.get_world_size() > 1: # },
label_list = [] # level=engine.amp_level):
device_id = paddle.distributed.ParallelEnv().device_id # out = engine.model(batch)
label = batch[1].cuda(device_id) if engine.config["Global"][ # else:
"device"] == "gpu" else batch[1] # out = self.model(batch)
paddle.distributed.all_gather(label_list, label) out = self.model(batch)
labels = paddle.concat(label_list, 0)
# just for DistributedBatchSampler issue: repeat sampling
if isinstance(out, list): current_samples = batch_size * paddle.distributed.get_world_size()
preds = [] accum_samples += current_samples
for x in out:
pred_list = [] if isinstance(out, dict) and "Student" in out:
paddle.distributed.all_gather(pred_list, x) out = out["Student"]
pred_x = paddle.concat(pred_list, 0) if isinstance(out, dict) and "logits" in out:
preds.append(pred_x) out = out["logits"]
else:
pred_list = [] # gather Tensor when distributed
paddle.distributed.all_gather(pred_list, out) if paddle.distributed.get_world_size() > 1:
preds = paddle.concat(pred_list, 0) label_list = []
device_id = paddle.distributed.ParallelEnv().device_id
if accum_samples > total_samples and not engine.use_dali: label = batch[1].cuda(device_id) if self.config["Global"][
if isinstance(preds, list): "device"] == "gpu" else batch[1]
preds = [ paddle.distributed.all_gather(label_list, label)
pred[:total_samples + current_samples - accum_samples] labels = paddle.concat(label_list, 0)
for pred in preds
] if isinstance(out, list):
preds = []
for x in out:
pred_list = []
paddle.distributed.all_gather(pred_list, x)
pred_x = paddle.concat(pred_list, 0)
preds.append(pred_x)
else: else:
preds = preds[:total_samples + current_samples - pred_list = []
accum_samples] paddle.distributed.all_gather(pred_list, out)
labels = labels[:total_samples + current_samples - preds = paddle.concat(pred_list, 0)
accum_samples]
current_samples = total_samples + current_samples - accum_samples if accum_samples > total_samples and not self.use_dali:
else: if isinstance(preds, list):
labels = batch[1] preds = [
preds = out pred[:total_samples + current_samples -
accum_samples] for pred in preds
# calc loss ]
if engine.eval_loss_func is not None: else:
if engine.amp and engine.amp_eval: preds = preds[:total_samples + current_samples -
with paddle.amp.auto_cast( accum_samples]
custom_black_list={ labels = labels[:total_samples + current_samples -
"flatten_contiguous_range", "greater_than" accum_samples]
}, current_samples = total_samples + current_samples - accum_samples
level=engine.amp_level):
loss_dict = engine.eval_loss_func(preds, labels)
else: else:
loss_dict = engine.eval_loss_func(preds, labels) labels = batch[1]
preds = out
for key in loss_dict:
if key not in output_info: # calc loss
output_info[key] = AverageMeter(key, '7.5f') if self.eval_loss_func is not None:
output_info[key].update(float(loss_dict[key]), current_samples) # if self.amp and self.amp_eval:
# with paddle.amp.auto_cast(
# calc metric # custom_black_list={
if engine.eval_metric_func is not None: # "flatten_contiguous_range", "greater_than"
engine.eval_metric_func(preds, labels) # },
time_info["batch_cost"].update(time.time() - tic) # level=engine.amp_level):
# loss_dict = engine.eval_loss_func(preds, labels)
if iter_id % print_batch_step == 0: # else:
time_msg = "s, ".join([ loss_dict = self.eval_loss_func(preds, labels)
"{}: {:.5f}".format(key, time_info[key].avg)
for key in time_info for key in loss_dict:
]) if key not in self.output_info:
self.output_info[key] = AverageMeter(key, '7.5f')
self.output_info[key].update(
float(loss_dict[key]), current_samples)
# calc metric
if self.eval_metric_func is not None:
self.eval_metric_func(preds, labels)
time_info["batch_cost"].update(time.time() - tic)
if iter_id % print_batch_step == 0:
time_msg = "s, ".join([
"{}: {:.5f}".format(key, time_info[key].avg)
for key in time_info
])
ips_msg = "ips: {:.5f} images/sec".format( ips_msg = "ips: {:.5f} images/sec".format(
batch_size / time_info["batch_cost"].avg) batch_size / time_info["batch_cost"].avg)
if "ATTRMetric" in engine.config["Metric"]["Eval"][0]: if "ATTRMetric" in self.config["Metric"]["Eval"][0]:
metric_msg = "" metric_msg = ""
else: else:
metric_msg = ", ".join([ metric_msg = ", ".join([
"{}: {:.5f}".format(key, output_info[key].val) "{}: {:.5f}".format(key, self.output_info[key].val)
for key in output_info for key in self.output_info
]) ])
metric_msg += ", {}".format(engine.eval_metric_func.avg_info) metric_msg += ", {}".format(self.eval_metric_func.avg_info)
logger.info("[Eval][Epoch {}][Iter: {}/{}]{}, {}, {}".format( logger.info("[Eval][Epoch {}][Iter: {}/{}]{}, {}, {}".format(
epoch_id, iter_id, max_iter, metric_msg, time_msg, ips_msg)) epoch_id, iter_id, max_iter, metric_msg, time_msg,
ips_msg))
tic = time.time()
if self.use_dali:
self.eval_dataloader["Eval"].reset()
if "ATTRMetric" in self.config["Metric"]["Eval"][0]:
metric_msg = ", ".join([
"evalres: ma: {:.5f} label_f1: {:.5f} label_pos_recall: {:.5f} label_neg_recall: {:.5f} instance_f1: {:.5f} instance_acc: {:.5f} instance_prec: {:.5f} instance_recall: {:.5f}".
format(*self.eval_metric_func.attr_res())
])
logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg))
tic = time.time() # do not try to save best eval.model
if engine.use_dali: if self.eval_metric_func is None:
engine.dataloader_dict["Eval"].reset() return -1
# return 1st metric in the dict
if "ATTRMetric" in engine.config["Metric"]["Eval"][0]: return self.eval_metric_func.attr_res()[0]
metric_msg = ", ".join([ else:
"evalres: ma: {:.5f} label_f1: {:.5f} label_pos_recall: {:.5f} label_neg_recall: {:.5f} instance_f1: {:.5f} instance_acc: {:.5f} instance_prec: {:.5f} instance_recall: {:.5f}". metric_msg = ", ".join([
format(*engine.eval_metric_func.attr_res()) "{}: {:.5f}".format(key, self.output_info[key].avg)
]) for key in self.output_info
logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg)) ])
metric_msg += ", {}".format(self.eval_metric_func.avg_info)
# do not try to save best eval.model logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg))
if engine.eval_metric_func is None:
return -1 # do not try to save best eval.model
# return 1st metric in the dict if self.eval_metric_func is None:
return engine.eval_metric_func.attr_res()[0] return -1
else: # return 1st metric in the dict
metric_msg = ", ".join([ return self.eval_metric_func.avg
"{}: {:.5f}".format(key, output_info[key].avg) self.model.train()
for key in output_info return eval_result
])
metric_msg += ", {}".format(engine.eval_metric_func.avg_info)
logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg))
# do not try to save best eval.model
if engine.eval_metric_func is None:
return -1
# return 1st metric in the dict
return engine.eval_metric_func.avg
...@@ -13,16 +13,19 @@ ...@@ -13,16 +13,19 @@
# limitations under the License. # limitations under the License.
from .train_metabin import train_epoch_metabin from .train_metabin import train_epoch_metabin
from .regular_train_epoch import regular_train_epoch from .classification import ClassTrainer
from .train_fixmatch import train_epoch_fixmatch from .train_fixmatch import train_epoch_fixmatch
from .train_fixmatch_ccssl import train_epoch_fixmatch_ccssl from .train_fixmatch_ccssl import train_epoch_fixmatch_ccssl
from .train_progressive import train_epoch_progressive from .train_progressive import train_epoch_progressive
def build_train_epoch_func(config): def build_train_func(config, mode, model, eval_func):
train_mode = config["Global"].get("train_mode", None) if mode != "train":
return None
train_mode = config["Global"].get("task", None)
if train_mode is None: if train_mode is None:
config["Global"]["train_mode"] = "regular_train" config["Global"]["task"] = "classification"
return regular_train_epoch return ClassTrainer(config, mode, model, eval_func)
else: else:
return getattr(sys.modules[__name__], "train_epoch_" + train_mode) return getattr(sys.modules[__name__], "train_epoch_" + train_mode)(
config, mode, model, eval_func)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function
import time
import paddle
from .utils import update_loss, update_metric, log_info
from ...utils import logger, profiler, type_name
from ...utils.misc import AverageMeter
from ...data import build_dataloader
from ...loss import build_loss
from ...metric import build_metrics
from ...optimizer import build_optimizer
from ...utils.ema import ExponentialMovingAverage
from ...utils.save_load import init_model, ModelSaver
class ClassTrainer(object):
def __init__(self, config, mode, model, eval_func):
self.config = config
self.model = model
self.eval = eval_func
self.start_eval_epoch = self.config["Global"].get("start_eval_epoch",
0) - 1
self.epochs = self.config["Global"].get("epochs", 1)
self.print_batch_step = self.config['Global']['print_batch_step']
self.save_interval = self.config["Global"]["save_interval"]
self.output_dir = self.config['Global']['output_dir']
# gradient accumulation
self.update_freq = self.config["Global"].get("update_freq", 1)
# AMP training and evaluating
# self._init_amp()
# build dataloader
self.use_dali = self.config["Global"].get("use_dali", False)
self.dataloader_dict = build_dataloader(self.config, mode)
# build loss
self.train_loss_func, self.unlabel_train_loss_func = build_loss(
self.config, mode)
# build metric
self.train_metric_func = build_metrics(config, "train")
# build optimizer
self.optimizer, self.lr_sch = build_optimizer(
self.config, self.dataloader_dict["Train"].max_iter,
[self.model, self.train_loss_func], self.update_freq)
# build model saver
self.model_saver = ModelSaver(
self,
net_name="model",
loss_name="train_loss_func",
opt_name="optimizer",
model_ema_name="model_ema")
# build best metric
self.best_metric = {
"metric": -1.0,
"epoch": 0,
}
# key:
# val: metrics list word
self.output_info = dict()
self.time_info = {
"batch_cost": AverageMeter(
"batch_cost", '.5f', postfix=" s,"),
"reader_cost": AverageMeter(
"reader_cost", ".5f", postfix=" s,"),
}
# build EMA model
self.model_ema = self._build_ema_model()
self._init_checkpoints()
# for visualdl
self.vdl_writer = self._init_vdl()
def __call__(self):
# global iter counter
self.global_step = 0
for epoch_id in range(self.best_metric["epoch"] + 1, self.epochs + 1):
# for one epoch train
self.train_epoch(epoch_id)
metric_msg = ", ".join(
[self.output_info[key].avg_info for key in self.output_info])
logger.info("[Train][Epoch {}/{}][Avg]{}".format(
epoch_id, self.epochs, metric_msg))
self.output_info.clear()
acc = 0.0
if self.config["Global"][
"eval_during_train"] and epoch_id % self.config["Global"][
"eval_interval"] == 0 and epoch_id > self.start_eval_epoch:
acc = self.eval(epoch_id)
# step lr (by epoch) according to given metric, such as acc
for i in range(len(self.lr_sch)):
if getattr(self.lr_sch[i], "by_epoch", False) and \
type_name(self.lr_sch[i]) == "ReduceOnPlateau":
self.lr_sch[i].step(acc)
if acc > self.best_metric["metric"]:
self.best_metric["metric"] = acc
self.best_metric["epoch"] = epoch_id
self.model_saver.save(
self.best_metric,
prefix="best_model",
save_student_model=True)
logger.info("[Eval][Epoch {}][best metric: {}]".format(
epoch_id, self.best_metric["metric"]))
logger.scaler(
name="eval_acc",
value=acc,
step=epoch_id,
writer=self.vdl_writer)
self.model.train()
if self.model_ema:
ori_model, self.model = self.model, self.model_ema.module
acc_ema = self.eval(epoch_id)
self.model = ori_model
self.model_ema.module.eval()
if acc_ema > self.best_metric["metric_ema"]:
self.best_metric["metric_ema"] = acc_ema
self.model_saver.save(
{
"metric": acc_ema,
"epoch": epoch_id
},
prefix="best_model_ema")
logger.info("[Eval][Epoch {}][best metric ema: {}]".format(
epoch_id, self.best_metric["metric_ema"]))
logger.scaler(
name="eval_acc_ema",
value=acc_ema,
step=epoch_id,
writer=self.vdl_writer)
# save model
if self.save_interval > 0 and epoch_id % self.save_interval == 0:
self.model_saver.save(
{
"metric": acc,
"epoch": epoch_id
},
prefix=f"epoch_{epoch_id}")
# save the latest model
self.model_saver.save(
{
"metric": acc,
"epoch": epoch_id
}, prefix="latest")
def train_epoch(self, epoch_id):
tic = time.time()
for iter_id in range(self.dataloader_dict["Train"].max_iter):
batch = self.dataloader_dict["Train"].get_batch()
profiler.add_profiler_step(self.config["profiler_options"])
if iter_id == 5:
for key in self.time_info:
self.time_info[key].reset()
self.time_info["reader_cost"].update(time.time() - tic)
batch_size = batch[0].shape[0]
if not self.config["Global"].get("use_multilabel", False):
batch[1] = batch[1].reshape([batch_size, -1])
self.global_step += 1
# forward & backward & step opt
# if engine.amp:
# with paddle.amp.auto_cast(
# custom_black_list={
# "flatten_contiguous_range", "greater_than"
# },
# level=engine.amp_level):
# out = engine.model(batch)
# loss_dict = engine.train_loss_func(out, batch[1])
# loss = loss_dict["loss"] / engine.update_freq
# scaled = engine.scaler.scale(loss)
# scaled.backward()
# if (iter_id + 1) % engine.update_freq == 0:
# for i in range(len(engine.optimizer)):
# engine.scaler.minimize(engine.optimizer[i], scaled)
# else:
# out = engine.model(batch)
# loss_dict = engine.train_loss_func(out, batch[1])
# loss = loss_dict["loss"] / engine.update_freq
# loss.backward()
# if (iter_id + 1) % engine.update_freq == 0:
# for i in range(len(engine.optimizer)):
# engine.optimizer[i].step()
out = self.model(batch)
loss_dict = self.train_loss_func(out, batch[1])
loss = loss_dict["loss"] / self.update_freq
loss.backward()
if (iter_id + 1) % self.update_freq == 0:
for i in range(len(self.optimizer)):
self.optimizer[i].step()
if (iter_id + 1) % self.update_freq == 0:
# clear grad
for i in range(len(self.optimizer)):
self.optimizer[i].clear_grad()
# step lr(by step)
for i in range(len(self.lr_sch)):
if not getattr(self.lr_sch[i], "by_epoch", False):
self.lr_sch[i].step()
# update ema
if self.model_ema:
self.model_ema.update(self.model)
# below code just for logging
# update metric_for_logger
update_metric(self, out, batch, batch_size)
# update_loss_for_logger
update_loss(self, loss_dict, batch_size)
self.time_info["batch_cost"].update(time.time() - tic)
if iter_id % self.print_batch_step == 0:
log_info(self, batch_size, epoch_id, iter_id)
tic = time.time()
# step lr(by epoch)
for i in range(len(self.lr_sch)):
if getattr(self.lr_sch[i], "by_epoch", False) and \
type_name(self.lr_sch[i]) != "ReduceOnPlateau":
self.lr_sch[i].step()
def __del__(self):
if self.vdl_writer is not None:
self.vdl_writer.close()
def _init_vdl(self):
if self.config['Global']['use_visualdl'] and dist.get_rank() == 0:
vdl_writer_path = os.path.join(self.output_dir, "vdl")
if not os.path.exists(vdl_writer_path):
os.makedirs(vdl_writer_path)
return LogWriter(logdir=vdl_writer_path)
return None
def _build_ema_model(self):
if "EMA" in self.config and self.mode == "train":
model_ema = ExponentialMovingAverage(
self.model, self.config['EMA'].get("decay", 0.9999))
self.best_metric["metric_ema"] = 0
return model_ema
else:
return None
def _init_checkpoints(self):
if self.config["Global"].get("checkpoints", None) is not None:
metric_info = init_model(self.config.Global, self.model,
self.optimizer, self.train_loss_func,
self.model_ema)
if metric_info is not None:
self.best_metric.update(metric_info)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function
import time
import paddle
from ppcls.engine.train.utils import update_loss, update_metric, log_info, type_name
from ppcls.utils import profiler
def regular_train_epoch(engine, epoch_id, print_batch_step):
tic = time.time()
for iter_id in range(engine.dataloader_dict["Train"].max_iter):
batch = engine.dataloader_dict["Train"].get_batch()
profiler.add_profiler_step(engine.config["profiler_options"])
if iter_id == 5:
for key in engine.time_info:
engine.time_info[key].reset()
engine.time_info["reader_cost"].update(time.time() - tic)
batch_size = batch[0].shape[0]
if not engine.config["Global"].get("use_multilabel", False):
batch[1] = batch[1].reshape([batch_size, -1])
engine.global_step += 1
# forward & backward & step opt
if engine.amp:
with paddle.amp.auto_cast(
custom_black_list={
"flatten_contiguous_range", "greater_than"
},
level=engine.amp_level):
out = engine.model(batch)
loss_dict = engine.train_loss_func(out, batch[1])
loss = loss_dict["loss"] / engine.update_freq
scaled = engine.scaler.scale(loss)
scaled.backward()
if (iter_id + 1) % engine.update_freq == 0:
for i in range(len(engine.optimizer)):
engine.scaler.minimize(engine.optimizer[i], scaled)
else:
out = engine.model(batch)
loss_dict = engine.train_loss_func(out, batch[1])
loss = loss_dict["loss"] / engine.update_freq
loss.backward()
if (iter_id + 1) % engine.update_freq == 0:
for i in range(len(engine.optimizer)):
engine.optimizer[i].step()
if (iter_id + 1) % engine.update_freq == 0:
# clear grad
for i in range(len(engine.optimizer)):
engine.optimizer[i].clear_grad()
# step lr(by step)
for i in range(len(engine.lr_sch)):
if not getattr(engine.lr_sch[i], "by_epoch", False):
engine.lr_sch[i].step()
# update ema
if engine.model_ema:
engine.model_ema.update(engine.model)
# below code just for logging
# update metric_for_logger
update_metric(engine, out, batch, batch_size)
# update_loss_for_logger
update_loss(engine, loss_dict, batch_size)
engine.time_info["batch_cost"].update(time.time() - tic)
if iter_id % print_batch_step == 0:
log_info(engine, batch_size, epoch_id, iter_id)
tic = time.time()
# step lr(by epoch)
for i in range(len(engine.lr_sch)):
if getattr(engine.lr_sch[i], "by_epoch", False) and \
type_name(engine.lr_sch[i]) != "ReduceOnPlateau":
engine.lr_sch[i].step()
...@@ -151,20 +151,20 @@ def _extract_student_weights(all_params, student_prefix="Student."): ...@@ -151,20 +151,20 @@ def _extract_student_weights(all_params, student_prefix="Student."):
class ModelSaver(object): class ModelSaver(object):
def __init__(self, def __init__(self,
engine, trainer,
net_name="model", net_name="model",
loss_name="train_loss_func", loss_name="train_loss_func",
opt_name="optimizer", opt_name="optimizer",
model_ema_name="model_ema"): model_ema_name="model_ema"):
# net, loss, opt, model_ema, output_dir, # net, loss, opt, model_ema, output_dir,
self.engine = engine self.trainer = trainer
self.net_name = net_name self.net_name = net_name
self.loss_name = loss_name self.loss_name = loss_name
self.opt_name = opt_name self.opt_name = opt_name
self.model_ema_name = model_ema_name self.model_ema_name = model_ema_name
arch_name = engine.config["Arch"]["name"] arch_name = trainer.config["Arch"]["name"]
self.output_dir = os.path.join(engine.output_dir, arch_name) self.output_dir = os.path.join(trainer.output_dir, arch_name)
_mkdir_if_not_exist(self.output_dir) _mkdir_if_not_exist(self.output_dir)
def save(self, metric_info, prefix='ppcls', save_student_model=False): def save(self, metric_info, prefix='ppcls', save_student_model=False):
...@@ -174,8 +174,8 @@ class ModelSaver(object): ...@@ -174,8 +174,8 @@ class ModelSaver(object):
save_dir = os.path.join(self.output_dir, prefix) save_dir = os.path.join(self.output_dir, prefix)
params_state_dict = getattr(self.engine, self.net_name).state_dict() params_state_dict = getattr(self.trainer, self.net_name).state_dict()
loss = getattr(self.engine, self.loss_name) loss = getattr(self.trainer, self.loss_name)
if loss is not None: if loss is not None:
loss_state_dict = loss.state_dict() loss_state_dict = loss.state_dict()
keys_inter = set(params_state_dict.keys()) & set( keys_inter = set(params_state_dict.keys()) & set(
...@@ -190,11 +190,11 @@ class ModelSaver(object): ...@@ -190,11 +190,11 @@ class ModelSaver(object):
paddle.save(s_params, save_dir + "_student.pdparams") paddle.save(s_params, save_dir + "_student.pdparams")
paddle.save(params_state_dict, save_dir + ".pdparams") paddle.save(params_state_dict, save_dir + ".pdparams")
model_ema = getattr(self.engine, self.model_ema_name) model_ema = getattr(self.trainer, self.model_ema_name)
if model_ema is not None: if model_ema is not None:
paddle.save(model_ema.module.state_dict(), paddle.save(model_ema.module.state_dict(),
save_dir + ".ema.pdparams") save_dir + ".ema.pdparams")
optimizer = getattr(self.engine, self.opt_name) optimizer = getattr(self.trainer, self.opt_name)
paddle.save([opt.state_dict() for opt in optimizer], paddle.save([opt.state_dict() for opt in optimizer],
save_dir + ".pdopt") save_dir + ".pdopt")
paddle.save(metric_info, save_dir + ".pdstates") paddle.save(metric_info, save_dir + ".pdstates")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册