engine.py 22.3 KB
Newer Older
D
dongshuilong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

D
dongshuilong 已提交
17
import os
D
dongshuilong 已提交
18 19 20 21
import platform
import paddle
import paddle.distributed as dist
from visualdl import LogWriter
D
dongshuilong 已提交
22
from paddle import nn
D
dongshuilong 已提交
23 24
import numpy as np
import random
D
dongshuilong 已提交
25 26 27 28 29 30 31

from ppcls.utils.check import check_gpu
from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger
from ppcls.utils.logger import init_logger
from ppcls.utils.config import print_config
from ppcls.data import build_dataloader
W
dbg  
weishengyu 已提交
32
from ppcls.arch import build_model, RecModel, DistillationModel, TheseusLayer
D
dongshuilong 已提交
33 34 35 36 37 38 39 40 41 42 43
from ppcls.arch import apply_to_static
from ppcls.loss import build_loss
from ppcls.metric import build_metrics
from ppcls.optimizer import build_optimizer
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
from ppcls.utils.save_load import init_model
from ppcls.utils import save_load

from ppcls.data.utils.get_image_list import get_image_list
from ppcls.data.postprocess import build_postprocess
from ppcls.data import create_operators
D
dongshuilong 已提交
44 45
from ppcls.engine.train import train_epoch
from ppcls.engine import evaluation
D
dongshuilong 已提交
46 47 48
from ppcls.arch.gears.identity_head import IdentityHead


D
dongshuilong 已提交
49
class Engine(object):
D
dongshuilong 已提交
50
    def __init__(self, config, mode="train"):
D
dongshuilong 已提交
51
        assert mode in ["train", "eval", "infer", "export"]
D
dongshuilong 已提交
52 53 54 55
        self.mode = mode
        self.config = config
        self.eval_mode = self.config["Global"].get("eval_mode",
                                                   "classification")
56 57
        if "Head" in self.config["Arch"] or self.config["Arch"].get("is_rec",
                                                                    False):
D
dongshuilong 已提交
58 59 60 61
            self.is_rec = True
        else:
            self.is_rec = False

D
dongshuilong 已提交
62 63
        # set seed
        seed = self.config["Global"].get("seed", False)
S
stephon 已提交
64
        if seed or seed == 0:
D
dongshuilong 已提交
65 66 67 68 69
            assert isinstance(seed, int), "The 'seed' must be a integer!"
            paddle.seed(seed)
            np.random.seed(seed)
            random.seed(seed)

D
dongshuilong 已提交
70 71 72 73
        # init logger
        self.output_dir = self.config['Global']['output_dir']
        log_file = os.path.join(self.output_dir, self.config["Arch"]["name"],
                                f"{mode}.log")
D
dongshuilong 已提交
74
        init_logger(log_file=log_file)
D
dongshuilong 已提交
75 76 77
        print_config(config)

        # init train_func and eval_func
D
dongshuilong 已提交
78 79 80
        assert self.eval_mode in [
            "classification", "retrieval", "adaface"
        ], logger.error("Invalid eval mode: {}".format(self.eval_mode))
D
dongshuilong 已提交
81 82 83
        self.train_epoch_func = train_epoch
        self.eval_func = getattr(evaluation, self.eval_mode + "_eval")

D
dongshuilong 已提交
84 85 86 87
        self.use_dali = self.config['Global'].get("use_dali", False)

        # for visualdl
        self.vdl_writer = None
T
Tingquan Gao 已提交
88 89
        if self.config['Global'][
                'use_visualdl'] and mode == "train" and dist.get_rank() == 0:
D
dongshuilong 已提交
90 91 92 93 94 95
            vdl_writer_path = os.path.join(self.output_dir, "vdl")
            if not os.path.exists(vdl_writer_path):
                os.makedirs(vdl_writer_path)
            self.vdl_writer = LogWriter(logdir=vdl_writer_path)

        # set device
