提交 15f6f581 编写于 作者: D dongshuilong

refactor trainer v2

上级 ebde0e13
......@@ -47,19 +47,18 @@ from ppcls.utils import save_load
from ppcls.data.utils.get_image_list import get_image_list
from ppcls.data.postprocess import build_postprocess
from ppcls.data import create_operators
from ppcls.engine.train import classification_train, retrieval_train
from ppcls.engine.eval import classification_eval, retrieval_eval
from ppcls.engine.train import train_epoch
from ppcls.engine import evaluation
from ppcls.arch.gears.identity_head import IdentityHead
class Core(object):
class Engine(object):
def __init__(self, config, mode="train"):
assert mode in ['train', 'eval', 'infer', 'export']
assert mode in ["train", "eval", "infer", "export"]
self.mode = mode
self.config = config
self.eval_mode = self.config["Global"].get("eval_mode",
"classification")
# init logger
self.output_dir = self.config['Global']['output_dir']
log_file = os.path.join(self.output_dir, self.config["Arch"]["name"],
......@@ -68,14 +67,10 @@ class Core(object):
print_config(config)
# init train_func and eval_func
if self.eval_mode == "classification":
self.evaler = classification_eval
self.trainer = classification_train
elif self.eval_mode == "retrieval":
self.trainer = retrieval_train
self.evaler = retrieval_eval
else:
logger.warning("Invalid eval mode: {}".format(self.eval_mode))
assert self.eval_mode in ["classification", "retrieval"], logger.error("Invalid eval mode: {}".format(self.eval_mode))
self.train_epoch_func = train_epoch
self.eval_func = getattr(evaluation, self.eval_mode + "_eval")
self.use_dali = self.config['Global'].get("use_dali", False)
# for visualdl
......@@ -242,7 +237,7 @@ class Core(object):
self.config["Global"]["epochs"] + 1):
acc = 0.0
# for one epoch train
self.trainer(self, epoch_id, print_batch_step)
self.train_epoch_func(self, epoch_id, print_batch_step)
if self.use_dali:
self.train_dataloader.reset()
......@@ -304,7 +299,7 @@ class Core(object):
def eval(self, epoch_id=0):
assert self.mode in ["train", "eval"]
self.model.eval()
eval_result = self.evaler(self, epoch_id)
eval_result = self.eval_func(self, epoch_id)
self.model.train()
return eval_result
......
......@@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ppcls.engine.eval.classification import classification_eval
from ppcls.engine.eval.retrieval import retrieval_eval
from ppcls.engine.evaluation.classification import classification_eval
from ppcls.engine.evaluation.retrieval import retrieval_eval
......@@ -11,5 +11,4 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ppcls.engine.train.classification import classification_train
from ppcls.engine.train.retrieval import retrieval_train
from ppcls.engine.train.train import train_epoch
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function
import datetime
import os
import platform
import sys
import time
import numpy as np
import paddle
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../../../')))
from ppcls.utils import logger
from ppcls.utils.misc import AverageMeter
from ppcls.engine.train.utils import update_loss, update_metric, log_info
def classification_train(trainer, epoch_id, print_batch_step):
tic = time.time()
train_dataloader = trainer.train_dataloader if trainer.use_dali else trainer.train_dataloader(
)
for iter_id, batch in enumerate(train_dataloader):
if iter_id >= trainer.max_iter:
break
if iter_id == 5:
for key in trainer.time_info:
trainer.time_info[key].reset()
trainer.time_info["reader_cost"].update(time.time() - tic)
if trainer.use_dali:
batch = [
paddle.to_tensor(batch[0]['data']),
paddle.to_tensor(batch[0]['label'])
]
batch_size = batch[0].shape[0]
batch[1] = batch[1].reshape([-1, 1]).astype("int64")
trainer.global_step += 1
# image input
if trainer.amp:
with paddle.amp.auto_cast(custom_black_list={
"flatten_contiguous_range", "greater_than"
}):
out = trainer.model(batch[0])
loss_dict = trainer.train_loss_func(out, batch[1])
else:
out = trainer.model(batch[0])
# calc loss
if trainer.config["DataLoader"]["Train"]["dataset"].get(
"batch_transform_ops", None):
loss_dict = trainer.train_loss_func(out, batch[1:])
else:
loss_dict = trainer.train_loss_func(out, batch[1])
# step opt and lr
if trainer.amp:
scaled = trainer.scaler.scale(loss_dict["loss"])
scaled.backward()
trainer.scaler.minimize(trainer.optimizer, scaled)
else:
loss_dict["loss"].backward()
trainer.optimizer.step()
trainer.optimizer.clear_grad()
trainer.lr_sch.step()
# below code just for logging
# update metric_for_logger
update_metric(trainer, out, batch, batch_size)
# update_loss_for_logger
update_loss(trainer, loss_dict, batch_size)
trainer.time_info["batch_cost"].update(time.time() - tic)
if iter_id % print_batch_step == 0:
log_info(trainer, batch_size, epoch_id, iter_id)
tic = time.time()
......@@ -29,7 +29,7 @@ from ppcls.utils.misc import AverageMeter
from ppcls.engine.train.utils import update_loss, update_metric, log_info
def retrieval_train(trainer, epoch_id, print_batch_step):
def train_epoch(trainer, epoch_id, print_batch_step):
tic = time.time()
train_dataloader = trainer.train_dataloader if trainer.use_dali else trainer.train_dataloader(
......@@ -55,10 +55,10 @@ def retrieval_train(trainer, epoch_id, print_batch_step):
with paddle.amp.auto_cast(custom_black_list={
"flatten_contiguous_range", "greater_than"
}):
out = trainer.model(batch[0], batch[1])
out = forward(trainer, batch)
loss_dict = trainer.train_loss_func(out, batch[1])
else:
out = trainer.model(batch[0], batch[1])
out = forward(trainer, batch)
# calc loss
if trainer.config["DataLoader"]["Train"]["dataset"].get(
......@@ -81,10 +81,15 @@ def retrieval_train(trainer, epoch_id, print_batch_step):
# below code just for logging
# update metric_for_logger
update_metric(trainer, out, batch, batch_size)
# update_loss_for_logger
update_loss(trainer, loss_dict, batch_size)
trainer.time_info["batch_cost"].update(time.time() - tic)
if iter_id % print_batch_step == 0:
log_info(trainer, batch_size, epoch_id, iter_id)
tic = time.time()
def forward(trainer, batch):
if trainer.eval_mode == "classification":
return trainer.model(batch[0])
else:
return trainer.model(batch[0], batch[1])
......@@ -21,11 +21,11 @@ __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../')))
from ppcls.utils import config
from ppcls.engine.core import Core
from ppcls.engine.engine import Engine
if __name__ == "__main__":
args = config.parse_args()
config = config.get_config(
args.config, overrides=args.override, show=False)
evaler = Core(config, mode="eval")
evaler.eval()
engine = Engine(config, mode="eval")
engine.eval()
......@@ -24,11 +24,11 @@ import paddle
import paddle.nn as nn
from ppcls.utils import config
from ppcls.engine.core import Core
from ppcls.engine.engine import Engine
if __name__ == "__main__":
args = config.parse_args()
config = config.get_config(
args.config, overrides=args.override, show=False)
exporter = Core(config, mode="export")
exporter.export()
engine = Engine(config, mode="export")
engine.export()
......@@ -21,11 +21,11 @@ __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../')))
from ppcls.utils import config
from ppcls.engine.core import Core
from ppcls.engine.engine import Engine
if __name__ == "__main__":
args = config.parse_args()
config = config.get_config(
args.config, overrides=args.override, show=False)
inferer = Core(config, mode="infer")
inferer.infer()
engine = Engine(config, mode="infer")
engine.infer()
......@@ -21,11 +21,11 @@ __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../')))
from ppcls.utils import config
from ppcls.engine.core import Core
from ppcls.engine.engine import Engine
if __name__ == "__main__":
args = config.parse_args()
config = config.get_config(
args.config, overrides=args.override, show=False)
trainer = Core(config, mode="train")
trainer.train()
engine = Engine(config, mode="train")
engine.train()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册