提交 915dde17 编写于 作者: T Tingquan Gao

Revert "refactor: rm train and eval from engine"

This reverts commit 5a6fe171.
上级 aa52682c
...@@ -22,17 +22,25 @@ from paddle import nn ...@@ -22,17 +22,25 @@ from paddle import nn
import numpy as np import numpy as np
import random import random
from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger from ppcls.utils import logger
from ppcls.utils.logger import init_logger from ppcls.utils.logger import init_logger
from ppcls.utils.config import print_config from ppcls.utils.config import print_config
from ppcls.data import build_dataloader
from ppcls.arch import build_model, RecModel, DistillationModel, TheseusLayer from ppcls.arch import build_model, RecModel, DistillationModel, TheseusLayer
from ppcls.loss import build_loss
from ppcls.metric import build_metrics
from ppcls.optimizer import build_optimizer
from ppcls.utils.ema import ExponentialMovingAverage
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
from ppcls.utils.save_load import init_model, ModelSaver
from ppcls.data.utils.get_image_list import get_image_list from ppcls.data.utils.get_image_list import get_image_list
from ppcls.data.postprocess import build_postprocess from ppcls.data.postprocess import build_postprocess
from ppcls.data import create_operators from ppcls.data import create_operators
from .train import build_train_func from .train import build_train_epoch_func
from .evaluation import build_eval_func from .evaluation import build_eval_func
from ppcls.engine.train.utils import type_name
from ppcls.engine import evaluation from ppcls.engine import evaluation
from ppcls.arch.gears.identity_head import IdentityHead from ppcls.arch.gears.identity_head import IdentityHead
...@@ -42,35 +50,186 @@ class Engine(object): ...@@ -42,35 +50,186 @@ class Engine(object):
assert mode in ["train", "eval", "infer", "export"] assert mode in ["train", "eval", "infer", "export"]
self.mode = mode self.mode = mode
self.config = config self.config = config
self.start_eval_epoch = self.config["Global"].get("start_eval_epoch",
0) - 1
self.epochs = self.config["Global"].get("epochs", 1)
# set seed # set seed
self._init_seed() self._init_seed()
# init logger # init logger
log_file = os.path.join(self.config['Global']['output_dir'], self.output_dir = self.config['Global']['output_dir']
self.config["Arch"]["name"], f"{mode}.log") log_file = os.path.join(self.output_dir, self.config["Arch"]["name"],
f"{mode}.log")
init_logger(log_file=log_file) init_logger(log_file=log_file)
# for visualdl
self.vdl_writer = self._init_vdl()
# init train_func and eval_func
self.train_epoch_func = build_train_epoch_func(self.config)
self.eval_func = build_eval_func(self.config)
# set device # set device
self._init_device() self._init_device()
# gradient accumulation
self.update_freq = self.config["Global"].get("update_freq", 1)
# build dataloader
self.use_dali = self.config["Global"].get("use_dali", False)
self.dataloader_dict = build_dataloader(self.config, mode)
# build loss
self.train_loss_func, self.unlabel_train_loss_func, self.eval_loss_func = build_loss(
self.config, self.mode)
# build metric
self.train_metric_func, self.eval_metric_func = build_metrics(self)
# build model # build model
self.model = build_model(self.config, self.mode) self.model = build_model(self.config, self.mode)
# load_pretrain # load_pretrain
self._init_pretrained() self._init_pretrained()
# init train_func and eval_func # build optimizer
self.eval = build_eval_func( self.optimizer, self.lr_sch = build_optimizer(self)
self.config, mode=self.mode, model=self.model)
self.train = build_train_func( # AMP training and evaluating
self.config, mode=self.mode, model=self.model, eval_func=self.eval) self._init_amp()
# for distributed # for distributed
self._init_dist() self._init_dist()
# build model saver
self.model_saver = ModelSaver(
self,
net_name="model",
loss_name="train_loss_func",
opt_name="optimizer",
model_ema_name="model_ema")
print_config(config) print_config(config)
def train(self):
assert self.mode == "train"
print_batch_step = self.config['Global']['print_batch_step']
save_interval = self.config["Global"]["save_interval"]
best_metric = {
"metric": -1.0,
"epoch": 0,
}
# key:
# val: metrics list word
self.output_info = dict()
self.time_info = {
"batch_cost": AverageMeter(
"batch_cost", '.5f', postfix=" s,"),
"reader_cost": AverageMeter(
"reader_cost", ".5f", postfix=" s,"),
}
# build EMA model
self.model_ema = self._build_ema_model()
# TODO: mv best_metric_ema to best_metric dict
best_metric_ema = 0
self._init_checkpoints(best_metric)
# global iter counter
self.global_step = 0
for epoch_id in range(best_metric["epoch"] + 1, self.epochs + 1):
# for one epoch train
self.train_epoch_func(self, epoch_id, print_batch_step)
metric_msg = ", ".join(
[self.output_info[key].avg_info for key in self.output_info])
logger.info("[Train][Epoch {}/{}][Avg]{}".format(
epoch_id, self.epochs, metric_msg))
self.output_info.clear()
acc = 0.0
if self.config["Global"][
"eval_during_train"] and epoch_id % self.config["Global"][
"eval_interval"] == 0 and epoch_id > self.start_eval_epoch:
acc = self.eval(epoch_id)
# step lr (by epoch) according to given metric, such as acc
for i in range(len(self.lr_sch)):
if getattr(self.lr_sch[i], "by_epoch", False) and \
type_name(self.lr_sch[i]) == "ReduceOnPlateau":
self.lr_sch[i].step(acc)
if acc > best_metric["metric"]:
best_metric["metric"] = acc
best_metric["epoch"] = epoch_id
self.model_saver.save(
best_metric,
prefix="best_model",
save_student_model=True)
logger.info("[Eval][Epoch {}][best metric: {}]".format(
epoch_id, best_metric["metric"]))
logger.scaler(
name="eval_acc",
value=acc,
step=epoch_id,
writer=self.vdl_writer)
self.model.train()
if self.model_ema:
ori_model, self.model = self.model, self.model_ema.module
acc_ema = self.eval(epoch_id)
self.model = ori_model
self.model_ema.module.eval()
if acc_ema > best_metric_ema:
best_metric_ema = acc_ema
self.model_saver.save(
{
"metric": acc_ema,
"epoch": epoch_id
},
prefix="best_model_ema")
logger.info("[Eval][Epoch {}][best metric ema: {}]".format(
epoch_id, best_metric_ema))
logger.scaler(
name="eval_acc_ema",
value=acc_ema,
step=epoch_id,
writer=self.vdl_writer)
# save model
if save_interval > 0 and epoch_id % save_interval == 0:
self.model_saver.save(
{
"metric": acc,
"epoch": epoch_id
},
prefix=f"epoch_{epoch_id}")
# save the latest model
self.model_saver.save(
{
"metric": acc,
"epoch": epoch_id
}, prefix="latest")
if self.vdl_writer is not None:
self.vdl_writer.close()
@paddle.no_grad()
def eval(self, epoch_id=0):
assert self.mode in ["train", "eval"]
self.model.eval()
eval_result = self.eval_func(self, epoch_id)
self.model.train()
return eval_result
@paddle.no_grad() @paddle.no_grad()
def infer(self): def infer(self):
assert self.mode == "infer" and self.eval_mode == "classification" assert self.mode == "infer" and self.eval_mode == "classification"
...@@ -167,6 +326,15 @@ class Engine(object): ...@@ -167,6 +326,15 @@ class Engine(object):
f"Export succeeded! The inference model exported has been saved in \"{self.config['Global']['save_inference_dir']}\"." f"Export succeeded! The inference model exported has been saved in \"{self.config['Global']['save_inference_dir']}\"."
) )
def _init_vdl(self):
if self.config['Global'][
'use_visualdl'] and mode == "train" and dist.get_rank() == 0:
vdl_writer_path = os.path.join(self.output_dir, "vdl")
if not os.path.exists(vdl_writer_path):
os.makedirs(vdl_writer_path)
return LogWriter(logdir=vdl_writer_path)
return None
def _init_seed(self): def _init_seed(self):
seed = self.config["Global"].get("seed", False) seed = self.config["Global"].get("seed", False)
if dist.get_world_size() != 1: if dist.get_world_size() != 1:
...@@ -287,6 +455,22 @@ class Engine(object): ...@@ -287,6 +455,22 @@ class Engine(object):
self.train_loss_func = paddle.DataParallel( self.train_loss_func = paddle.DataParallel(
self.train_loss_func) self.train_loss_func)
def _build_ema_model(self):
if "EMA" in self.config and self.mode == "train":
model_ema = ExponentialMovingAverage(
self.model, self.config['EMA'].get("decay", 0.9999))
return model_ema
else:
return None
def _init_checkpoints(self, best_metric):
if self.config["Global"].get("checkpoints", None) is not None:
metric_info = init_model(self.config.Global, self.model,
self.optimizer, self.train_loss_func,
self.model_ema)
if metric_info is not None:
best_metric.update(metric_info)
class ExportModel(TheseusLayer): class ExportModel(TheseusLayer):
""" """
......
...@@ -12,18 +12,15 @@ ...@@ -12,18 +12,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .classification import ClassEval from .classification import classification_eval
from .retrieval import retrieval_eval from .retrieval import retrieval_eval
from .adaface import adaface_eval from .adaface import adaface_eval
def build_eval_func(config, mode, model): def build_eval_func(config):
if mode not in ["eval", "train"]:
return None
eval_mode = config["Global"].get("eval_mode", None) eval_mode = config["Global"].get("eval_mode", None)
if eval_mode is None: if eval_mode is None:
config["Global"]["eval_mode"] = "classification" config["Global"]["eval_mode"] = "classification"
return ClassEval(config, mode, model) return classification_eval
else: else:
return getattr(sys.modules[__name__], eval_mode + "_eval")(config, return getattr(sys.modules[__name__], eval_mode + "_eval")
mode, model)
...@@ -18,185 +18,164 @@ import time ...@@ -18,185 +18,164 @@ import time
import platform import platform
import paddle import paddle
from ...utils.misc import AverageMeter from ppcls.utils.misc import AverageMeter
from ...utils import logger from ppcls.utils import logger
from ...data import build_dataloader
from ...loss import build_loss
from ...metric import build_metrics def classification_eval(engine, epoch_id=0):
if hasattr(engine.eval_metric_func, "reset"):
engine.eval_metric_func.reset()
class ClassEval(object): output_info = dict()
def __init__(self, config, mode, model): time_info = {
self.config = config "batch_cost": AverageMeter(
self.model = model "batch_cost", '.5f', postfix=" s,"),
self.use_dali = self.config["Global"].get("use_dali", False) "reader_cost": AverageMeter(
self.eval_metric_func = build_metrics(config, "eval") "reader_cost", ".5f", postfix=" s,"),
self.eval_dataloader = build_dataloader(config, "eval") }
self.eval_loss_func = build_loss(config, "eval") print_batch_step = engine.config["Global"]["print_batch_step"]
self.output_info = dict()
tic = time.time()
@paddle.no_grad() total_samples = engine.dataloader_dict["Eval"].total_samples
def __call__(self, epoch_id=0): accum_samples = 0
self.model.eval() max_iter = engine.dataloader_dict["Eval"].max_iter
for iter_id, batch in enumerate(engine.dataloader_dict["Eval"]):
if hasattr(self.eval_metric_func, "reset"): if iter_id >= max_iter:
self.eval_metric_func.reset() break
if iter_id == 5:
time_info = { for key in time_info:
"batch_cost": AverageMeter( time_info[key].reset()
"batch_cost", '.5f', postfix=" s,"),
"reader_cost": AverageMeter( time_info["reader_cost"].update(time.time() - tic)
"reader_cost", ".5f", postfix=" s,"), batch_size = batch[0].shape[0]
} batch[0] = paddle.to_tensor(batch[0])
print_batch_step = self.config["Global"]["print_batch_step"] if not engine.config["Global"].get("use_multilabel", False):
batch[1] = batch[1].reshape([-1, 1]).astype("int64")
tic = time.time()
total_samples = self.eval_dataloader["Eval"].total_samples # image input
accum_samples = 0 if engine.amp and engine.amp_eval:
max_iter = self.eval_dataloader["Eval"].max_iter with paddle.amp.auto_cast(
for iter_id, batch in enumerate(self.eval_dataloader["Eval"]): custom_black_list={
if iter_id >= max_iter: "flatten_contiguous_range", "greater_than"
break },
if iter_id == 5: level=engine.amp_level):
for key in time_info: out = engine.model(batch)
time_info[key].reset() else:
out = engine.model(batch)
time_info["reader_cost"].update(time.time() - tic)
batch_size = batch[0].shape[0] # just for DistributedBatchSampler issue: repeat sampling
batch[0] = paddle.to_tensor(batch[0]) current_samples = batch_size * paddle.distributed.get_world_size()
if not self.config["Global"].get("use_multilabel", False): accum_samples += current_samples
batch[1] = batch[1].reshape([-1, 1]).astype("int64")
if isinstance(out, dict) and "Student" in out:
# image input out = out["Student"]
# if engine.amp and engine.amp_eval: if isinstance(out, dict) and "logits" in out:
# with paddle.amp.auto_cast( out = out["logits"]
# custom_black_list={
# "flatten_contiguous_range", "greater_than" # gather Tensor when distributed
# }, if paddle.distributed.get_world_size() > 1:
# level=engine.amp_level): label_list = []
# out = engine.model(batch) device_id = paddle.distributed.ParallelEnv().device_id
# else: label = batch[1].cuda(device_id) if engine.config["Global"][
# out = self.model(batch) "device"] == "gpu" else batch[1]
out = self.model(batch) paddle.distributed.all_gather(label_list, label)
labels = paddle.concat(label_list, 0)
# just for DistributedBatchSampler issue: repeat sampling
current_samples = batch_size * paddle.distributed.get_world_size() if isinstance(out, list):
accum_samples += current_samples preds = []
for x in out:
if isinstance(out, dict) and "Student" in out:
out = out["Student"]
if isinstance(out, dict) and "logits" in out:
out = out["logits"]
# gather Tensor when distributed
if paddle.distributed.get_world_size() > 1:
label_list = []
device_id = paddle.distributed.ParallelEnv().device_id
label = batch[1].cuda(device_id) if self.config["Global"][
"device"] == "gpu" else batch[1]
paddle.distributed.all_gather(label_list, label)
labels = paddle.concat(label_list, 0)
if isinstance(out, list):
preds = []
for x in out:
pred_list = []
paddle.distributed.all_gather(pred_list, x)
pred_x = paddle.concat(pred_list, 0)
preds.append(pred_x)
else:
pred_list = [] pred_list = []
paddle.distributed.all_gather(pred_list, out) paddle.distributed.all_gather(pred_list, x)
preds = paddle.concat(pred_list, 0) pred_x = paddle.concat(pred_list, 0)
preds.append(pred_x)
if accum_samples > total_samples and not self.use_dali:
if isinstance(preds, list):
preds = [
pred[:total_samples + current_samples -
accum_samples] for pred in preds
]
else:
preds = preds[:total_samples + current_samples -
accum_samples]
labels = labels[:total_samples + current_samples -
accum_samples]
current_samples = total_samples + current_samples - accum_samples
else: else:
labels = batch[1] pred_list = []
preds = out paddle.distributed.all_gather(pred_list, out)
preds = paddle.concat(pred_list, 0)
# calc loss
if self.eval_loss_func is not None: if accum_samples > total_samples and not engine.use_dali:
# if self.amp and self.amp_eval: if isinstance(preds, list):
# with paddle.amp.auto_cast( preds = [
# custom_black_list={ pred[:total_samples + current_samples - accum_samples]
# "flatten_contiguous_range", "greater_than" for pred in preds
# }, ]
# level=engine.amp_level):
# loss_dict = engine.eval_loss_func(preds, labels)
# else:
loss_dict = self.eval_loss_func(preds, labels)
for key in loss_dict:
if key not in self.output_info:
self.output_info[key] = AverageMeter(key, '7.5f')
self.output_info[key].update(
float(loss_dict[key]), current_samples)
# calc metric
if self.eval_metric_func is not None:
self.eval_metric_func(preds, labels)
time_info["batch_cost"].update(time.time() - tic)
if iter_id % print_batch_step == 0:
time_msg = "s, ".join([
"{}: {:.5f}".format(key, time_info[key].avg)
for key in time_info
])
ips_msg = "ips: {:.5f} images/sec".format(
batch_size / time_info["batch_cost"].avg)
if "ATTRMetric" in self.config["Metric"]["Eval"][0]:
metric_msg = ""
else: else:
metric_msg = ", ".join([ preds = preds[:total_samples + current_samples -
"{}: {:.5f}".format(key, self.output_info[key].val) accum_samples]
for key in self.output_info labels = labels[:total_samples + current_samples -
]) accum_samples]
metric_msg += ", {}".format(self.eval_metric_func.avg_info) current_samples = total_samples + current_samples - accum_samples
logger.info("[Eval][Epoch {}][Iter: {}/{}]{}, {}, {}".format(
epoch_id, iter_id, max_iter, metric_msg, time_msg,
ips_msg))
tic = time.time()
if self.use_dali:
self.eval_dataloader["Eval"].reset()
if "ATTRMetric" in self.config["Metric"]["Eval"][0]:
metric_msg = ", ".join([
"evalres: ma: {:.5f} label_f1: {:.5f} label_pos_recall: {:.5f} label_neg_recall: {:.5f} instance_f1: {:.5f} instance_acc: {:.5f} instance_prec: {:.5f} instance_recall: {:.5f}".
format(*self.eval_metric_func.attr_res())
])
logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg))
# do not try to save best eval.model
if self.eval_metric_func is None:
return -1
# return 1st metric in the dict
return self.eval_metric_func.attr_res()[0]
else: else:
metric_msg = ", ".join([ labels = batch[1]
"{}: {:.5f}".format(key, self.output_info[key].avg) preds = out
for key in self.output_info
# calc loss
if engine.eval_loss_func is not None:
if engine.amp and engine.amp_eval:
with paddle.amp.auto_cast(
custom_black_list={
"flatten_contiguous_range", "greater_than"
},
level=engine.amp_level):
loss_dict = engine.eval_loss_func(preds, labels)
else:
loss_dict = engine.eval_loss_func(preds, labels)
for key in loss_dict:
if key not in output_info:
output_info[key] = AverageMeter(key, '7.5f')
output_info[key].update(float(loss_dict[key]), current_samples)
# calc metric
if engine.eval_metric_func is not None:
engine.eval_metric_func(preds, labels)
time_info["batch_cost"].update(time.time() - tic)
if iter_id % print_batch_step == 0:
time_msg = "s, ".join([
"{}: {:.5f}".format(key, time_info[key].avg)
for key in time_info
]) ])
metric_msg += ", {}".format(self.eval_metric_func.avg_info)
logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg)) ips_msg = "ips: {:.5f} images/sec".format(
batch_size / time_info["batch_cost"].avg)
# do not try to save best eval.model
if self.eval_metric_func is None: if "ATTRMetric" in engine.config["Metric"]["Eval"][0]:
return -1 metric_msg = ""
# return 1st metric in the dict else:
return self.eval_metric_func.avg metric_msg = ", ".join([
self.model.train() "{}: {:.5f}".format(key, output_info[key].val)
return eval_result for key in output_info
])
metric_msg += ", {}".format(engine.eval_metric_func.avg_info)
logger.info("[Eval][Epoch {}][Iter: {}/{}]{}, {}, {}".format(
epoch_id, iter_id, max_iter, metric_msg, time_msg, ips_msg))
tic = time.time()
if engine.use_dali:
engine.dataloader_dict["Eval"].reset()
if "ATTRMetric" in engine.config["Metric"]["Eval"][0]:
metric_msg = ", ".join([
"evalres: ma: {:.5f} label_f1: {:.5f} label_pos_recall: {:.5f} label_neg_recall: {:.5f} instance_f1: {:.5f} instance_acc: {:.5f} instance_prec: {:.5f} instance_recall: {:.5f}".
format(*engine.eval_metric_func.attr_res())
])
logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg))
# do not try to save best eval.model
if engine.eval_metric_func is None:
return -1
# return 1st metric in the dict
return engine.eval_metric_func.attr_res()[0]
else:
metric_msg = ", ".join([
"{}: {:.5f}".format(key, output_info[key].avg)
for key in output_info
])
metric_msg += ", {}".format(engine.eval_metric_func.avg_info)
logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg))
# do not try to save best eval.model
if engine.eval_metric_func is None:
return -1
# return 1st metric in the dict
return engine.eval_metric_func.avg
...@@ -13,19 +13,16 @@ ...@@ -13,19 +13,16 @@
# limitations under the License. # limitations under the License.
from .train_metabin import train_epoch_metabin from .train_metabin import train_epoch_metabin
from .classification import ClassTrainer from .regular_train_epoch import regular_train_epoch
from .train_fixmatch import train_epoch_fixmatch from .train_fixmatch import train_epoch_fixmatch
from .train_fixmatch_ccssl import train_epoch_fixmatch_ccssl from .train_fixmatch_ccssl import train_epoch_fixmatch_ccssl
from .train_progressive import train_epoch_progressive from .train_progressive import train_epoch_progressive
def build_train_func(config, mode, model, eval_func): def build_train_epoch_func(config):
if mode != "train": train_mode = config["Global"].get("train_mode", None)
return None
train_mode = config["Global"].get("task", None)
if train_mode is None: if train_mode is None:
config["Global"]["task"] = "classification" config["Global"]["train_mode"] = "regular_train"
return ClassTrainer(config, mode, model, eval_func) return regular_train_epoch
else: else:
return getattr(sys.modules[__name__], "train_epoch_" + train_mode)( return getattr(sys.modules[__name__], "train_epoch_" + train_mode)
config, mode, model, eval_func)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function
import time
import paddle
from .utils import update_loss, update_metric, log_info
from ...utils import logger, profiler, type_name
from ...utils.misc import AverageMeter
from ...data import build_dataloader
from ...loss import build_loss
from ...metric import build_metrics
from ...optimizer import build_optimizer
from ...utils.ema import ExponentialMovingAverage
from ...utils.save_load import init_model, ModelSaver
class ClassTrainer(object):
def __init__(self, config, mode, model, eval_func):
self.config = config
self.model = model
self.eval = eval_func
self.start_eval_epoch = self.config["Global"].get("start_eval_epoch",
0) - 1
self.epochs = self.config["Global"].get("epochs", 1)
self.print_batch_step = self.config['Global']['print_batch_step']
self.save_interval = self.config["Global"]["save_interval"]
self.output_dir = self.config['Global']['output_dir']
# gradient accumulation
self.update_freq = self.config["Global"].get("update_freq", 1)
# AMP training and evaluating
# self._init_amp()
# build dataloader
self.use_dali = self.config["Global"].get("use_dali", False)
self.dataloader_dict = build_dataloader(self.config, mode)
# build loss
self.train_loss_func, self.unlabel_train_loss_func = build_loss(
self.config, mode)
# build metric
self.train_metric_func = build_metrics(config, "train")
# build optimizer
self.optimizer, self.lr_sch = build_optimizer(
self.config, self.dataloader_dict["Train"].max_iter,
[self.model, self.train_loss_func], self.update_freq)
# build model saver
self.model_saver = ModelSaver(
self,
net_name="model",
loss_name="train_loss_func",
opt_name="optimizer",
model_ema_name="model_ema")
# build best metric
self.best_metric = {
"metric": -1.0,
"epoch": 0,
}
# key:
# val: metrics list word
self.output_info = dict()
self.time_info = {
"batch_cost": AverageMeter(
"batch_cost", '.5f', postfix=" s,"),
"reader_cost": AverageMeter(
"reader_cost", ".5f", postfix=" s,"),
}
# build EMA model
self.model_ema = self._build_ema_model()
self._init_checkpoints()
# for visualdl
self.vdl_writer = self._init_vdl()
def __call__(self):
# global iter counter
self.global_step = 0
for epoch_id in range(self.best_metric["epoch"] + 1, self.epochs + 1):
# for one epoch train
self.train_epoch(epoch_id)
metric_msg = ", ".join(
[self.output_info[key].avg_info for key in self.output_info])
logger.info("[Train][Epoch {}/{}][Avg]{}".format(
epoch_id, self.epochs, metric_msg))
self.output_info.clear()
acc = 0.0
if self.config["Global"][
"eval_during_train"] and epoch_id % self.config["Global"][
"eval_interval"] == 0 and epoch_id > self.start_eval_epoch:
acc = self.eval(epoch_id)
# step lr (by epoch) according to given metric, such as acc
for i in range(len(self.lr_sch)):
if getattr(self.lr_sch[i], "by_epoch", False) and \
type_name(self.lr_sch[i]) == "ReduceOnPlateau":
self.lr_sch[i].step(acc)
if acc > self.best_metric["metric"]:
self.best_metric["metric"] = acc
self.best_metric["epoch"] = epoch_id
self.model_saver.save(
self.best_metric,
prefix="best_model",
save_student_model=True)
logger.info("[Eval][Epoch {}][best metric: {}]".format(
epoch_id, self.best_metric["metric"]))
logger.scaler(
name="eval_acc",
value=acc,
step=epoch_id,
writer=self.vdl_writer)
self.model.train()
if self.model_ema:
ori_model, self.model = self.model, self.model_ema.module
acc_ema = self.eval(epoch_id)
self.model = ori_model
self.model_ema.module.eval()
if acc_ema > self.best_metric["metric_ema"]:
self.best_metric["metric_ema"] = acc_ema
self.model_saver.save(
{
"metric": acc_ema,
"epoch": epoch_id
},
prefix="best_model_ema")
logger.info("[Eval][Epoch {}][best metric ema: {}]".format(
epoch_id, self.best_metric["metric_ema"]))
logger.scaler(
name="eval_acc_ema",
value=acc_ema,
step=epoch_id,
writer=self.vdl_writer)
# save model
if self.save_interval > 0 and epoch_id % self.save_interval == 0:
self.model_saver.save(
{
"metric": acc,
"epoch": epoch_id
},
prefix=f"epoch_{epoch_id}")
# save the latest model
self.model_saver.save(
{
"metric": acc,
"epoch": epoch_id
}, prefix="latest")
def train_epoch(self, epoch_id):
tic = time.time()
for iter_id in range(self.dataloader_dict["Train"].max_iter):
batch = self.dataloader_dict["Train"].get_batch()
profiler.add_profiler_step(self.config["profiler_options"])
if iter_id == 5:
for key in self.time_info:
self.time_info[key].reset()
self.time_info["reader_cost"].update(time.time() - tic)
batch_size = batch[0].shape[0]
if not self.config["Global"].get("use_multilabel", False):
batch[1] = batch[1].reshape([batch_size, -1])
self.global_step += 1
# forward & backward & step opt
# if engine.amp:
# with paddle.amp.auto_cast(
# custom_black_list={
# "flatten_contiguous_range", "greater_than"
# },
# level=engine.amp_level):
# out = engine.model(batch)
# loss_dict = engine.train_loss_func(out, batch[1])
# loss = loss_dict["loss"] / engine.update_freq
# scaled = engine.scaler.scale(loss)
# scaled.backward()
# if (iter_id + 1) % engine.update_freq == 0:
# for i in range(len(engine.optimizer)):
# engine.scaler.minimize(engine.optimizer[i], scaled)
# else:
# out = engine.model(batch)
# loss_dict = engine.train_loss_func(out, batch[1])
# loss = loss_dict["loss"] / engine.update_freq
# loss.backward()
# if (iter_id + 1) % engine.update_freq == 0:
# for i in range(len(engine.optimizer)):
# engine.optimizer[i].step()
out = self.model(batch)
loss_dict = self.train_loss_func(out, batch[1])
loss = loss_dict["loss"] / self.update_freq
loss.backward()
if (iter_id + 1) % self.update_freq == 0:
for i in range(len(self.optimizer)):
self.optimizer[i].step()
if (iter_id + 1) % self.update_freq == 0:
# clear grad
for i in range(len(self.optimizer)):
self.optimizer[i].clear_grad()
# step lr(by step)
for i in range(len(self.lr_sch)):
if not getattr(self.lr_sch[i], "by_epoch", False):
self.lr_sch[i].step()
# update ema
if self.model_ema:
self.model_ema.update(self.model)
# below code just for logging
# update metric_for_logger
update_metric(self, out, batch, batch_size)
# update_loss_for_logger
update_loss(self, loss_dict, batch_size)
self.time_info["batch_cost"].update(time.time() - tic)
if iter_id % self.print_batch_step == 0:
log_info(self, batch_size, epoch_id, iter_id)
tic = time.time()
# step lr(by epoch)
for i in range(len(self.lr_sch)):
if getattr(self.lr_sch[i], "by_epoch", False) and \
type_name(self.lr_sch[i]) != "ReduceOnPlateau":
self.lr_sch[i].step()
def __del__(self):
if self.vdl_writer is not None:
self.vdl_writer.close()
def _init_vdl(self):
if self.config['Global']['use_visualdl'] and dist.get_rank() == 0:
vdl_writer_path = os.path.join(self.output_dir, "vdl")
if not os.path.exists(vdl_writer_path):
os.makedirs(vdl_writer_path)
return LogWriter(logdir=vdl_writer_path)
return None
def _build_ema_model(self):
if "EMA" in self.config and self.mode == "train":
model_ema = ExponentialMovingAverage(
self.model, self.config['EMA'].get("decay", 0.9999))
self.best_metric["metric_ema"] = 0
return model_ema
else:
return None
def _init_checkpoints(self):
if self.config["Global"].get("checkpoints", None) is not None:
metric_info = init_model(self.config.Global, self.model,
self.optimizer, self.train_loss_func,
self.model_ema)
if metric_info is not None:
self.best_metric.update(metric_info)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function
import time
import paddle
from ppcls.engine.train.utils import update_loss, update_metric, log_info, type_name
from ppcls.utils import profiler
def regular_train_epoch(engine, epoch_id, print_batch_step):
tic = time.time()
for iter_id in range(engine.dataloader_dict["Train"].max_iter):
batch = engine.dataloader_dict["Train"].get_batch()
profiler.add_profiler_step(engine.config["profiler_options"])
if iter_id == 5:
for key in engine.time_info:
engine.time_info[key].reset()
engine.time_info["reader_cost"].update(time.time() - tic)
batch_size = batch[0].shape[0]
if not engine.config["Global"].get("use_multilabel", False):
batch[1] = batch[1].reshape([batch_size, -1])
engine.global_step += 1
# forward & backward & step opt
if engine.amp:
with paddle.amp.auto_cast(
custom_black_list={
"flatten_contiguous_range", "greater_than"
},
level=engine.amp_level):
out = engine.model(batch)
loss_dict = engine.train_loss_func(out, batch[1])
loss = loss_dict["loss"] / engine.update_freq
scaled = engine.scaler.scale(loss)
scaled.backward()
if (iter_id + 1) % engine.update_freq == 0:
for i in range(len(engine.optimizer)):
engine.scaler.minimize(engine.optimizer[i], scaled)
else:
out = engine.model(batch)
loss_dict = engine.train_loss_func(out, batch[1])
loss = loss_dict["loss"] / engine.update_freq
loss.backward()
if (iter_id + 1) % engine.update_freq == 0:
for i in range(len(engine.optimizer)):
engine.optimizer[i].step()
if (iter_id + 1) % engine.update_freq == 0:
# clear grad
for i in range(len(engine.optimizer)):
engine.optimizer[i].clear_grad()
# step lr(by step)
for i in range(len(engine.lr_sch)):
if not getattr(engine.lr_sch[i], "by_epoch", False):
engine.lr_sch[i].step()
# update ema
if engine.model_ema:
engine.model_ema.update(engine.model)
# below code just for logging
# update metric_for_logger
update_metric(engine, out, batch, batch_size)
# update_loss_for_logger
update_loss(engine, loss_dict, batch_size)
engine.time_info["batch_cost"].update(time.time() - tic)
if iter_id % print_batch_step == 0:
log_info(engine, batch_size, epoch_id, iter_id)
tic = time.time()
# step lr(by epoch)
for i in range(len(engine.lr_sch)):
if getattr(engine.lr_sch[i], "by_epoch", False) and \
type_name(engine.lr_sch[i]) != "ReduceOnPlateau":
engine.lr_sch[i].step()
...@@ -151,20 +151,20 @@ def _extract_student_weights(all_params, student_prefix="Student."): ...@@ -151,20 +151,20 @@ def _extract_student_weights(all_params, student_prefix="Student."):
class ModelSaver(object): class ModelSaver(object):
def __init__(self, def __init__(self,
trainer, engine,
net_name="model", net_name="model",
loss_name="train_loss_func", loss_name="train_loss_func",
opt_name="optimizer", opt_name="optimizer",
model_ema_name="model_ema"): model_ema_name="model_ema"):
# net, loss, opt, model_ema, output_dir, # net, loss, opt, model_ema, output_dir,
self.trainer = trainer self.engine = engine
self.net_name = net_name self.net_name = net_name
self.loss_name = loss_name self.loss_name = loss_name
self.opt_name = opt_name self.opt_name = opt_name
self.model_ema_name = model_ema_name self.model_ema_name = model_ema_name
arch_name = trainer.config["Arch"]["name"] arch_name = engine.config["Arch"]["name"]
self.output_dir = os.path.join(trainer.output_dir, arch_name) self.output_dir = os.path.join(engine.output_dir, arch_name)
_mkdir_if_not_exist(self.output_dir) _mkdir_if_not_exist(self.output_dir)
def save(self, metric_info, prefix='ppcls', save_student_model=False): def save(self, metric_info, prefix='ppcls', save_student_model=False):
...@@ -174,8 +174,8 @@ class ModelSaver(object): ...@@ -174,8 +174,8 @@ class ModelSaver(object):
save_dir = os.path.join(self.output_dir, prefix) save_dir = os.path.join(self.output_dir, prefix)
params_state_dict = getattr(self.trainer, self.net_name).state_dict() params_state_dict = getattr(self.engine, self.net_name).state_dict()
loss = getattr(self.trainer, self.loss_name) loss = getattr(self.engine, self.loss_name)
if loss is not None: if loss is not None:
loss_state_dict = loss.state_dict() loss_state_dict = loss.state_dict()
keys_inter = set(params_state_dict.keys()) & set( keys_inter = set(params_state_dict.keys()) & set(
...@@ -190,11 +190,11 @@ class ModelSaver(object): ...@@ -190,11 +190,11 @@ class ModelSaver(object):
paddle.save(s_params, save_dir + "_student.pdparams") paddle.save(s_params, save_dir + "_student.pdparams")
paddle.save(params_state_dict, save_dir + ".pdparams") paddle.save(params_state_dict, save_dir + ".pdparams")
model_ema = getattr(self.trainer, self.model_ema_name) model_ema = getattr(self.engine, self.model_ema_name)
if model_ema is not None: if model_ema is not None:
paddle.save(model_ema.module.state_dict(), paddle.save(model_ema.module.state_dict(),
save_dir + ".ema.pdparams") save_dir + ".ema.pdparams")
optimizer = getattr(self.trainer, self.opt_name) optimizer = getattr(self.engine, self.opt_name)
paddle.save([opt.state_dict() for opt in optimizer], paddle.save([opt.state_dict() for opt in optimizer],
save_dir + ".pdopt") save_dir + ".pdopt")
paddle.save(metric_info, save_dir + ".pdstates") paddle.save(metric_info, save_dir + ".pdstates")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册