D
dongshuilong 已提交
96 97
        assert self.config["Global"][
            "device"] in ["cpu", "gpu", "xpu", "npu", "mlu"]
D
dongshuilong 已提交
98 99 100 101
        self.device = paddle.set_device(self.config["Global"]["device"])
        logger.info('train with paddle {} and device {}'.format(
            paddle.__version__, self.device))

102 103 104 105 106 107 108 109
        if "class_num" in config["Global"]:
            global_class_num = config["Global"]["class_num"]
            if "class_num" not in config["Arch"]:
                config["Arch"]["class_num"] = global_class_num
                msg = f"The Global.class_num will be deprecated. Please use Arch.class_num instead. Arch.class_num has been set to {global_class_num}."
            else:
                msg = "The Global.class_num will be deprecated. Please use Arch.class_num instead. The Global.class_num has been ignored."
            logger.warning(msg)
110
        #TODO(gaotingquan): support rec
G
gaotingquan 已提交
111 112
        class_num = config["Arch"].get("class_num", None)
        self.config["DataLoader"].update({"class_num": class_num})
D
dongshuilong 已提交
113 114 115 116
        # build dataloader
        if self.mode == 'train':
            self.train_dataloader = build_dataloader(
                self.config["DataLoader"], "Train", self.device, self.use_dali)
D
dongshuilong 已提交
117 118
        if self.mode == "eval" or (self.mode == "train" and
                                   self.config["Global"]["eval_during_train"]):
D
dongshuilong 已提交
119 120 121 122 123
            if self.eval_mode == "classification":
                self.eval_dataloader = build_dataloader(
                    self.config["DataLoader"], "Eval", self.device,
                    self.use_dali)
            elif self.eval_mode == "retrieval":
124 125 126 127 128 129 130 131 132 133 134 135 136
                self.gallery_query_dataloader = None
                if len(self.config["DataLoader"]["Eval"].keys()) == 1:
                    key = list(self.config["DataLoader"]["Eval"].keys())[0]
                    self.gallery_query_dataloader = build_dataloader(
                        self.config["DataLoader"]["Eval"], key, self.device,
                        self.use_dali)
                else:
                    self.gallery_dataloader = build_dataloader(
                        self.config["DataLoader"]["Eval"], "Gallery",
                        self.device, self.use_dali)
                    self.query_dataloader = build_dataloader(
                        self.config["DataLoader"]["Eval"], "Query",
                        self.device, self.use_dali)
D
dongshuilong 已提交
137 138 139 140 141

        # build loss
        if self.mode == "train":
            loss_info = self.config["Loss"]["Train"]
            self.train_loss_func = build_loss(loss_info)
D
dongshuilong 已提交
142 143
        if self.mode == "eval" or (self.mode == "train" and
                                   self.config["Global"]["eval_during_train"]):
D
dongshuilong 已提交
144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
            loss_config = self.config.get("Loss", None)
            if loss_config is not None:
                loss_config = loss_config.get("Eval")
                if loss_config is not None:
                    self.eval_loss_func = build_loss(loss_config)
                else:
                    self.eval_loss_func = None
            else:
                self.eval_loss_func = None

        # build metric
        if self.mode == 'train':
            metric_config = self.config.get("Metric")
            if metric_config is not None:
                metric_config = metric_config.get("Train")
                if metric_config is not None:
littletomatodonkey's avatar
littletomatodonkey 已提交
160 161 162
                    if hasattr(
                            self.train_dataloader, "collate_fn"
                    ) and self.train_dataloader.collate_fn is not None:
163 164 165 166 167 168
                        for m_idx, m in enumerate(metric_config):
                            if "TopkAcc" in m:
                                msg = f"'TopkAcc' metric can not be used when setting 'batch_transform_ops' in config. The 'TopkAcc' metric has been removed."
                                logger.warning(msg)
                                break
                        metric_config.pop(m_idx)
