提交 915dde17 编写于 作者: T Tingquan Gao

Revert "refactor: rm train and eval from engine"

This reverts commit 5a6fe171.
上级 aa52682c
......@@ -22,17 +22,25 @@ from paddle import nn
import numpy as np
import random
from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger
from ppcls.utils.logger import init_logger
from ppcls.utils.config import print_config
from ppcls.data import build_dataloader
from ppcls.arch import build_model, RecModel, DistillationModel, TheseusLayer
from ppcls.loss import build_loss
from ppcls.metric import build_metrics
from ppcls.optimizer import build_optimizer
from ppcls.utils.ema import ExponentialMovingAverage
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
from ppcls.utils.save_load import init_model, ModelSaver
from ppcls.data.utils.get_image_list import get_image_list
from ppcls.data.postprocess import build_postprocess
from ppcls.data import create_operators
from .train import build_train_func
from .train import build_train_epoch_func
from .evaluation import build_eval_func
from ppcls.engine.train.utils import type_name
from ppcls.engine import evaluation
from ppcls.arch.gears.identity_head import IdentityHead
......@@ -42,35 +50,186 @@ class Engine(object):
assert mode in ["train", "eval", "infer", "export"]
self.mode = mode
self.config = config
self.start_eval_epoch = self.config["Global"].get("start_eval_epoch",
0) - 1
self.epochs = self.config["Global"].get("epochs", 1)
# set seed
self._init_seed()
# init logger
log_file = os.path.join(self.config['Global']['output_dir'],
self.config["Arch"]["name"], f"{mode}.log")
self.output_dir = self.config['Global']['output_dir']
log_file = os.path.join(self.output_dir, self.config["Arch"]["name"],
f"{mode}.log")
init_logger(log_file=log_file)
# for visualdl
self.vdl_writer = self._init_vdl()
# init train_func and eval_func
self.train_epoch_func = build_train_epoch_func(self.config)
self.eval_func = build_eval_func(self.config)
# set device
self._init_device()
# gradient accumulation
self.update_freq = self.config["Global"].get("update_freq", 1)
# build dataloader
self.use_dali = self.config["Global"].get("use_dali", False)
self.dataloader_dict = build_dataloader(self.config, mode)
# build loss
self.train_loss_func, self.unlabel_train_loss_func, self.eval_loss_func = build_loss(
self.config, self.mode)
# build metric
self.train_metric_func, self.eval_metric_func = build_metrics(self)
# build model
self.model = build_model(self.config, self.mode)
# load_pretrain
self._init_pretrained()
# init train_func and eval_func
self.eval = build_eval_func(
self.config, mode=self.mode, model=self.model)
self.train = build_train_func(
self.config, mode=self.mode, model=self.model, eval_func=self.eval)
# build optimizer
self.optimizer, self.lr_sch = build_optimizer(self)
# AMP training and evaluating
self._init_amp()
# for distributed
self._init_dist()
# build model saver
self.model_saver = ModelSaver(
self,
net_name="model",
loss_name="train_loss_func",
opt_name="optimizer",
model_ema_name="model_ema")
print_config(config)
def train(self):
assert self.mode == "train"
print_batch_step = self.config['Global']['print_batch_step']
save_interval = self.config["Global"]["save_interval"]
best_metric = {
"metric": -1.0,
"epoch": 0,
}
# key:
# val: metrics list word
self.output_info = dict()
self.time_info = {
"batch_cost": AverageMeter(
"batch_cost", '.5f', postfix=" s,"),
"reader_cost": AverageMeter(
"reader_cost", ".5f", postfix=" s,"),
}
# build EMA model
self.model_ema = self._build_ema_model()
# TODO: mv best_metric_ema to best_metric dict
best_metric_ema = 0
self._init_checkpoints(best_metric)
# global iter counter
self.global_step = 0
for epoch_id in range(best_metric["epoch"] + 1, self.epochs + 1):
# for one epoch train
self.train_epoch_func(self, epoch_id, print_batch_step)
metric_msg = ", ".join(
[self.output_info[key].avg_info for key in self.output_info])
logger.info("[Train][Epoch {}/{}][Avg]{}".format(
epoch_id, self.epochs, metric_msg))
self.output_info.clear()
acc = 0.0
if self.config["Global"][
"eval_during_train"] and epoch_id % self.config["Global"][
"eval_interval"] == 0 and epoch_id > self.start_eval_epoch:
acc = self.eval(epoch_id)
# step lr (by epoch) according to given metric, such as acc
for i in range(len(self.lr_sch)):
if getattr(self.lr_sch[i], "by_epoch", False) and \
type_name(self.lr_sch[i]) == "ReduceOnPlateau":
self.lr_sch[i].step(acc)
if acc > best_metric["metric"]:
best_metric["metric"] = acc
best_metric["epoch"] = epoch_id
self.model_saver.save(
best_metric,
prefix="best_model",
save_student_model=True)
logger.info("[Eval][Epoch {}][best metric: {}]".format(
epoch_id, best_metric["metric"]))
logger.scaler(
name="eval_acc",
value=acc,
step=epoch_id,
writer=self.vdl_writer)
self.model.train()
if self.model_ema:
ori_model, self.model = self.model, self.model_ema.module
acc_ema = self.eval(epoch_id)
self.model = ori_model
self.model_ema.module.eval()
if acc_ema > best_metric_ema:
best_metric_ema = acc_ema
self.model_saver.save(
{
"metric": acc_ema,
"epoch": epoch_id
},
prefix="best_model_ema")
logger.info("[Eval][Epoch {}][best metric ema: {}]".format(
epoch_id, best_metric_ema))
logger.scaler(
name="eval_acc_ema",
value=acc_ema,
step=epoch_id,
writer=self.vdl_writer)
# save model
if save_interval > 0 and epoch_id % save_interval == 0:
self.model_saver.save(
{
"metric": acc,
"epoch": epoch_id
},
prefix=f"epoch_{epoch_id}")
# save the latest model
self.model_saver.save(
{
"metric": acc,
"epoch": epoch_id
}, prefix="latest")
if self.vdl_writer is not None:
self.vdl_writer.close()
@paddle.no_grad()
def eval(self, epoch_id=0):
assert self.mode in ["train", "eval"]
self.model.eval()
eval_result = self.eval_func(self, epoch_id)
self.model.train()
return eval_result
@paddle.no_grad()
def infer(self):
assert self.mode == "infer" and self.eval_mode == "classification"
......@@ -167,6 +326,15 @@ class Engine(object):
f"Export succeeded! The inference model exported has been saved in \"{self.config['Global']['save_inference_dir']}\"."
)
def _init_vdl(self):
if self.config['Global'][
'use_visualdl'] and mode == "train" and dist.get_rank() == 0:
vdl_writer_path = os.path.join(self.output_dir, "vdl")
if not os.path.exists(vdl_writer_path):
os.makedirs(vdl_writer_path)
return LogWriter(logdir=vdl_writer_path)
return None
def _init_seed(self):
seed = self.config["Global"].get("seed", False)
if dist.get_world_size() != 1:
......@@ -287,6 +455,22 @@ class Engine(object):
self.train_loss_func = paddle.DataParallel(
self.train_loss_func)
def _build_ema_model(self):
if "EMA" in self.config and self.mode == "train":
model_ema = ExponentialMovingAverage(
self.model, self.config['EMA'].get("decay", 0.9999))
return model_ema
else:
return None
def _init_checkpoints(self, best_metric):
if self.config["Global"].get("checkpoints", None) is not None:
metric_info = init_model(self.config.Global, self.model,
self.optimizer, self.train_loss_func,
self.model_ema)
if metric_info is not None:
best_metric.update(metric_info)
class ExportModel(TheseusLayer):
"""
......
......@@ -12,18 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .classification import ClassEval
from .classification import classification_eval
from .retrieval import retrieval_eval
from .adaface import adaface_eval
def build_eval_func(config, mode, model):
if mode not in ["eval", "train"]:
return None
def build_eval_func(config):
eval_mode = config["Global"].get("eval_mode", None)
if eval_mode is None:
config["Global"]["eval_mode"] = "classification"
return ClassEval(config, mode, model)
return classification_eval
else:
return getattr(sys.modules[__name__], eval_mode + "_eval")(config,
mode, model)
return getattr(sys.modules[__name__], eval_mode + "_eval")
......@@ -18,185 +18,164 @@ import time
import platform
import paddle
from ...utils.misc import AverageMeter
from ...utils import logger
from ...data import build_dataloader
from ...loss import build_loss
from ...metric import build_metrics
class ClassEval(object):
def __init__(self, config, mode, model):
self.config = config
self.model = model
self.use_dali = self.config["Global"].get("use_dali", False)
self.eval_metric_func = build_metrics(config, "eval")
self.eval_dataloader = build_dataloader(config, "eval")
self.eval_loss_func = build_loss(config, "eval")
self.output_info = dict()
@paddle.no_grad()
def __call__(self, epoch_id=0):
self.model.eval()
if hasattr(self.eval_metric_func, "reset"):
self.eval_metric_func.reset()
time_info = {
"batch_cost": AverageMeter(
"batch_cost", '.5f', postfix=" s,"),
"reader_cost": AverageMeter(
"reader_cost", ".5f", postfix=" s,"),
}
print_batch_step = self.config["Global"]["print_batch_step"]
tic = time.time()
total_samples = self.eval_dataloader["Eval"].total_samples
accum_samples = 0
max_iter = self.eval_dataloader["Eval"].max_iter
for iter_id, batch in enumerate(self.eval_dataloader["Eval"]):
if iter_id >= max_iter:
break
if iter_id == 5:
for key in time_info:
time_info[key].reset()
time_info["reader_cost"].update(time.time() - tic)
batch_size = batch[0].shape[0]
batch[0] = paddle.to_tensor(batch[0])
if not self.config["Global"].get("use_multilabel", False):
batch[1] = batch[1].reshape([-1, 1]).astype("int64")
# image input
# if engine.amp and engine.amp_eval:
# with paddle.amp.auto_cast(
# custom_black_list={
# "flatten_contiguous_range", "greater_than"
# },
# level=engine.amp_level):
# out = engine.model(batch)
# else:
# out = self.model(batch)
out = self.model(batch)
# just for DistributedBatchSampler issue: repeat sampling
current_samples = batch_size * paddle.distributed.get_world_size()
accum_samples += current_samples
if isinstance(out, dict) and "Student" in out:
out = out["Student"]
if isinstance(out, dict) and "logits" in out:
out = out["logits"]
# gather Tensor when distributed
if paddle.distributed.get_world_size() > 1:
label_list = []
device_id = paddle.distributed.ParallelEnv().device_id
label = batch[1].cuda(device_id) if self.config["Global"][
"device"] == "gpu" else batch[1]
paddle.distributed.all_gather(label_list, label)
labels = paddle.concat(label_list, 0)
if isinstance(out, list):
preds = []
for x in out:
pred_list = []
paddle.distributed.all_gather(pred_list, x)
pred_x = paddle.concat(pred_list, 0)
preds.append(pred_x)
else:
from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger
def classification_eval(engine, epoch_id=0):
if hasattr(engine.eval_metric_func, "reset"):
engine.eval_metric_func.reset()
output_info = dict()
time_info = {
"batch_cost": AverageMeter(
"batch_cost", '.5f', postfix=" s,"),
"reader_cost": AverageMeter(
"reader_cost", ".5f", postfix=" s,"),
}
print_batch_step = engine.config["Global"]["print_batch_step"]
tic = time.time()
total_samples = engine.dataloader_dict["Eval"].total_samples
accum_samples = 0
max_iter = engine.dataloader_dict["Eval"].max_iter
for iter_id, batch in enumerate(engine.dataloader_dict["Eval"]):
if iter_id >= max_iter:
break
if iter_id == 5:
for key in time_info:
time_info[key].reset()
time_info["reader_cost"].update(time.time() - tic)
batch_size = batch[0].shape[0]
batch[0] = paddle.to_tensor(batch[0])
if not engine.config["Global"].get("use_multilabel", False):
batch[1] = batch[1].reshape([-1, 1]).astype("int64")
# image input
if engine.amp and engine.amp_eval:
with paddle.amp.auto_cast(
custom_black_list={
"flatten_contiguous_range", "greater_than"
},
level=engine.amp_level):
out = engine.model(batch)
else:
out = engine.model(batch)
# just for DistributedBatchSampler issue: repeat sampling
current_samples = batch_size * paddle.distributed.get_world_size()
accum_samples += current_samples
if isinstance(out, dict) and "Student" in out:
out = out["Student"]
if isinstance(out, dict) and "logits" in out:
out = out["logits"]
# gather Tensor when distributed
if paddle.distributed.get_world_size() > 1:
label_list = []
device_id = paddle.distributed.ParallelEnv().device_id
label = batch[1].cuda(device_id) if engine.config["Global"][
"device"] == "gpu" else batch[1]
paddle.distributed.all_gather(label_list, label)
labels = paddle.concat(label_list, 0)
if isinstance(out, list):
preds = []
for x in out:
pred_list = []
paddle.distributed.all_gather(pred_list, out)
preds = paddle.concat(pred_list, 0)
if accum_samples > total_samples and not self.use_dali:
if isinstance(preds, list):
preds = [
pred[:total_samples + current_samples -
accum_samples] for pred in preds
]
else:
preds = preds[:total_samples + current_samples -
accum_samples]
labels = labels[:total_samples + current_samples -
accum_samples]
current_samples = total_samples + current_samples - accum_samples
paddle.distributed.all_gather(pred_list, x)
pred_x = paddle.concat(pred_list, 0)
preds.append(pred_x)
else:
labels = batch[1]
preds = out
# calc loss
if self.eval_loss_func is not None:
# if self.amp and self.amp_eval:
# with paddle.amp.auto_cast(
# custom_black_list={
# "flatten_contiguous_range", "greater_than"
# },
# level=engine.amp_level):
# loss_dict = engine.eval_loss_func(preds, labels)
# else:
loss_dict = self.eval_loss_func(preds, labels)
for key in loss_dict:
if key not in self.output_info:
self.output_info[key] = AverageMeter(key, '7.5f')
self.output_info[key].update(
float(loss_dict[key]), current_samples)
# calc metric
if self.eval_metric_func is not None:
self.eval_metric_func(preds, labels)
time_info["batch_cost"].update(time.time() - tic)
if iter_id % print_batch_step == 0:
time_msg = "s, ".join([
"{}: {:.5f}".format(key, time_info[key].avg)
for key in time_info
])
ips_msg = "ips: {:.5f} images/sec".format(
batch_size / time_info["batch_cost"].avg)
if "ATTRMetric" in self.config["Metric"]["Eval"][0]:
metric_msg = ""
pred_list = []
paddle.distributed.all_gather(pred_list, out)
preds = paddle.concat(pred_list, 0)
if accum_samples > total_samples and not engine.use_dali:
if isinstance(preds, list):
preds = [
pred[:total_samples + current_samples - accum_samples]
for pred in preds
]
else:
metric_msg = ", ".join([
"{}: {:.5f}".format(key, self.output_info[key].val)
for key in self.output_info
])
metric_msg += ", {}".format(self.eval_metric_func.avg_info)
logger.info("[Eval][Epoch {}][Iter: {}/{}]{}, {}, {}".format(
epoch_id, iter_id, max_iter, metric_msg, time_msg,
ips_msg))
tic = time.time()
if self.use_dali:
self.eval_dataloader["Eval"].reset()
if "ATTRMetric" in self.config["Metric"]["Eval"][0]:
metric_msg = ", ".join([
"evalres: ma: {:.5f} label_f1: {:.5f} label_pos_recall: {:.5f} label_neg_recall: {:.5f} instance_f1: {:.5f} instance_acc: {:.5f} instance_prec: {:.5f} instance_recall: {:.5f}".
format(*self.eval_metric_func.attr_res())
])
logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg))
# do not try to save best eval.model
if self.eval_metric_func is None:
return -1
# return 1st metric in the dict
return self.eval_metric_func.attr_res()[0]
preds = preds[:total_samples + current_samples -
accum_samples]
labels = labels[:total_samples + current_samples -
accum_samples]
current_samples = total_samples + current_samples - accum_samples
else:
metric_msg = ", ".join([
"{}: {:.5f}".format(key, self.output_info[key].avg)
for key in self.output_info
labels = batch[1]
preds = out
# calc loss
if engine.eval_loss_func is not None:
if engine.amp and engine.amp_eval:
with paddle.amp.auto_cast(
custom_black_list={
"flatten_contiguous_range", "greater_than"
},
level=engine.amp_level):
loss_dict = engine.eval_loss_func(preds, labels)
else:
loss_dict = engine.eval_loss_func(preds, labels)
for key in loss_dict:
if key not in output_info:
output_info[key] = AverageMeter(key, '7.5f')
output_info[key].update(float(loss_dict[key]), current_samples)
# calc metric
if engine.eval_metric_func is not None:
engine.eval_metric_func(preds, labels)
time_info["batch_cost"].update(time.time() - tic)
if iter_id % print_batch_step == 0:
time_msg = "s, ".join([
"{}: {:.5f}".format(key, time_info[key].avg)
for key in time_info
])
metric_msg += ", {}".format(self.eval_metric_func.avg_info)
logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg))
# do not try to save best eval.model
if self.eval_metric_func is None:
return -1
# return 1st metric in the dict
return self.eval_metric_func.avg
self.model.train()
return eval_result
ips_msg = "ips: {:.5f} images/sec".format(
batch_size / time_info["batch_cost"].avg)
if "ATTRMetric" in engine.config["Metric"]["Eval"][0]:
metric_msg = ""
else:
metric_msg = ", ".join([
"{}: {:.5f}".format(key, output_info[key].val)
for key in output_info
])
metric_msg += ", {}".format(engine.eval_metric_func.avg_info)
logger.info("[Eval][Epoch {}][Iter: {}/{}]{}, {}, {}".format(
epoch_id, iter_id, max_iter, metric_msg, time_msg, ips_msg))
tic = time.time()
if engine.use_dali:
engine.dataloader_dict["Eval"].reset()
if "ATTRMetric" in engine.config["Metric"]["Eval"][0]:
metric_msg = ", ".join([
"evalres: ma: {:.5f} label_f1: {:.5f} label_pos_recall: {:.5f} label_neg_recall: {:.5f} instance_f1: {:.5f} instance_acc: {:.5f} instance_prec: {:.5f} instance_recall: {:.5f}".
format(*engine.eval_metric_func.attr_res())
])
logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg))
# do not try to save best eval.model
if engine.eval_metric_func is None:
return -1
# return 1st metric in the dict
return engine.eval_metric_func.attr_res()[0]
else:
metric_msg = ", ".join([
"{}: {:.5f}".format(key, output_info[key].avg)
for key in output_info
])
metric_msg += ", {}".format(engine.eval_metric_func.avg_info)
logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg))
# do not try to save best eval.model
if engine.eval_metric_func is None:
return -1
# return 1st metric in the dict
return engine.eval_metric_func.avg
......@@ -13,19 +13,16 @@
# limitations under the License.
from .train_metabin import train_epoch_metabin
from .classification import ClassTrainer
from .regular_train_epoch import regular_train_epoch
from .train_fixmatch import train_epoch_fixmatch
from .train_fixmatch_ccssl import train_epoch_fixmatch_ccssl
from .train_progressive import train_epoch_progressive
def build_train_func(config, mode, model, eval_func):
if mode != "train":
return None
train_mode = config["Global"].get("task", None)
def build_train_epoch_func(config):
train_mode = config["Global"].get("train_mode", None)
if train_mode is None:
config["Global"]["task"] = "classification"
return ClassTrainer(config, mode, model, eval_func)
config["Global"]["train_mode"] = "regular_train"
return regular_train_epoch
else:
return getattr(sys.modules[__name__], "train_epoch_" + train_mode)(
config, mode, model, eval_func)
return getattr(sys.modules[__name__], "train_epoch_" + train_mode)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function
import time
import paddle
from .utils import update_loss, update_metric, log_info
from ...utils import logger, profiler, type_name
from ...utils.misc import AverageMeter
from ...data import build_dataloader
from ...loss import build_loss
from ...metric import build_metrics
from ...optimizer import build_optimizer
from ...utils.ema import ExponentialMovingAverage
from ...utils.save_load import init_model, ModelSaver
class ClassTrainer(object):
def __init__(self, config, mode, model, eval_func):
self.config = config
self.model = model
self.eval = eval_func
self.start_eval_epoch = self.config["Global"].get("start_eval_epoch",
0) - 1
self.epochs = self.config["Global"].get("epochs", 1)
self.print_batch_step = self.config['Global']['print_batch_step']
self.save_interval = self.config["Global"]["save_interval"]
self.output_dir = self.config['Global']['output_dir']
# gradient accumulation
self.update_freq = self.config["Global"].get("update_freq", 1)
# AMP training and evaluating
# self._init_amp()
# build dataloader
self.use_dali = self.config["Global"].get("use_dali", False)
self.dataloader_dict = build_dataloader(self.config, mode)
# build loss
self.train_loss_func, self.unlabel_train_loss_func = build_loss(
self.config, mode)
# build metric
self.train_metric_func = build_metrics(config, "train")
# build optimizer
self.optimizer, self.lr_sch = build_optimizer(
self.config, self.dataloader_dict["Train"].max_iter,
[self.model, self.train_loss_func], self.update_freq)
# build model saver
self.model_saver = ModelSaver(
self,
net_name="model",
loss_name="train_loss_func",
opt_name="optimizer",
model_ema_name="model_ema")
# build best metric
self.best_metric = {
"metric": -1.0,
"epoch": 0,
}
# key:
# val: metrics list word
self.output_info = dict()
self.time_info = {
"batch_cost": AverageMeter(
"batch_cost", '.5f', postfix=" s,"),
"reader_cost": AverageMeter(
"reader_cost", ".5f", postfix=" s,"),
}
# build EMA model
self.model_ema = self._build_ema_model()
self._init_checkpoints()
# for visualdl
self.vdl_writer = self._init_vdl()
def __call__(self):
# global iter counter
self.global_step = 0
for epoch_id in range(self.best_metric["epoch"] + 1, self.epochs + 1):
# for one epoch train
self.train_epoch(epoch_id)
metric_msg = ", ".join(
[self.output_info[key].avg_info for key in self.output_info])
logger.info("[Train][Epoch {}/{}][Avg]{}".format(
epoch_id, self.epochs, metric_msg))
self.output_info.clear()
acc = 0.0
if self.config["Global"][
"eval_during_train"] and epoch_id % self.config["Global"][
"eval_interval"] == 0 and epoch_id > self.start_eval_epoch:
acc = self.eval(epoch_id)
# step lr (by epoch) according to given metric, such as acc
for i in range(len(self.lr_sch)):
if getattr(self.lr_sch[i], "by_epoch", False) and \
type_name(self.lr_sch[i]) == "ReduceOnPlateau":
self.lr_sch[i].step(acc)
if acc > self.best_metric["metric"]:
self.best_metric["metric"] = acc
self.best_metric["epoch"] = epoch_id
self.model_saver.save(
self.best_metric,
prefix="best_model",
save_student_model=True)
logger.info("[Eval][Epoch {}][best metric: {}]".format(
epoch_id, self.best_metric["metric"]))
logger.scaler(
name="eval_acc",
value=acc,
step=epoch_id,
writer=self.vdl_writer)
self.model.train()
if self.model_ema:
ori_model, self.model = self.model, self.model_ema.module
acc_ema = self.eval(epoch_id)
self.model = ori_model
self.model_ema.module.eval()
if acc_ema > self.best_metric["metric_ema"]:
self.best_metric["metric_ema"] = acc_ema
self.model_saver.save(
{
"metric": acc_ema,
"epoch": epoch_id
},
prefix="best_model_ema")
logger.info("[Eval][Epoch {}][best metric ema: {}]".format(
epoch_id, self.best_metric["metric_ema"]))
logger.scaler(
name="eval_acc_ema",
value=acc_ema,
step=epoch_id,
writer=self.vdl_writer)
# save model
if self.save_interval > 0 and epoch_id % self.save_interval == 0:
self.model_saver.save(
{
"metric": acc,
"epoch": epoch_id
},
prefix=f"epoch_{epoch_id}")
# save the latest model
self.model_saver.save(
{
"metric": acc,
"epoch": epoch_id
}, prefix="latest")
def train_epoch(self, epoch_id):
tic = time.time()
for iter_id in range(self.dataloader_dict["Train"].max_iter):
batch = self.dataloader_dict["Train"].get_batch()
profiler.add_profiler_step(self.config["profiler_options"])
if iter_id == 5:
for key in self.time_info:
self.time_info[key].reset()
self.time_info["reader_cost"].update(time.time() - tic)
batch_size = batch[0].shape[0]
if not self.config["Global"].get("use_multilabel", False):
batch[1] = batch[1].reshape([batch_size, -1])
self.global_step += 1
# forward & backward & step opt
# if engine.amp:
# with paddle.amp.auto_cast(
# custom_black_list={
# "flatten_contiguous_range", "greater_than"
# },
# level=engine.amp_level):
# out = engine.model(batch)
# loss_dict = engine.train_loss_func(out, batch[1])
# loss = loss_dict["loss"] / engine.update_freq
# scaled = engine.scaler.scale(loss)
# scaled.backward()
# if (iter_id + 1) % engine.update_freq == 0:
# for i in range(len(engine.optimizer)):
# engine.scaler.minimize(engine.optimizer[i], scaled)
# else:
# out = engine.model(batch)
# loss_dict = engine.train_loss_func(out, batch[1])
# loss = loss_dict["loss"] / engine.update_freq
# loss.backward()
# if (iter_id + 1) % engine.update_freq == 0:
# for i in range(len(engine.optimizer)):
# engine.optimizer[i].step()
out = self.model(batch)
loss_dict = self.train_loss_func(out, batch[1])
loss = loss_dict["loss"] / self.update_freq
loss.backward()
if (iter_id + 1) % self.update_freq == 0:
for i in range(len(self.optimizer)):
self.optimizer[i].step()
if (iter_id + 1) % self.update_freq == 0:
# clear grad
for i in range(len(self.optimizer)):
self.optimizer[i].clear_grad()
# step lr(by step)
for i in range(len(self.lr_sch)):
if not getattr(self.lr_sch[i], "by_epoch", False):
self.lr_sch[i].step()
# update ema
if self.model_ema:
self.model_ema.update(self.model)
# below code just for logging
# update metric_for_logger
update_metric(self, out, batch, batch_size)
# update_loss_for_logger
update_loss(self, loss_dict, batch_size)
self.time_info["batch_cost"].update(time.time() - tic)
if iter_id % self.print_batch_step == 0:
log_info(self, batch_size, epoch_id, iter_id)
tic = time.time()
# step lr(by epoch)
for i in range(len(self.lr_sch)):
if getattr(self.lr_sch[i], "by_epoch", False) and \
type_name(self.lr_sch[i]) != "ReduceOnPlateau":
self.lr_sch[i].step()
def __del__(self):
if self.vdl_writer is not None:
self.vdl_writer.close()
def _init_vdl(self):
if self.config['Global']['use_visualdl'] and dist.get_rank() == 0:
vdl_writer_path = os.path.join(self.output_dir, "vdl")
if not os.path.exists(vdl_writer_path):
os.makedirs(vdl_writer_path)
return LogWriter(logdir=vdl_writer_path)
return None
def _build_ema_model(self):
if "EMA" in self.config and self.mode == "train":
model_ema = ExponentialMovingAverage(
self.model, self.config['EMA'].get("decay", 0.9999))
self.best_metric["metric_ema"] = 0
return model_ema
else:
return None
def _init_checkpoints(self):
if self.config["Global"].get("checkpoints", None) is not None:
metric_info = init_model(self.config.Global, self.model,
self.optimizer, self.train_loss_func,
self.model_ema)
if metric_info is not None:
self.best_metric.update(metric_info)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function
import time
import paddle
from ppcls.engine.train.utils import update_loss, update_metric, log_info, type_name
from ppcls.utils import profiler
def regular_train_epoch(engine, epoch_id, print_batch_step):
tic = time.time()
for iter_id in range(engine.dataloader_dict["Train"].max_iter):
batch = engine.dataloader_dict["Train"].get_batch()
profiler.add_profiler_step(engine.config["profiler_options"])
if iter_id == 5:
for key in engine.time_info:
engine.time_info[key].reset()
engine.time_info["reader_cost"].update(time.time() - tic)
batch_size = batch[0].shape[0]
if not engine.config["Global"].get("use_multilabel", False):
batch[1] = batch[1].reshape([batch_size, -1])
engine.global_step += 1
# forward & backward & step opt
if engine.amp:
with paddle.amp.auto_cast(
custom_black_list={
"flatten_contiguous_range", "greater_than"
},
level=engine.amp_level):
out = engine.model(batch)
loss_dict = engine.train_loss_func(out, batch[1])
loss = loss_dict["loss"] / engine.update_freq
scaled = engine.scaler.scale(loss)
scaled.backward()
if (iter_id + 1) % engine.update_freq == 0:
for i in range(len(engine.optimizer)):
engine.scaler.minimize(engine.optimizer[i], scaled)
else:
out = engine.model(batch)
loss_dict = engine.train_loss_func(out, batch[1])
loss = loss_dict["loss"] / engine.update_freq
loss.backward()
if (iter_id + 1) % engine.update_freq == 0:
for i in range(len(engine.optimizer)):
engine.optimizer[i].step()
if (iter_id + 1) % engine.update_freq == 0:
# clear grad
for i in range(len(engine.optimizer)):
engine.optimizer[i].clear_grad()
# step lr(by step)
for i in range(len(engine.lr_sch)):
if not getattr(engine.lr_sch[i], "by_epoch", False):
engine.lr_sch[i].step()
# update ema
if engine.model_ema:
engine.model_ema.update(engine.model)
# below code just for logging
# update metric_for_logger
update_metric(engine, out, batch, batch_size)
# update_loss_for_logger
update_loss(engine, loss_dict, batch_size)
engine.time_info["batch_cost"].update(time.time() - tic)
if iter_id % print_batch_step == 0:
log_info(engine, batch_size, epoch_id, iter_id)
tic = time.time()
# step lr(by epoch)
for i in range(len(engine.lr_sch)):
if getattr(engine.lr_sch[i], "by_epoch", False) and \
type_name(engine.lr_sch[i]) != "ReduceOnPlateau":
engine.lr_sch[i].step()
......@@ -151,20 +151,20 @@ def _extract_student_weights(all_params, student_prefix="Student."):
class ModelSaver(object):
def __init__(self,
trainer,
engine,
net_name="model",
loss_name="train_loss_func",
opt_name="optimizer",
model_ema_name="model_ema"):
# net, loss, opt, model_ema, output_dir,
self.trainer = trainer
self.engine = engine
self.net_name = net_name
self.loss_name = loss_name
self.opt_name = opt_name
self.model_ema_name = model_ema_name
arch_name = trainer.config["Arch"]["name"]
self.output_dir = os.path.join(trainer.output_dir, arch_name)
arch_name = engine.config["Arch"]["name"]
self.output_dir = os.path.join(engine.output_dir, arch_name)
_mkdir_if_not_exist(self.output_dir)
def save(self, metric_info, prefix='ppcls', save_student_model=False):
......@@ -174,8 +174,8 @@ class ModelSaver(object):
save_dir = os.path.join(self.output_dir, prefix)
params_state_dict = getattr(self.trainer, self.net_name).state_dict()
loss = getattr(self.trainer, self.loss_name)
params_state_dict = getattr(self.engine, self.net_name).state_dict()
loss = getattr(self.engine, self.loss_name)
if loss is not None:
loss_state_dict = loss.state_dict()
keys_inter = set(params_state_dict.keys()) & set(
......@@ -190,11 +190,11 @@ class ModelSaver(object):
paddle.save(s_params, save_dir + "_student.pdparams")
paddle.save(params_state_dict, save_dir + ".pdparams")
model_ema = getattr(self.trainer, self.model_ema_name)
model_ema = getattr(self.engine, self.model_ema_name)
if model_ema is not None:
paddle.save(model_ema.module.state_dict(),
save_dir + ".ema.pdparams")
optimizer = getattr(self.trainer, self.opt_name)
optimizer = getattr(self.engine, self.opt_name)
paddle.save([opt.state_dict() for opt in optimizer],
save_dir + ".pdopt")
paddle.save(metric_info, save_dir + ".pdstates")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册