diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index cb43b6a9e406e0f8d243e34d8bb14a8b05eb6aa2..c9cb3cf3d1f576829b710b5951e4180cb5b888ec 100755 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -22,25 +22,17 @@ from paddle import nn import numpy as np import random -from ppcls.utils.misc import AverageMeter from ppcls.utils import logger from ppcls.utils.logger import init_logger from ppcls.utils.config import print_config -from ppcls.data import build_dataloader from ppcls.arch import build_model, RecModel, DistillationModel, TheseusLayer -from ppcls.loss import build_loss -from ppcls.metric import build_metrics -from ppcls.optimizer import build_optimizer -from ppcls.utils.ema import ExponentialMovingAverage from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url -from ppcls.utils.save_load import init_model, ModelSaver from ppcls.data.utils.get_image_list import get_image_list from ppcls.data.postprocess import build_postprocess from ppcls.data import create_operators -from .train import build_train_epoch_func +from .train import build_train_func from .evaluation import build_eval_func -from ppcls.engine.train.utils import type_name from ppcls.engine import evaluation from ppcls.arch.gears.identity_head import IdentityHead @@ -50,186 +42,35 @@ class Engine(object): assert mode in ["train", "eval", "infer", "export"] self.mode = mode self.config = config - self.start_eval_epoch = self.config["Global"].get("start_eval_epoch", - 0) - 1 - self.epochs = self.config["Global"].get("epochs", 1) # set seed self._init_seed() # init logger - self.output_dir = self.config['Global']['output_dir'] - log_file = os.path.join(self.output_dir, self.config["Arch"]["name"], - f"{mode}.log") + log_file = os.path.join(self.config['Global']['output_dir'], + self.config["Arch"]["name"], f"{mode}.log") init_logger(log_file=log_file) - # for visualdl - self.vdl_writer = self._init_vdl() - - # init train_func and eval_func - self.train_epoch_func = build_train_epoch_func(self.config) - self.eval_func = build_eval_func(self.config) - # set device self._init_device() - # gradient accumulation - self.update_freq = self.config["Global"].get("update_freq", 1) - - # build dataloader - self.use_dali = self.config["Global"].get("use_dali", False) - self.dataloader_dict = build_dataloader(self.config, mode) - - # build loss - self.train_loss_func, self.unlabel_train_loss_func, self.eval_loss_func = build_loss( - self.config, self.mode) - - # build metric - self.train_metric_func, self.eval_metric_func = build_metrics(self) - # build model self.model = build_model(self.config, self.mode) # load_pretrain self._init_pretrained() - # build optimizer - self.optimizer, self.lr_sch = build_optimizer(self) - - # AMP training and evaluating - self._init_amp() + # init train_func and eval_func + self.eval = build_eval_func( + self.config, mode=self.mode, model=self.model) + self.train = build_train_func( + self.config, mode=self.mode, model=self.model, eval_func=self.eval) # for distributed self._init_dist() - # build model saver - self.model_saver = ModelSaver( - self, - net_name="model", - loss_name="train_loss_func", - opt_name="optimizer", - model_ema_name="model_ema") - print_config(config) - def train(self): - assert self.mode == "train" - print_batch_step = self.config['Global']['print_batch_step'] - save_interval = self.config["Global"]["save_interval"] - - best_metric = { - "metric": -1.0, - "epoch": 0, - } - - # key: - # val: metrics list word - self.output_info = dict() - self.time_info = { - "batch_cost": AverageMeter( - "batch_cost", '.5f', postfix=" s,"), - "reader_cost": AverageMeter( - "reader_cost", ".5f", postfix=" s,"), - } - - # build EMA model - self.model_ema = self._build_ema_model() - # TODO: mv best_metric_ema to best_metric dict - best_metric_ema = 0 - - self._init_checkpoints(best_metric) - - # global iter counter - self.global_step = 0 - for epoch_id in range(best_metric["epoch"] + 1, self.epochs + 1): - # for one epoch train - self.train_epoch_func(self, epoch_id, print_batch_step) - - metric_msg = ", ".join( - [self.output_info[key].avg_info for key in self.output_info]) - logger.info("[Train][Epoch {}/{}][Avg]{}".format( - epoch_id, self.epochs, metric_msg)) - self.output_info.clear() - - acc = 0.0 - if self.config["Global"][ - "eval_during_train"] and epoch_id % self.config["Global"][ - "eval_interval"] == 0 and epoch_id > self.start_eval_epoch: - acc = self.eval(epoch_id) - - # step lr (by epoch) according to given metric, such as acc - for i in range(len(self.lr_sch)): - if getattr(self.lr_sch[i], "by_epoch", False) and \ - type_name(self.lr_sch[i]) == "ReduceOnPlateau": - self.lr_sch[i].step(acc) - - if acc > best_metric["metric"]: - best_metric["metric"] = acc - best_metric["epoch"] = epoch_id - self.model_saver.save( - best_metric, - prefix="best_model", - save_student_model=True) - - logger.info("[Eval][Epoch {}][best metric: {}]".format( - epoch_id, best_metric["metric"])) - logger.scaler( - name="eval_acc", - value=acc, - step=epoch_id, - writer=self.vdl_writer) - - self.model.train() - - if self.model_ema: - ori_model, self.model = self.model, self.model_ema.module - acc_ema = self.eval(epoch_id) - self.model = ori_model - self.model_ema.module.eval() - - if acc_ema > best_metric_ema: - best_metric_ema = acc_ema - self.model_saver.save( - { - "metric": acc_ema, - "epoch": epoch_id - }, - prefix="best_model_ema") - logger.info("[Eval][Epoch {}][best metric ema: {}]".format( - epoch_id, best_metric_ema)) - logger.scaler( - name="eval_acc_ema", - value=acc_ema, - step=epoch_id, - writer=self.vdl_writer) - - # save model - if save_interval > 0 and epoch_id % save_interval == 0: - self.model_saver.save( - { - "metric": acc, - "epoch": epoch_id - }, - prefix=f"epoch_{epoch_id}") - - # save the latest model - self.model_saver.save( - { - "metric": acc, - "epoch": epoch_id - }, prefix="latest") - - if self.vdl_writer is not None: - self.vdl_writer.close() - - @paddle.no_grad() - def eval(self, epoch_id=0): - assert self.mode in ["train", "eval"] - self.model.eval() - eval_result = self.eval_func(self, epoch_id) - self.model.train() - return eval_result - @paddle.no_grad() def infer(self): assert self.mode == "infer" and self.eval_mode == "classification" @@ -326,15 +167,6 @@ class Engine(object): f"Export succeeded! The inference model exported has been saved in \"{self.config['Global']['save_inference_dir']}\"." ) - def _init_vdl(self): - if self.config['Global'][ - 'use_visualdl'] and mode == "train" and dist.get_rank() == 0: - vdl_writer_path = os.path.join(self.output_dir, "vdl") - if not os.path.exists(vdl_writer_path): - os.makedirs(vdl_writer_path) - return LogWriter(logdir=vdl_writer_path) - return None - def _init_seed(self): seed = self.config["Global"].get("seed", False) if dist.get_world_size() != 1: @@ -455,22 +287,6 @@ class Engine(object): self.train_loss_func = paddle.DataParallel( self.train_loss_func) - def _build_ema_model(self): - if "EMA" in self.config and self.mode == "train": - model_ema = ExponentialMovingAverage( - self.model, self.config['EMA'].get("decay", 0.9999)) - return model_ema - else: - return None - - def _init_checkpoints(self, best_metric): - if self.config["Global"].get("checkpoints", None) is not None: - metric_info = init_model(self.config.Global, self.model, - self.optimizer, self.train_loss_func, - self.model_ema) - if metric_info is not None: - best_metric.update(metric_info) - class ExportModel(TheseusLayer): """ diff --git a/ppcls/engine/evaluation/__init__.py b/ppcls/engine/evaluation/__init__.py index 82bffcfebac8a57a84e3f785ab9c604842a07be2..c9499db0d98b1532ead0c61210933cbfc3e04c94 100644 --- a/ppcls/engine/evaluation/__init__.py +++ b/ppcls/engine/evaluation/__init__.py @@ -12,15 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .classification import classification_eval +from .classification import ClassEval from .retrieval import retrieval_eval from .adaface import adaface_eval -def build_eval_func(config): +def build_eval_func(config, mode, model): + if mode not in ["eval", "train"]: + return None eval_mode = config["Global"].get("eval_mode", None) if eval_mode is None: config["Global"]["eval_mode"] = "classification" - return classification_eval + return ClassEval(config, mode, model) else: - return getattr(sys.modules[__name__], eval_mode + "_eval") + return getattr(sys.modules[__name__], eval_mode + "_eval")(config, + mode, model) diff --git a/ppcls/engine/evaluation/classification.py b/ppcls/engine/evaluation/classification.py index e37f0e2af84c4eb14aeb0e88ee1d6837ec8e7f06..1c0d2d747f15443598e30758aa62b91e0741920f 100644 --- a/ppcls/engine/evaluation/classification.py +++ b/ppcls/engine/evaluation/classification.py @@ -18,164 +18,185 @@ import time import platform import paddle -from ppcls.utils.misc import AverageMeter -from ppcls.utils import logger - - -def classification_eval(engine, epoch_id=0): - if hasattr(engine.eval_metric_func, "reset"): - engine.eval_metric_func.reset() - output_info = dict() - time_info = { - "batch_cost": AverageMeter( - "batch_cost", '.5f', postfix=" s,"), - "reader_cost": AverageMeter( - "reader_cost", ".5f", postfix=" s,"), - } - print_batch_step = engine.config["Global"]["print_batch_step"] - - tic = time.time() - total_samples = engine.dataloader_dict["Eval"].total_samples - accum_samples = 0 - max_iter = engine.dataloader_dict["Eval"].max_iter - for iter_id, batch in enumerate(engine.dataloader_dict["Eval"]): - if iter_id >= max_iter: - break - if iter_id == 5: - for key in time_info: - time_info[key].reset() - - time_info["reader_cost"].update(time.time() - tic) - batch_size = batch[0].shape[0] - batch[0] = paddle.to_tensor(batch[0]) - if not engine.config["Global"].get("use_multilabel", False): - batch[1] = batch[1].reshape([-1, 1]).astype("int64") - - # image input - if engine.amp and engine.amp_eval: - with paddle.amp.auto_cast( - custom_black_list={ - "flatten_contiguous_range", "greater_than" - }, - level=engine.amp_level): - out = engine.model(batch) - else: - out = engine.model(batch) - - # just for DistributedBatchSampler issue: repeat sampling - current_samples = batch_size * paddle.distributed.get_world_size() - accum_samples += current_samples - - if isinstance(out, dict) and "Student" in out: - out = out["Student"] - if isinstance(out, dict) and "logits" in out: - out = out["logits"] - - # gather Tensor when distributed - if paddle.distributed.get_world_size() > 1: - label_list = [] - device_id = paddle.distributed.ParallelEnv().device_id - label = batch[1].cuda(device_id) if engine.config["Global"][ - "device"] == "gpu" else batch[1] - paddle.distributed.all_gather(label_list, label) - labels = paddle.concat(label_list, 0) - - if isinstance(out, list): - preds = [] - for x in out: - pred_list = [] - paddle.distributed.all_gather(pred_list, x) - pred_x = paddle.concat(pred_list, 0) - preds.append(pred_x) - else: - pred_list = [] - paddle.distributed.all_gather(pred_list, out) - preds = paddle.concat(pred_list, 0) - - if accum_samples > total_samples and not engine.use_dali: - if isinstance(preds, list): - preds = [ - pred[:total_samples + current_samples - accum_samples] - for pred in preds - ] +from ...utils.misc import AverageMeter +from ...utils import logger +from ...data import build_dataloader +from ...loss import build_loss +from ...metric import build_metrics + + +class ClassEval(object): + def __init__(self, config, mode, model): + self.config = config + self.model = model + self.use_dali = self.config["Global"].get("use_dali", False) + self.eval_metric_func = build_metrics(config, "eval") + self.eval_dataloader = build_dataloader(config, "eval") + self.eval_loss_func = build_loss(config, "eval") + self.output_info = dict() + + @paddle.no_grad() + def __call__(self, epoch_id=0): + self.model.eval() + + if hasattr(self.eval_metric_func, "reset"): + self.eval_metric_func.reset() + + time_info = { + "batch_cost": AverageMeter( + "batch_cost", '.5f', postfix=" s,"), + "reader_cost": AverageMeter( + "reader_cost", ".5f", postfix=" s,"), + } + print_batch_step = self.config["Global"]["print_batch_step"] + + tic = time.time() + total_samples = self.eval_dataloader["Eval"].total_samples + accum_samples = 0 + max_iter = self.eval_dataloader["Eval"].max_iter + for iter_id, batch in enumerate(self.eval_dataloader["Eval"]): + if iter_id >= max_iter: + break + if iter_id == 5: + for key in time_info: + time_info[key].reset() + + time_info["reader_cost"].update(time.time() - tic) + batch_size = batch[0].shape[0] + batch[0] = paddle.to_tensor(batch[0]) + if not self.config["Global"].get("use_multilabel", False): + batch[1] = batch[1].reshape([-1, 1]).astype("int64") + + # image input + # if engine.amp and engine.amp_eval: + # with paddle.amp.auto_cast( + # custom_black_list={ + # "flatten_contiguous_range", "greater_than" + # }, + # level=engine.amp_level): + # out = engine.model(batch) + # else: + # out = self.model(batch) + out = self.model(batch) + + # just for DistributedBatchSampler issue: repeat sampling + current_samples = batch_size * paddle.distributed.get_world_size() + accum_samples += current_samples + + if isinstance(out, dict) and "Student" in out: + out = out["Student"] + if isinstance(out, dict) and "logits" in out: + out = out["logits"] + + # gather Tensor when distributed + if paddle.distributed.get_world_size() > 1: + label_list = [] + device_id = paddle.distributed.ParallelEnv().device_id + label = batch[1].cuda(device_id) if self.config["Global"][ + "device"] == "gpu" else batch[1] + paddle.distributed.all_gather(label_list, label) + labels = paddle.concat(label_list, 0) + + if isinstance(out, list): + preds = [] + for x in out: + pred_list = [] + paddle.distributed.all_gather(pred_list, x) + pred_x = paddle.concat(pred_list, 0) + preds.append(pred_x) else: - preds = preds[:total_samples + current_samples - - accum_samples] - labels = labels[:total_samples + current_samples - - accum_samples] - current_samples = total_samples + current_samples - accum_samples - else: - labels = batch[1] - preds = out - - # calc loss - if engine.eval_loss_func is not None: - if engine.amp and engine.amp_eval: - with paddle.amp.auto_cast( - custom_black_list={ - "flatten_contiguous_range", "greater_than" - }, - level=engine.amp_level): - loss_dict = engine.eval_loss_func(preds, labels) + pred_list = [] + paddle.distributed.all_gather(pred_list, out) + preds = paddle.concat(pred_list, 0) + + if accum_samples > total_samples and not self.use_dali: + if isinstance(preds, list): + preds = [ + pred[:total_samples + current_samples - + accum_samples] for pred in preds + ] + else: + preds = preds[:total_samples + current_samples - + accum_samples] + labels = labels[:total_samples + current_samples - + accum_samples] + current_samples = total_samples + current_samples - accum_samples else: - loss_dict = engine.eval_loss_func(preds, labels) - - for key in loss_dict: - if key not in output_info: - output_info[key] = AverageMeter(key, '7.5f') - output_info[key].update(float(loss_dict[key]), current_samples) - - # calc metric - if engine.eval_metric_func is not None: - engine.eval_metric_func(preds, labels) - time_info["batch_cost"].update(time.time() - tic) - - if iter_id % print_batch_step == 0: - time_msg = "s, ".join([ - "{}: {:.5f}".format(key, time_info[key].avg) - for key in time_info - ]) + labels = batch[1] + preds = out + + # calc loss + if self.eval_loss_func is not None: + # if self.amp and self.amp_eval: + # with paddle.amp.auto_cast( + # custom_black_list={ + # "flatten_contiguous_range", "greater_than" + # }, + # level=engine.amp_level): + # loss_dict = engine.eval_loss_func(preds, labels) + # else: + loss_dict = self.eval_loss_func(preds, labels) + + for key in loss_dict: + if key not in self.output_info: + self.output_info[key] = AverageMeter(key, '7.5f') + self.output_info[key].update( + float(loss_dict[key]), current_samples) + + # calc metric + if self.eval_metric_func is not None: + self.eval_metric_func(preds, labels) + time_info["batch_cost"].update(time.time() - tic) + + if iter_id % print_batch_step == 0: + time_msg = "s, ".join([ + "{}: {:.5f}".format(key, time_info[key].avg) + for key in time_info + ]) - ips_msg = "ips: {:.5f} images/sec".format( - batch_size / time_info["batch_cost"].avg) + ips_msg = "ips: {:.5f} images/sec".format( + batch_size / time_info["batch_cost"].avg) - if "ATTRMetric" in engine.config["Metric"]["Eval"][0]: - metric_msg = "" - else: - metric_msg = ", ".join([ - "{}: {:.5f}".format(key, output_info[key].val) - for key in output_info - ]) - metric_msg += ", {}".format(engine.eval_metric_func.avg_info) - logger.info("[Eval][Epoch {}][Iter: {}/{}]{}, {}, {}".format( - epoch_id, iter_id, max_iter, metric_msg, time_msg, ips_msg)) + if "ATTRMetric" in self.config["Metric"]["Eval"][0]: + metric_msg = "" + else: + metric_msg = ", ".join([ + "{}: {:.5f}".format(key, self.output_info[key].val) + for key in self.output_info + ]) + metric_msg += ", {}".format(self.eval_metric_func.avg_info) + logger.info("[Eval][Epoch {}][Iter: {}/{}]{}, {}, {}".format( + epoch_id, iter_id, max_iter, metric_msg, time_msg, + ips_msg)) + + tic = time.time() + if self.use_dali: + self.eval_dataloader["Eval"].reset() + + if "ATTRMetric" in self.config["Metric"]["Eval"][0]: + metric_msg = ", ".join([ + "evalres: ma: {:.5f} label_f1: {:.5f} label_pos_recall: {:.5f} label_neg_recall: {:.5f} instance_f1: {:.5f} instance_acc: {:.5f} instance_prec: {:.5f} instance_recall: {:.5f}". + format(*self.eval_metric_func.attr_res()) + ]) + logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg)) - tic = time.time() - if engine.use_dali: - engine.dataloader_dict["Eval"].reset() - - if "ATTRMetric" in engine.config["Metric"]["Eval"][0]: - metric_msg = ", ".join([ - "evalres: ma: {:.5f} label_f1: {:.5f} label_pos_recall: {:.5f} label_neg_recall: {:.5f} instance_f1: {:.5f} instance_acc: {:.5f} instance_prec: {:.5f} instance_recall: {:.5f}". - format(*engine.eval_metric_func.attr_res()) - ]) - logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg)) - - # do not try to save best eval.model - if engine.eval_metric_func is None: - return -1 - # return 1st metric in the dict - return engine.eval_metric_func.attr_res()[0] - else: - metric_msg = ", ".join([ - "{}: {:.5f}".format(key, output_info[key].avg) - for key in output_info - ]) - metric_msg += ", {}".format(engine.eval_metric_func.avg_info) - logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg)) - - # do not try to save best eval.model - if engine.eval_metric_func is None: - return -1 - # return 1st metric in the dict - return engine.eval_metric_func.avg + # do not try to save best eval.model + if self.eval_metric_func is None: + return -1 + # return 1st metric in the dict + return self.eval_metric_func.attr_res()[0] + else: + metric_msg = ", ".join([ + "{}: {:.5f}".format(key, self.output_info[key].avg) + for key in self.output_info + ]) + metric_msg += ", {}".format(self.eval_metric_func.avg_info) + logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg)) + + # do not try to save best eval.model + if self.eval_metric_func is None: + return -1 + # return 1st metric in the dict + return self.eval_metric_func.avg + self.model.train() + return eval_result diff --git a/ppcls/engine/train/__init__.py b/ppcls/engine/train/__init__.py index ec3c40032bc1f0a407a7eb886a72f30457bf23c3..eedba871ef2a86b8c0510b33e42d63fc72d8f04b 100644 --- a/ppcls/engine/train/__init__.py +++ b/ppcls/engine/train/__init__.py @@ -13,16 +13,19 @@ # limitations under the License. from .train_metabin import train_epoch_metabin -from .regular_train_epoch import regular_train_epoch +from .classification import ClassTrainer from .train_fixmatch import train_epoch_fixmatch from .train_fixmatch_ccssl import train_epoch_fixmatch_ccssl from .train_progressive import train_epoch_progressive -def build_train_epoch_func(config): - train_mode = config["Global"].get("train_mode", None) +def build_train_func(config, mode, model, eval_func): + if mode != "train": + return None + train_mode = config["Global"].get("task", None) if train_mode is None: - config["Global"]["train_mode"] = "regular_train" - return regular_train_epoch + config["Global"]["task"] = "classification" + return ClassTrainer(config, mode, model, eval_func) else: - return getattr(sys.modules[__name__], "train_epoch_" + train_mode) + return getattr(sys.modules[__name__], "train_epoch_" + train_mode)( + config, mode, model, eval_func) diff --git a/ppcls/engine/train/classification.py b/ppcls/engine/train/classification.py new file mode 100644 index 0000000000000000000000000000000000000000..1fad5961cdad934a03a4051437b5b4e2f4078dcb --- /dev/null +++ b/ppcls/engine/train/classification.py @@ -0,0 +1,279 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import, division, print_function + +import time +import paddle + +from .utils import update_loss, update_metric, log_info +from ...utils import logger, profiler, type_name +from ...utils.misc import AverageMeter +from ...data import build_dataloader +from ...loss import build_loss +from ...metric import build_metrics +from ...optimizer import build_optimizer +from ...utils.ema import ExponentialMovingAverage +from ...utils.save_load import init_model, ModelSaver + + +class ClassTrainer(object): + def __init__(self, config, mode, model, eval_func): + self.config = config + self.model = model + self.eval = eval_func + self.start_eval_epoch = self.config["Global"].get("start_eval_epoch", + 0) - 1 + self.epochs = self.config["Global"].get("epochs", 1) + self.print_batch_step = self.config['Global']['print_batch_step'] + self.save_interval = self.config["Global"]["save_interval"] + self.output_dir = self.config['Global']['output_dir'] + # gradient accumulation + self.update_freq = self.config["Global"].get("update_freq", 1) + + # AMP training and evaluating + # self._init_amp() + + # build dataloader + self.use_dali = self.config["Global"].get("use_dali", False) + self.dataloader_dict = build_dataloader(self.config, mode) + + # build loss + self.train_loss_func, self.unlabel_train_loss_func = build_loss( + self.config, mode) + + # build metric + self.train_metric_func = build_metrics(config, "train") + + # build optimizer + self.optimizer, self.lr_sch = build_optimizer( + self.config, self.dataloader_dict["Train"].max_iter, + [self.model, self.train_loss_func], self.update_freq) + + # build model saver + self.model_saver = ModelSaver( + self, + net_name="model", + loss_name="train_loss_func", + opt_name="optimizer", + model_ema_name="model_ema") + + # build best metric + self.best_metric = { + "metric": -1.0, + "epoch": 0, + } + + # key: + # val: metrics list word + self.output_info = dict() + self.time_info = { + "batch_cost": AverageMeter( + "batch_cost", '.5f', postfix=" s,"), + "reader_cost": AverageMeter( + "reader_cost", ".5f", postfix=" s,"), + } + + # build EMA model + self.model_ema = self._build_ema_model() + self._init_checkpoints() + + # for visualdl + self.vdl_writer = self._init_vdl() + + def __call__(self): + # global iter counter + self.global_step = 0 + for epoch_id in range(self.best_metric["epoch"] + 1, self.epochs + 1): + # for one epoch train + self.train_epoch(epoch_id) + + metric_msg = ", ".join( + [self.output_info[key].avg_info for key in self.output_info]) + logger.info("[Train][Epoch {}/{}][Avg]{}".format( + epoch_id, self.epochs, metric_msg)) + self.output_info.clear() + + acc = 0.0 + if self.config["Global"][ + "eval_during_train"] and epoch_id % self.config["Global"][ + "eval_interval"] == 0 and epoch_id > self.start_eval_epoch: + acc = self.eval(epoch_id) + + # step lr (by epoch) according to given metric, such as acc + for i in range(len(self.lr_sch)): + if getattr(self.lr_sch[i], "by_epoch", False) and \ + type_name(self.lr_sch[i]) == "ReduceOnPlateau": + self.lr_sch[i].step(acc) + + if acc > self.best_metric["metric"]: + self.best_metric["metric"] = acc + self.best_metric["epoch"] = epoch_id + self.model_saver.save( + self.best_metric, + prefix="best_model", + save_student_model=True) + + logger.info("[Eval][Epoch {}][best metric: {}]".format( + epoch_id, self.best_metric["metric"])) + logger.scaler( + name="eval_acc", + value=acc, + step=epoch_id, + writer=self.vdl_writer) + + self.model.train() + + if self.model_ema: + ori_model, self.model = self.model, self.model_ema.module + acc_ema = self.eval(epoch_id) + self.model = ori_model + self.model_ema.module.eval() + + if acc_ema > self.best_metric["metric_ema"]: + self.best_metric["metric_ema"] = acc_ema + self.model_saver.save( + { + "metric": acc_ema, + "epoch": epoch_id + }, + prefix="best_model_ema") + logger.info("[Eval][Epoch {}][best metric ema: {}]".format( + epoch_id, self.best_metric["metric_ema"])) + logger.scaler( + name="eval_acc_ema", + value=acc_ema, + step=epoch_id, + writer=self.vdl_writer) + + # save model + if self.save_interval > 0 and epoch_id % self.save_interval == 0: + self.model_saver.save( + { + "metric": acc, + "epoch": epoch_id + }, + prefix=f"epoch_{epoch_id}") + + # save the latest model + self.model_saver.save( + { + "metric": acc, + "epoch": epoch_id + }, prefix="latest") + + def train_epoch(self, epoch_id): + tic = time.time() + + for iter_id in range(self.dataloader_dict["Train"].max_iter): + batch = self.dataloader_dict["Train"].get_batch() + + profiler.add_profiler_step(self.config["profiler_options"]) + if iter_id == 5: + for key in self.time_info: + self.time_info[key].reset() + self.time_info["reader_cost"].update(time.time() - tic) + + batch_size = batch[0].shape[0] + if not self.config["Global"].get("use_multilabel", False): + batch[1] = batch[1].reshape([batch_size, -1]) + self.global_step += 1 + + # forward & backward & step opt + # if engine.amp: + # with paddle.amp.auto_cast( + # custom_black_list={ + # "flatten_contiguous_range", "greater_than" + # }, + # level=engine.amp_level): + # out = engine.model(batch) + # loss_dict = engine.train_loss_func(out, batch[1]) + # loss = loss_dict["loss"] / engine.update_freq + # scaled = engine.scaler.scale(loss) + # scaled.backward() + # if (iter_id + 1) % engine.update_freq == 0: + # for i in range(len(engine.optimizer)): + # engine.scaler.minimize(engine.optimizer[i], scaled) + # else: + # out = engine.model(batch) + # loss_dict = engine.train_loss_func(out, batch[1]) + # loss = loss_dict["loss"] / engine.update_freq + # loss.backward() + # if (iter_id + 1) % engine.update_freq == 0: + # for i in range(len(engine.optimizer)): + # engine.optimizer[i].step() + out = self.model(batch) + loss_dict = self.train_loss_func(out, batch[1]) + loss = loss_dict["loss"] / self.update_freq + loss.backward() + + if (iter_id + 1) % self.update_freq == 0: + for i in range(len(self.optimizer)): + self.optimizer[i].step() + + if (iter_id + 1) % self.update_freq == 0: + # clear grad + for i in range(len(self.optimizer)): + self.optimizer[i].clear_grad() + # step lr(by step) + for i in range(len(self.lr_sch)): + if not getattr(self.lr_sch[i], "by_epoch", False): + self.lr_sch[i].step() + # update ema + if self.model_ema: + self.model_ema.update(self.model) + + # below code just for logging + # update metric_for_logger + update_metric(self, out, batch, batch_size) + # update_loss_for_logger + update_loss(self, loss_dict, batch_size) + self.time_info["batch_cost"].update(time.time() - tic) + if iter_id % self.print_batch_step == 0: + log_info(self, batch_size, epoch_id, iter_id) + tic = time.time() + + # step lr(by epoch) + for i in range(len(self.lr_sch)): + if getattr(self.lr_sch[i], "by_epoch", False) and \ + type_name(self.lr_sch[i]) != "ReduceOnPlateau": + self.lr_sch[i].step() + + def __del__(self): + if self.vdl_writer is not None: + self.vdl_writer.close() + + def _init_vdl(self): + if self.config['Global']['use_visualdl'] and dist.get_rank() == 0: + vdl_writer_path = os.path.join(self.output_dir, "vdl") + if not os.path.exists(vdl_writer_path): + os.makedirs(vdl_writer_path) + return LogWriter(logdir=vdl_writer_path) + return None + + def _build_ema_model(self): + if "EMA" in self.config and self.mode == "train": + model_ema = ExponentialMovingAverage( + self.model, self.config['EMA'].get("decay", 0.9999)) + self.best_metric["metric_ema"] = 0 + return model_ema + else: + return None + + def _init_checkpoints(self): + if self.config["Global"].get("checkpoints", None) is not None: + metric_info = init_model(self.config.Global, self.model, + self.optimizer, self.train_loss_func, + self.model_ema) + if metric_info is not None: + self.best_metric.update(metric_info) diff --git a/ppcls/engine/train/regular_train_epoch.py b/ppcls/engine/train/regular_train_epoch.py deleted file mode 100644 index f49e57e43da9eadb400e61d8ac9307e66afcf1ef..0000000000000000000000000000000000000000 --- a/ppcls/engine/train/regular_train_epoch.py +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from __future__ import absolute_import, division, print_function - -import time -import paddle -from ppcls.engine.train.utils import update_loss, update_metric, log_info, type_name -from ppcls.utils import profiler - - -def regular_train_epoch(engine, epoch_id, print_batch_step): - tic = time.time() - - for iter_id in range(engine.dataloader_dict["Train"].max_iter): - batch = engine.dataloader_dict["Train"].get_batch() - - profiler.add_profiler_step(engine.config["profiler_options"]) - if iter_id == 5: - for key in engine.time_info: - engine.time_info[key].reset() - engine.time_info["reader_cost"].update(time.time() - tic) - - batch_size = batch[0].shape[0] - if not engine.config["Global"].get("use_multilabel", False): - batch[1] = batch[1].reshape([batch_size, -1]) - engine.global_step += 1 - - # forward & backward & step opt - if engine.amp: - with paddle.amp.auto_cast( - custom_black_list={ - "flatten_contiguous_range", "greater_than" - }, - level=engine.amp_level): - out = engine.model(batch) - loss_dict = engine.train_loss_func(out, batch[1]) - loss = loss_dict["loss"] / engine.update_freq - scaled = engine.scaler.scale(loss) - scaled.backward() - if (iter_id + 1) % engine.update_freq == 0: - for i in range(len(engine.optimizer)): - engine.scaler.minimize(engine.optimizer[i], scaled) - else: - out = engine.model(batch) - loss_dict = engine.train_loss_func(out, batch[1]) - loss = loss_dict["loss"] / engine.update_freq - loss.backward() - if (iter_id + 1) % engine.update_freq == 0: - for i in range(len(engine.optimizer)): - engine.optimizer[i].step() - - if (iter_id + 1) % engine.update_freq == 0: - # clear grad - for i in range(len(engine.optimizer)): - engine.optimizer[i].clear_grad() - # step lr(by step) - for i in range(len(engine.lr_sch)): - if not getattr(engine.lr_sch[i], "by_epoch", False): - engine.lr_sch[i].step() - # update ema - if engine.model_ema: - engine.model_ema.update(engine.model) - - # below code just for logging - # update metric_for_logger - update_metric(engine, out, batch, batch_size) - # update_loss_for_logger - update_loss(engine, loss_dict, batch_size) - engine.time_info["batch_cost"].update(time.time() - tic) - if iter_id % print_batch_step == 0: - log_info(engine, batch_size, epoch_id, iter_id) - tic = time.time() - - # step lr(by epoch) - for i in range(len(engine.lr_sch)): - if getattr(engine.lr_sch[i], "by_epoch", False) and \ - type_name(engine.lr_sch[i]) != "ReduceOnPlateau": - engine.lr_sch[i].step() diff --git a/ppcls/utils/save_load.py b/ppcls/utils/save_load.py index 2515be363ee5b268617c239428758e1e5a82ed04..776459fe566110b9b603d3e1a6ea04e5f8046d17 100644 --- a/ppcls/utils/save_load.py +++ b/ppcls/utils/save_load.py @@ -151,20 +151,20 @@ def _extract_student_weights(all_params, student_prefix="Student."): class ModelSaver(object): def __init__(self, - engine, + trainer, net_name="model", loss_name="train_loss_func", opt_name="optimizer", model_ema_name="model_ema"): # net, loss, opt, model_ema, output_dir, - self.engine = engine + self.trainer = trainer self.net_name = net_name self.loss_name = loss_name self.opt_name = opt_name self.model_ema_name = model_ema_name - arch_name = engine.config["Arch"]["name"] - self.output_dir = os.path.join(engine.output_dir, arch_name) + arch_name = trainer.config["Arch"]["name"] + self.output_dir = os.path.join(trainer.output_dir, arch_name) _mkdir_if_not_exist(self.output_dir) def save(self, metric_info, prefix='ppcls', save_student_model=False): @@ -174,8 +174,8 @@ class ModelSaver(object): save_dir = os.path.join(self.output_dir, prefix) - params_state_dict = getattr(self.engine, self.net_name).state_dict() - loss = getattr(self.engine, self.loss_name) + params_state_dict = getattr(self.trainer, self.net_name).state_dict() + loss = getattr(self.trainer, self.loss_name) if loss is not None: loss_state_dict = loss.state_dict() keys_inter = set(params_state_dict.keys()) & set( @@ -190,11 +190,11 @@ class ModelSaver(object): paddle.save(s_params, save_dir + "_student.pdparams") paddle.save(params_state_dict, save_dir + ".pdparams") - model_ema = getattr(self.engine, self.model_ema_name) + model_ema = getattr(self.trainer, self.model_ema_name) if model_ema is not None: paddle.save(model_ema.module.state_dict(), save_dir + ".ema.pdparams") - optimizer = getattr(self.engine, self.opt_name) + optimizer = getattr(self.trainer, self.opt_name) paddle.save([opt.state_dict() for opt in optimizer], save_dir + ".pdopt") paddle.save(metric_info, save_dir + ".pdstates")