D
dongshuilong 已提交
169 170 171 172 173 174
                    self.train_metric_func = build_metrics(metric_config)
                else:
                    self.train_metric_func = None
        else:
            self.train_metric_func = None

D
dongshuilong 已提交
175 176
        if self.mode == "eval" or (self.mode == "train" and
                                   self.config["Global"]["eval_during_train"]):
D
dongshuilong 已提交
177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192
            metric_config = self.config.get("Metric")
            if self.eval_mode == "classification":
                if metric_config is not None:
                    metric_config = metric_config.get("Eval")
                    if metric_config is not None:
                        self.eval_metric_func = build_metrics(metric_config)
            elif self.eval_mode == "retrieval":
                if metric_config is None:
                    metric_config = [{"name": "Recallk", "topk": (1, 5)}]
                else:
                    metric_config = metric_config["Eval"]
                self.eval_metric_func = build_metrics(metric_config)
        else:
            self.eval_metric_func = None

        # build model
W
weishengyu 已提交
193
        self.model = build_model(self.config)
D
dongshuilong 已提交
194 195
        # set @to_static for benchmark, skip this by default.
        apply_to_static(self.config, self.model)
D
dongshuilong 已提交
196

D
dongshuilong 已提交
197 198 199 200
        # load_pretrain
        if self.config["Global"]["pretrained_model"] is not None:
            if self.config["Global"]["pretrained_model"].startswith("http"):
                load_dygraph_pretrain_from_url(
201
                    [self.model, getattr(self, 'train_loss_func', None)],
202
                    self.config["Global"]["pretrained_model"])
D
dongshuilong 已提交
203 204
            else:
                load_dygraph_pretrain(
205
                    [self.model, getattr(self, 'train_loss_func', None)],
206
                    self.config["Global"]["pretrained_model"])
D
dongshuilong 已提交
207 208 209 210

        # build optimizer
        if self.mode == 'train':
            self.optimizer, self.lr_sch = build_optimizer(
H
HydrogenSulfate 已提交
211
                self.config["Optimizer"], self.config["Global"]["epochs"],
212 213
                len(self.train_dataloader),
                [self.model, self.train_loss_func])
214

215 216 217
        # AMP training and evaluating
        self.amp = "AMP" in self.config and self.config["AMP"] is not None
        self.amp_eval = False
G
gaotingquan 已提交
218
        # for amp
Z
zhangbo9674 已提交
219
        if self.amp:
220 221 222 223 224 225 226 227 228 229
            AMP_RELATED_FLAGS_SETTING = {'FLAGS_max_inplace_grad_add': 8, }
            if paddle.is_compiled_with_cuda():
                AMP_RELATED_FLAGS_SETTING.update({
                    'FLAGS_cudnn_batchnorm_spatial_persistent': 1
                })
            paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)

            self.scale_loss = self.config["AMP"].get("scale_loss", 1.0)
            self.use_dynamic_loss_scaling = self.config["AMP"].get(
                "use_dynamic_loss_scaling", False)
Z
zhangbo9674 已提交
230 231 232
            self.scaler = paddle.amp.GradScaler(
                init_loss_scaling=self.scale_loss,
                use_dynamic_loss_scaling=self.use_dynamic_loss_scaling)
233 234 235

            self.amp_level = self.config['AMP'].get("level", "O1")
            if self.amp_level not in ["O1", "O2"]:
236 237 238
                msg = "[Parameter Error]: The optimize level of AMP only support 'O1' and 'O2'. The level has been set 'O1'."
                logger.warning(msg)
                self.config['AMP']["level"] = "O1"
