From 15f6f5813914b02644aab39f791dd9cc1940994c Mon Sep 17 00:00:00 2001 From: dongshuilong Date: Tue, 24 Aug 2021 03:02:55 +0000 Subject: [PATCH] refactor trainer v2 --- ppcls/engine/{core.py => engine.py} | 25 +++--- ppcls/engine/{eval => evaluation}/__init__.py | 4 +- .../{eval => evaluation}/classification.py | 0 .../engine/{eval => evaluation}/retrieval.py | 0 ppcls/engine/train/__init__.py | 3 +- ppcls/engine/train/classification.py | 89 ------------------- ppcls/engine/train/{retrieval.py => train.py} | 13 ++- tools/eval.py | 6 +- tools/export_model.py | 6 +- tools/infer.py | 6 +- tools/train.py | 6 +- 11 files changed, 34 insertions(+), 124 deletions(-) rename ppcls/engine/{core.py => engine.py} (95%) rename ppcls/engine/{eval => evaluation}/__init__.py (82%) rename ppcls/engine/{eval => evaluation}/classification.py (100%) rename ppcls/engine/{eval => evaluation}/retrieval.py (100%) delete mode 100644 ppcls/engine/train/classification.py rename ppcls/engine/train/{retrieval.py => train.py} (90%) diff --git a/ppcls/engine/core.py b/ppcls/engine/engine.py similarity index 95% rename from ppcls/engine/core.py rename to ppcls/engine/engine.py index d0e448d7..711ec3b7 100644 --- a/ppcls/engine/core.py +++ b/ppcls/engine/engine.py @@ -47,19 +47,18 @@ from ppcls.utils import save_load from ppcls.data.utils.get_image_list import get_image_list from ppcls.data.postprocess import build_postprocess from ppcls.data import create_operators -from ppcls.engine.train import classification_train, retrieval_train -from ppcls.engine.eval import classification_eval, retrieval_eval +from ppcls.engine.train import train_epoch +from ppcls.engine import evaluation from ppcls.arch.gears.identity_head import IdentityHead -class Core(object): +class Engine(object): def __init__(self, config, mode="train"): - assert mode in ['train', 'eval', 'infer', 'export'] + assert mode in ["train", "eval", "infer", "export"] self.mode = mode self.config = config self.eval_mode = self.config["Global"].get("eval_mode", "classification") - # init logger self.output_dir = self.config['Global']['output_dir'] log_file = os.path.join(self.output_dir, self.config["Arch"]["name"], @@ -68,14 +67,10 @@ class Core(object): print_config(config) # init train_func and eval_func - if self.eval_mode == "classification": - self.evaler = classification_eval - self.trainer = classification_train - elif self.eval_mode == "retrieval": - self.trainer = retrieval_train - self.evaler = retrieval_eval - else: - logger.warning("Invalid eval mode: {}".format(self.eval_mode)) + assert self.eval_mode in ["classification", "retrieval"], logger.error("Invalid eval mode: {}".format(self.eval_mode)) + self.train_epoch_func = train_epoch + self.eval_func = getattr(evaluation, self.eval_mode + "_eval") + self.use_dali = self.config['Global'].get("use_dali", False) # for visualdl @@ -242,7 +237,7 @@ class Core(object): self.config["Global"]["epochs"] + 1): acc = 0.0 # for one epoch train - self.trainer(self, epoch_id, print_batch_step) + self.train_epoch_func(self, epoch_id, print_batch_step) if self.use_dali: self.train_dataloader.reset() @@ -304,7 +299,7 @@ class Core(object): def eval(self, epoch_id=0): assert self.mode in ["train", "eval"] self.model.eval() - eval_result = self.evaler(self, epoch_id) + eval_result = self.eval_func(self, epoch_id) self.model.train() return eval_result diff --git a/ppcls/engine/eval/__init__.py b/ppcls/engine/evaluation/__init__.py similarity index 82% rename from ppcls/engine/eval/__init__.py rename to ppcls/engine/evaluation/__init__.py index 0cc0dc98..e0cd7788 100644 --- a/ppcls/engine/eval/__init__.py +++ b/ppcls/engine/evaluation/__init__.py @@ -12,5 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ppcls.engine.eval.classification import classification_eval -from ppcls.engine.eval.retrieval import retrieval_eval +from ppcls.engine.evaluation.classification import classification_eval +from ppcls.engine.evaluation.retrieval import retrieval_eval diff --git a/ppcls/engine/eval/classification.py b/ppcls/engine/evaluation/classification.py similarity index 100% rename from ppcls/engine/eval/classification.py rename to ppcls/engine/evaluation/classification.py diff --git a/ppcls/engine/eval/retrieval.py b/ppcls/engine/evaluation/retrieval.py similarity index 100% rename from ppcls/engine/eval/retrieval.py rename to ppcls/engine/evaluation/retrieval.py diff --git a/ppcls/engine/train/__init__.py b/ppcls/engine/train/__init__.py index 897c5846..800d3a41 100644 --- a/ppcls/engine/train/__init__.py +++ b/ppcls/engine/train/__init__.py @@ -11,5 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from ppcls.engine.train.classification import classification_train -from ppcls.engine.train.retrieval import retrieval_train +from ppcls.engine.train.train import train_epoch diff --git a/ppcls/engine/train/classification.py b/ppcls/engine/train/classification.py deleted file mode 100644 index c3d74f96..00000000 --- a/ppcls/engine/train/classification.py +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from __future__ import absolute_import, division, print_function - -import datetime -import os -import platform -import sys -import time - -import numpy as np -import paddle - -__dir__ = os.path.dirname(os.path.abspath(__file__)) -sys.path.append(os.path.abspath(os.path.join(__dir__, '../../../'))) -from ppcls.utils import logger -from ppcls.utils.misc import AverageMeter -from ppcls.engine.train.utils import update_loss, update_metric, log_info - - -def classification_train(trainer, epoch_id, print_batch_step): - tic = time.time() - - train_dataloader = trainer.train_dataloader if trainer.use_dali else trainer.train_dataloader( - ) - for iter_id, batch in enumerate(train_dataloader): - if iter_id >= trainer.max_iter: - break - if iter_id == 5: - for key in trainer.time_info: - trainer.time_info[key].reset() - trainer.time_info["reader_cost"].update(time.time() - tic) - if trainer.use_dali: - batch = [ - paddle.to_tensor(batch[0]['data']), - paddle.to_tensor(batch[0]['label']) - ] - batch_size = batch[0].shape[0] - batch[1] = batch[1].reshape([-1, 1]).astype("int64") - - trainer.global_step += 1 - # image input - if trainer.amp: - with paddle.amp.auto_cast(custom_black_list={ - "flatten_contiguous_range", "greater_than" - }): - out = trainer.model(batch[0]) - loss_dict = trainer.train_loss_func(out, batch[1]) - else: - out = trainer.model(batch[0]) - - # calc loss - if trainer.config["DataLoader"]["Train"]["dataset"].get( - "batch_transform_ops", None): - loss_dict = trainer.train_loss_func(out, batch[1:]) - else: - loss_dict = trainer.train_loss_func(out, batch[1]) - - # step opt and lr - if trainer.amp: - scaled = trainer.scaler.scale(loss_dict["loss"]) - scaled.backward() - trainer.scaler.minimize(trainer.optimizer, scaled) - else: - loss_dict["loss"].backward() - trainer.optimizer.step() - trainer.optimizer.clear_grad() - trainer.lr_sch.step() - - # below code just for logging - # update metric_for_logger - update_metric(trainer, out, batch, batch_size) - # update_loss_for_logger - update_loss(trainer, loss_dict, batch_size) - trainer.time_info["batch_cost"].update(time.time() - tic) - if iter_id % print_batch_step == 0: - log_info(trainer, batch_size, epoch_id, iter_id) - tic = time.time() diff --git a/ppcls/engine/train/retrieval.py b/ppcls/engine/train/train.py similarity index 90% rename from ppcls/engine/train/retrieval.py rename to ppcls/engine/train/train.py index d4a99459..9ea173a9 100644 --- a/ppcls/engine/train/retrieval.py +++ b/ppcls/engine/train/train.py @@ -29,7 +29,7 @@ from ppcls.utils.misc import AverageMeter from ppcls.engine.train.utils import update_loss, update_metric, log_info -def retrieval_train(trainer, epoch_id, print_batch_step): +def train_epoch(trainer, epoch_id, print_batch_step): tic = time.time() train_dataloader = trainer.train_dataloader if trainer.use_dali else trainer.train_dataloader( @@ -55,10 +55,10 @@ def retrieval_train(trainer, epoch_id, print_batch_step): with paddle.amp.auto_cast(custom_black_list={ "flatten_contiguous_range", "greater_than" }): - out = trainer.model(batch[0], batch[1]) + out = forward(trainer, batch) loss_dict = trainer.train_loss_func(out, batch[1]) else: - out = trainer.model(batch[0], batch[1]) + out = forward(trainer, batch) # calc loss if trainer.config["DataLoader"]["Train"]["dataset"].get( @@ -81,10 +81,15 @@ def retrieval_train(trainer, epoch_id, print_batch_step): # below code just for logging # update metric_for_logger update_metric(trainer, out, batch, batch_size) - # update_loss_for_logger update_loss(trainer, loss_dict, batch_size) trainer.time_info["batch_cost"].update(time.time() - tic) if iter_id % print_batch_step == 0: log_info(trainer, batch_size, epoch_id, iter_id) tic = time.time() + +def forward(trainer, batch): + if trainer.eval_mode == "classification": + return trainer.model(batch[0]) + else: + return trainer.model(batch[0], batch[1]) diff --git a/tools/eval.py b/tools/eval.py index 367f5f8b..e086da1b 100644 --- a/tools/eval.py +++ b/tools/eval.py @@ -21,11 +21,11 @@ __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(os.path.abspath(os.path.join(__dir__, '../'))) from ppcls.utils import config -from ppcls.engine.core import Core +from ppcls.engine.engine import Engine if __name__ == "__main__": args = config.parse_args() config = config.get_config( args.config, overrides=args.override, show=False) - evaler = Core(config, mode="eval") - evaler.eval() + engine = Engine(config, mode="eval") + engine.eval() diff --git a/tools/export_model.py b/tools/export_model.py index 7d324a61..01aba06c 100644 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -24,11 +24,11 @@ import paddle import paddle.nn as nn from ppcls.utils import config -from ppcls.engine.core import Core +from ppcls.engine.engine import Engine if __name__ == "__main__": args = config.parse_args() config = config.get_config( args.config, overrides=args.override, show=False) - exporter = Core(config, mode="export") - exporter.export() + engine = Engine(config, mode="export") + engine.export() diff --git a/tools/infer.py b/tools/infer.py index 9c00542c..4f6bf927 100644 --- a/tools/infer.py +++ b/tools/infer.py @@ -21,11 +21,11 @@ __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(os.path.abspath(os.path.join(__dir__, '../'))) from ppcls.utils import config -from ppcls.engine.core import Core +from ppcls.engine.engine import Engine if __name__ == "__main__": args = config.parse_args() config = config.get_config( args.config, overrides=args.override, show=False) - inferer = Core(config, mode="infer") - inferer.infer() + engine = Engine(config, mode="infer") + engine.infer() diff --git a/tools/train.py b/tools/train.py index 23a064e7..1d835903 100644 --- a/tools/train.py +++ b/tools/train.py @@ -21,11 +21,11 @@ __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(os.path.abspath(os.path.join(__dir__, '../'))) from ppcls.utils import config -from ppcls.engine.core import Core +from ppcls.engine.engine import Engine if __name__ == "__main__": args = config.parse_args() config = config.get_config( args.config, overrides=args.override, show=False) - trainer = Core(config, mode="train") - trainer.train() + engine = Engine(config, mode="train") + engine.train() -- GitLab