239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280
                self.amp_level = "O1"

            self.amp_eval = self.config["AMP"].get("use_fp16_test", False)
            # TODO(gaotingquan): Paddle not yet support FP32 evaluation when training with AMPO2
            if self.config["Global"].get(
                    "eval_during_train",
                    True) and self.amp_level == "O2" and self.amp_eval == False:
                msg = "PaddlePaddle only support FP16 evaluation when training with AMP O2 now. "
                logger.warning(msg)
                self.config["AMP"]["use_fp16_test"] = True
                self.amp_eval = True

            # TODO(gaotingquan): to compatible with Paddle 2.2, 2.3, develop and so on.
            paddle_version = sum([
                int(x) * 10**(2 - i)
                for i, x in enumerate(paddle.__version__.split(".")[:3])
            ])
            # paddle version < 2.3.0 and not develop
            if paddle_version < 230 and paddle_version != 0:
                if self.mode == "train":
                    self.model, self.optimizer = paddle.amp.decorate(
                        models=self.model,
                        optimizers=self.optimizer,
                        level=self.amp_level,
                        save_dtype='float32')
                elif self.amp_eval:
                    if self.amp_level == "O2":
                        msg = "The PaddlePaddle that installed not support FP16 evaluation in AMP O2. Please use PaddlePaddle version >= 2.3.0. Use FP32 evaluation instead and please notice the Eval Dataset output_fp16 should be 'False'."
                        logger.warning(msg)
                        self.amp_eval = False
                    else:
                        self.model, self.optimizer = paddle.amp.decorate(
                            models=self.model,
                            level=self.amp_level,
                            save_dtype='float32')
            # paddle version >= 2.3.0 or develop
            else:
                self.model = paddle.amp.decorate(
                    models=self.model,
                    level=self.amp_level,
                    save_dtype='float32')

G
gaotingquan 已提交
281 282
            if self.mode == "train" and len(self.train_loss_func.parameters(
            )) > 0:
283 284
                self.train_loss_func = paddle.amp.decorate(
                    models=self.train_loss_func,
285
                    level=self.amp_level,
286
                    save_dtype='float32')
D
dongshuilong 已提交
287

288
        # check the gpu num
289 290
        world_size = dist.get_world_size()
        self.config["Global"]["distributed"] = world_size != 1
291
        if self.mode == "train":
H
HydrogenSulfate 已提交
292 293 294
            std_gpu_num = 8 if isinstance(
                self.config["Optimizer"],
                dict) and self.config["Optimizer"]["name"] == "AdamW" else 4
295 296 297 298 299
            if world_size != std_gpu_num:
                msg = f"The training strategy provided by PaddleClas is based on {std_gpu_num} gpus. But the number of gpu is {world_size} in current training. Please modify the stategy (learning rate, batch size and so on) if use this config to train."
                logger.warning(msg)

        # for distributed
D
dongshuilong 已提交
300 301 302
        if self.config["Global"]["distributed"]:
            dist.init_parallel_env()
            self.model = paddle.DataParallel(self.model)
303 304
            if self.mode == 'train' and len(self.train_loss_func.parameters(
            )) > 0:
305 306
                self.train_loss_func = paddle.DataParallel(
                    self.train_loss_func)
D
dongshuilong 已提交
307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333
        # build postprocess for infer
        if self.mode == 'infer':
            self.preprocess_func = create_operators(self.config["Infer"][
                "transforms"])
            self.postprocess_func = build_postprocess(self.config["Infer"][
                "PostProcess"])

    def train(self):
        assert self.mode == "train"
        print_batch_step = self.config['Global']['print_batch_step']
        save_interval = self.config["Global"]["save_interval"]
        best_metric = {
            "metric": 0.0,
            "epoch": 0,
        }
        # key:
        # val: metrics list word
        self.output_info = dict()
        self.time_info = {
            "batch_cost": AverageMeter(
                "batch_cost", '.5f', postfix=" s,"),
            "reader_cost": AverageMeter(
                "reader_cost", ".5f", postfix=" s,"),
        }
        # global iter counter
        self.global_step = 0

334 335 336
        if self.config.Global.checkpoints is not None:
            metric_info = init_model(self.config.Global, self.model,
                                     self.optimizer, self.train_loss_func)
D
dongshuilong 已提交
337 338 339 340 341
            if metric_info is not None:
                best_metric.update(metric_info)

        self.max_iter = len(self.train_dataloader) - 1 if platform.system(
        ) == "Windows" else len(self.train_dataloader)
H
HydrogenSulfate 已提交
342

D
dongshuilong 已提交
343 344 345 346
        for epoch_id in range(best_metric["epoch"] + 1,
                              self.config["Global"]["epochs"] + 1):
            acc = 0.0
            # for one epoch train
D
dongshuilong 已提交
347
            self.train_epoch_func(self, epoch_id, print_batch_step)
D
dongshuilong 已提交
348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372

            if self.use_dali:
                self.train_dataloader.reset()
            metric_msg = ", ".join([
                "{}: {:.5f}".format(key, self.output_info[key].avg)
                for key in self.output_info
            ])
            logger.info("[Train][Epoch {}/{}][Avg]{}".format(
                epoch_id, self.config["Global"]["epochs"], metric_msg))
            self.output_info.clear()

            # eval model and save model if possible
            if self.config["Global"][
                    "eval_during_train"] and epoch_id % self.config["Global"][
                        "eval_interval"] == 0:
                acc = self.eval(epoch_id)
                if acc > best_metric["metric"]:
                    best_metric["metric"] = acc
                    best_metric["epoch"] = epoch_id
                    save_load.save_model(
                        self.model,
                        self.optimizer,
                        best_metric,
                        self.output_dir,
                        model_name=self.config["Arch"]["name"],
373 374
                        prefix="best_model",
                        loss=self.train_loss_func)
D
dongshuilong 已提交
375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392
                logger.info("[Eval][Epoch {}][best metric: {}]".format(
                    epoch_id, best_metric["metric"]))
                logger.scaler(
                    name="eval_acc",
                    value=acc,
                    step=epoch_id,
                    writer=self.vdl_writer)

                self.model.train()

            # save model
            if epoch_id % save_interval == 0:
                save_load.save_model(
                    self.model,
                    self.optimizer, {"metric": acc,
                                     "epoch": epoch_id},
                    self.output_dir,
                    model_name=self.config["Arch"]["name"],
393 394
                    prefix="epoch_{}".format(epoch_id),
                    loss=self.train_loss_func)
G
gaotingquan 已提交
395 396 397 398 399 400 401
            # save the latest model
            save_load.save_model(
                self.model,
                self.optimizer, {"metric": acc,
                                 "epoch": epoch_id},
                self.output_dir,
                model_name=self.config["Arch"]["name"],
402 403
                prefix="latest",
                loss=self.train_loss_func)
D
dongshuilong 已提交
404 405 406 407 408 409 410 411

        if self.vdl_writer is not None:
            self.vdl_writer.close()

    @paddle.no_grad()
    def eval(self, epoch_id=0):
        assert self.mode in ["train", "eval"]
        self.model.eval()
D
dongshuilong 已提交
412
        eval_result = self.eval_func(self, epoch_id)
D
dongshuilong 已提交
413 414 415 416 417 418
        self.model.train()
        return eval_result

    @paddle.no_grad()
    def infer(self):
        assert self.mode == "infer" and self.eval_mode == "classification"
419 420
        total_trainer = dist.get_world_size()
        local_rank = dist.get_rank()
D
dongshuilong 已提交
421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440
        image_list = get_image_list(self.config["Infer"]["infer_imgs"])
        # data split
        image_list = image_list[local_rank::total_trainer]

        batch_size = self.config["Infer"]["batch_size"]
        self.model.eval()
        batch_data = []
        image_file_list = []
        for idx, image_file in enumerate(image_list):
            with open(image_file, 'rb') as f:
                x = f.read()
            for process in self.preprocess_func:
                x = process(x)
            batch_data.append(x)
            image_file_list.append(image_file)
            if len(batch_data) >= batch_size or idx == len(image_list) - 1:
                batch_tensor = paddle.to_tensor(batch_data)
                out = self.model(batch_tensor)
                if isinstance(out, list):
                    out = out[0]
441 442 443
                if isinstance(out, dict) and "logits" in out:
                    out = out["logits"]
                if isinstance(out, dict) and "output" in out:
W
dbg  
weishengyu 已提交
444
                    out = out["output"]
D
dongshuilong 已提交
445 446 447 448 449 450 451
                result = self.postprocess_func(out, image_file_list)
                print(result)
                batch_data.clear()
                image_file_list.clear()

    def export(self):
        assert self.mode == "export"
C
cuicheng01 已提交
452 453
        use_multilabel = self.config["Global"].get("use_multilabel", False)
        model = ExportModel(self.config["Arch"], self.model, use_multilabel)
D
dongshuilong 已提交
454 455 456 457 458
        if self.config["Global"]["pretrained_model"] is not None:
            load_dygraph_pretrain(model.base_model,
                                  self.config["Global"]["pretrained_model"])

        model.eval()
D
dongshuilong 已提交
459 460
        save_path = os.path.join(self.config["Global"]["save_inference_dir"],
                                 "inference")
W
weishengyu 已提交
461
        if model.quanter:
W
weishengyu 已提交
462
            model.quanter.save_quantized_model(
C
cuicheng01 已提交
463
                model.base_model,
D
dongshuilong 已提交
464 465 466 467 468 469 470 471 472 473 474 475 476 477 478
                save_path,
                input_spec=[
                    paddle.static.InputSpec(
                        shape=[None] + self.config["Global"]["image_shape"],
                        dtype='float32')
                ])
        else:
            model = paddle.jit.to_static(
                model,
                input_spec=[
                    paddle.static.InputSpec(
                        shape=[None] + self.config["Global"]["image_shape"],
                        dtype='float32')
                ])
            paddle.jit.save(model, save_path)
D
dongshuilong 已提交
479 480


W
dbg  
weishengyu 已提交
481
class ExportModel(TheseusLayer):
D
dongshuilong 已提交
482 483 484 485
    """
    ExportModel: add softmax onto the model
    """

C
cuicheng01 已提交
486
    def __init__(self, config, model, use_multilabel):
D
dongshuilong 已提交
487 488 489 490 491 492 493 494 495 496 497 498
        super().__init__()
        self.base_model = model
        # we should choose a final model to export
        if isinstance(self.base_model, DistillationModel):
            self.infer_model_name = config["infer_model_name"]
        else:
            self.infer_model_name = None

        self.infer_output_key = config.get("infer_output_key", None)
        if self.infer_output_key == "features" and isinstance(self.base_model,
                                                              RecModel):
            self.base_model.head = IdentityHead()
C
cuicheng01 已提交
499 500
        if use_multilabel:
            self.out_act = nn.Sigmoid()
D
dongshuilong 已提交
501
        else:
C
cuicheng01 已提交
502 503 504 505
            if config.get("infer_add_softmax", True):
                self.out_act = nn.Softmax(axis=-1)
            else:
                self.out_act = None
D
dongshuilong 已提交
506 507 508 509 510 511 512 513 514 515 516 517 518 519 520

    def eval(self):
        self.training = False
        for layer in self.sublayers():
            layer.training = False
            layer.eval()

    def forward(self, x):
        x = self.base_model(x)
        if isinstance(x, list):
            x = x[0]
        if self.infer_model_name is not None:
            x = x[self.infer_model_name]
        if self.infer_output_key is not None:
            x = x[self.infer_output_key]
C
cuicheng01 已提交
521
        if self.out_act is not None:
wc晨曦's avatar
wc晨曦 已提交
522 523
            if isinstance(x, dict):
                x = x["logits"]
C
cuicheng01 已提交
524
            x = self.out_act(x)
D
dongshuilong 已提交
525
        return x