engine.py 20.2 KB
Newer Older
D
dongshuilong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

D
dongshuilong 已提交
17
import os
D
dongshuilong 已提交
18 19 20 21
import platform
import paddle
import paddle.distributed as dist
from visualdl import LogWriter
D
dongshuilong 已提交
22
from paddle import nn
D
dongshuilong 已提交
23 24
import numpy as np
import random
D
dongshuilong 已提交
25 26 27 28 29 30 31

from ppcls.utils.check import check_gpu
from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger
from ppcls.utils.logger import init_logger
from ppcls.utils.config import print_config
from ppcls.data import build_dataloader
W
dbg  
weishengyu 已提交
32
from ppcls.arch import build_model, RecModel, DistillationModel, TheseusLayer
D
dongshuilong 已提交
33 34 35 36 37 38 39 40 41 42 43
from ppcls.arch import apply_to_static
from ppcls.loss import build_loss
from ppcls.metric import build_metrics
from ppcls.optimizer import build_optimizer
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
from ppcls.utils.save_load import init_model
from ppcls.utils import save_load

from ppcls.data.utils.get_image_list import get_image_list
from ppcls.data.postprocess import build_postprocess
from ppcls.data import create_operators
D
dongshuilong 已提交
44 45
from ppcls.engine.train import train_epoch
from ppcls.engine import evaluation
D
dongshuilong 已提交
46 47 48
from ppcls.arch.gears.identity_head import IdentityHead


D
dongshuilong 已提交
49
class Engine(object):
D
dongshuilong 已提交
50
    def __init__(self, config, mode="train"):
D
dongshuilong 已提交
51
        assert mode in ["train", "eval", "infer", "export"]
D
dongshuilong 已提交
52 53 54 55
        self.mode = mode
        self.config = config
        self.eval_mode = self.config["Global"].get("eval_mode",
                                                   "classification")
56 57
        if "Head" in self.config["Arch"] or self.config["Arch"].get("is_rec",
                                                                    False):
D
dongshuilong 已提交
58 59 60 61
            self.is_rec = True
        else:
            self.is_rec = False

D
dongshuilong 已提交
62 63
        # set seed
        seed = self.config["Global"].get("seed", False)
S
stephon 已提交
64
        if seed or seed == 0:
D
dongshuilong 已提交
65 66 67 68 69
            assert isinstance(seed, int), "The 'seed' must be a integer!"
            paddle.seed(seed)
            np.random.seed(seed)
            random.seed(seed)

D
dongshuilong 已提交
70 71 72 73
        # init logger
        self.output_dir = self.config['Global']['output_dir']
        log_file = os.path.join(self.output_dir, self.config["Arch"]["name"],
                                f"{mode}.log")
D
dongshuilong 已提交
74
        init_logger(log_file=log_file)
D
dongshuilong 已提交
75 76 77
        print_config(config)

        # init train_func and eval_func
D
dongshuilong 已提交
78 79
        assert self.eval_mode in ["classification", "retrieval"], logger.error(
            "Invalid eval mode: {}".format(self.eval_mode))
D
dongshuilong 已提交
80 81 82
        self.train_epoch_func = train_epoch
        self.eval_func = getattr(evaluation, self.eval_mode + "_eval")

D
dongshuilong 已提交
83 84 85 86
        self.use_dali = self.config['Global'].get("use_dali", False)

        # for visualdl
        self.vdl_writer = None
T
Tingquan Gao 已提交
87 88
        if self.config['Global'][
                'use_visualdl'] and mode == "train" and dist.get_rank() == 0:
D
dongshuilong 已提交
89 90 91 92 93 94
            vdl_writer_path = os.path.join(self.output_dir, "vdl")
            if not os.path.exists(vdl_writer_path):
                os.makedirs(vdl_writer_path)
            self.vdl_writer = LogWriter(logdir=vdl_writer_path)

        # set device
D
dongshuilong 已提交
95 96
        assert self.config["Global"][
            "device"] in ["cpu", "gpu", "xpu", "npu", "mlu"]
D
dongshuilong 已提交
97 98 99 100
        self.device = paddle.set_device(self.config["Global"]["device"])
        logger.info('train with paddle {} and device {}'.format(
            paddle.__version__, self.device))

101 102
        # AMP training and evaluating
        self.amp = "AMP" in self.config
D
dongshuilong 已提交
103 104 105 106 107 108 109 110
        if self.amp and self.config["AMP"] is not None:
            self.scale_loss = self.config["AMP"].get("scale_loss", 1.0)
            self.use_dynamic_loss_scaling = self.config["AMP"].get(
                "use_dynamic_loss_scaling", False)
        else:
            self.scale_loss = 1.0
            self.use_dynamic_loss_scaling = False
        if self.amp:
D
dongshuilong 已提交
111
            AMP_RELATED_FLAGS_SETTING = {'FLAGS_max_inplace_grad_add': 8, }
112 113 114 115
            if paddle.is_compiled_with_cuda():
                AMP_RELATED_FLAGS_SETTING.update({
                    'FLAGS_cudnn_batchnorm_spatial_persistent': 1
                })
D
dongshuilong 已提交
116 117
            paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)

118 119 120 121 122 123 124 125
        if "class_num" in config["Global"]:
            global_class_num = config["Global"]["class_num"]
            if "class_num" not in config["Arch"]:
                config["Arch"]["class_num"] = global_class_num
                msg = f"The Global.class_num will be deprecated. Please use Arch.class_num instead. Arch.class_num has been set to {global_class_num}."
            else:
                msg = "The Global.class_num will be deprecated. Please use Arch.class_num instead. The Global.class_num has been ignored."
            logger.warning(msg)
126
        #TODO(gaotingquan): support rec
G
gaotingquan 已提交
127 128
        class_num = config["Arch"].get("class_num", None)
        self.config["DataLoader"].update({"class_num": class_num})
D
dongshuilong 已提交
129 130 131 132
        # build dataloader
        if self.mode == 'train':
            self.train_dataloader = build_dataloader(
                self.config["DataLoader"], "Train", self.device, self.use_dali)
D
dongshuilong 已提交
133 134
        if self.mode == "eval" or (self.mode == "train" and
                                   self.config["Global"]["eval_during_train"]):
D
dongshuilong 已提交
135 136 137 138 139
            if self.eval_mode == "classification":
                self.eval_dataloader = build_dataloader(
                    self.config["DataLoader"], "Eval", self.device,
                    self.use_dali)
            elif self.eval_mode == "retrieval":
140 141 142 143 144 145 146 147 148 149 150 151 152
                self.gallery_query_dataloader = None
                if len(self.config["DataLoader"]["Eval"].keys()) == 1:
                    key = list(self.config["DataLoader"]["Eval"].keys())[0]
                    self.gallery_query_dataloader = build_dataloader(
                        self.config["DataLoader"]["Eval"], key, self.device,
                        self.use_dali)
                else:
                    self.gallery_dataloader = build_dataloader(
                        self.config["DataLoader"]["Eval"], "Gallery",
                        self.device, self.use_dali)
                    self.query_dataloader = build_dataloader(
                        self.config["DataLoader"]["Eval"], "Query",
                        self.device, self.use_dali)
D
dongshuilong 已提交
153 154 155 156 157

        # build loss
        if self.mode == "train":
            loss_info = self.config["Loss"]["Train"]
            self.train_loss_func = build_loss(loss_info)
D
dongshuilong 已提交
158 159
        if self.mode == "eval" or (self.mode == "train" and
                                   self.config["Global"]["eval_during_train"]):
D
dongshuilong 已提交
160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175
            loss_config = self.config.get("Loss", None)
            if loss_config is not None:
                loss_config = loss_config.get("Eval")
                if loss_config is not None:
                    self.eval_loss_func = build_loss(loss_config)
                else:
                    self.eval_loss_func = None
            else:
                self.eval_loss_func = None

        # build metric
        if self.mode == 'train':
            metric_config = self.config.get("Metric")
            if metric_config is not None:
                metric_config = metric_config.get("Train")
                if metric_config is not None:
littletomatodonkey's avatar
littletomatodonkey 已提交
176 177 178
                    if hasattr(
                            self.train_dataloader, "collate_fn"
                    ) and self.train_dataloader.collate_fn is not None:
179 180 181 182 183 184
                        for m_idx, m in enumerate(metric_config):
                            if "TopkAcc" in m:
                                msg = f"'TopkAcc' metric can not be used when setting 'batch_transform_ops' in config. The 'TopkAcc' metric has been removed."
                                logger.warning(msg)
                                break
                        metric_config.pop(m_idx)
D
dongshuilong 已提交
185 186 187 188 189 190
                    self.train_metric_func = build_metrics(metric_config)
                else:
                    self.train_metric_func = None
        else:
            self.train_metric_func = None

D
dongshuilong 已提交
191 192
        if self.mode == "eval" or (self.mode == "train" and
                                   self.config["Global"]["eval_during_train"]):
D
dongshuilong 已提交
193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208
            metric_config = self.config.get("Metric")
            if self.eval_mode == "classification":
                if metric_config is not None:
                    metric_config = metric_config.get("Eval")
                    if metric_config is not None:
                        self.eval_metric_func = build_metrics(metric_config)
            elif self.eval_mode == "retrieval":
                if metric_config is None:
                    metric_config = [{"name": "Recallk", "topk": (1, 5)}]
                else:
                    metric_config = metric_config["Eval"]
                self.eval_metric_func = build_metrics(metric_config)
        else:
            self.eval_metric_func = None

        # build model
W
weishengyu 已提交
209
        self.model = build_model(self.config)
D
dongshuilong 已提交
210 211
        # set @to_static for benchmark, skip this by default.
        apply_to_static(self.config, self.model)
D
dongshuilong 已提交
212

D
dongshuilong 已提交
213 214 215 216
        # load_pretrain
        if self.config["Global"]["pretrained_model"] is not None:
            if self.config["Global"]["pretrained_model"].startswith("http"):
                load_dygraph_pretrain_from_url(
217
                    [self.model, getattr(self, 'train_loss_func', None)],
218
                    self.config["Global"]["pretrained_model"])
D
dongshuilong 已提交
219 220
            else:
                load_dygraph_pretrain(
221
                    [self.model, getattr(self, 'train_loss_func', None)],
222
                    self.config["Global"]["pretrained_model"])
D
dongshuilong 已提交
223 224 225 226

        # build optimizer
        if self.mode == 'train':
            self.optimizer, self.lr_sch = build_optimizer(
H
HydrogenSulfate 已提交
227
                self.config["Optimizer"], self.config["Global"]["epochs"],
228 229
                len(self.train_dataloader),
                [self.model, self.train_loss_func])
230

Z
zhangbo9674 已提交
231 232 233 234 235
        # for amp training
        if self.amp:
            self.scaler = paddle.amp.GradScaler(
                init_loss_scaling=self.scale_loss,
                use_dynamic_loss_scaling=self.use_dynamic_loss_scaling)
236 237 238 239 240 241
            amp_level = self.config['AMP'].get("level", "O1")
            if amp_level not in ["O1", "O2"]:
                msg = "[Parameter Error]: The optimize level of AMP only support 'O1' and 'O2'. The level has been set 'O1'."
                logger.warning(msg)
                self.config['AMP']["level"] = "O1"
                amp_level = "O1"
G
gaotingquan 已提交
242 243 244 245 246
            self.model, self.optimizer = paddle.amp.decorate(
                models=self.model,
                optimizers=self.optimizer,
                level=amp_level,
                save_dtype='float32')
247 248 249 250 251
            if len(self.train_loss_func.parameters()) > 0:
                self.train_loss_func = paddle.amp.decorate(
                    models=self.train_loss_func,
                    level=amp_level,
                    save_dtype='float32')
D
dongshuilong 已提交
252 253

        # for distributed
254 255 256 257 258
        world_size = dist.get_world_size()
        self.config["Global"]["distributed"] = world_size != 1
        if world_size != 4 and self.mode == "train":
            msg = f"The training strategy in config files provided by PaddleClas is based on 4 gpus. But the number of gpus is {world_size} in current training. Please modify the stategy (learning rate, batch size and so on) if use config files in PaddleClas to train."
            logger.warning(msg)
D
dongshuilong 已提交
259 260 261
        if self.config["Global"]["distributed"]:
            dist.init_parallel_env()
            self.model = paddle.DataParallel(self.model)
262 263
            if self.mode == 'train' and len(self.train_loss_func.parameters(
            )) > 0:
264 265
                self.train_loss_func = paddle.DataParallel(
                    self.train_loss_func)
D
dongshuilong 已提交
266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292
        # build postprocess for infer
        if self.mode == 'infer':
            self.preprocess_func = create_operators(self.config["Infer"][
                "transforms"])
            self.postprocess_func = build_postprocess(self.config["Infer"][
                "PostProcess"])

    def train(self):
        assert self.mode == "train"
        print_batch_step = self.config['Global']['print_batch_step']
        save_interval = self.config["Global"]["save_interval"]
        best_metric = {
            "metric": 0.0,
            "epoch": 0,
        }
        # key:
        # val: metrics list word
        self.output_info = dict()
        self.time_info = {
            "batch_cost": AverageMeter(
                "batch_cost", '.5f', postfix=" s,"),
            "reader_cost": AverageMeter(
                "reader_cost", ".5f", postfix=" s,"),
        }
        # global iter counter
        self.global_step = 0

293 294 295
        if self.config.Global.checkpoints is not None:
            metric_info = init_model(self.config.Global, self.model,
                                     self.optimizer, self.train_loss_func)
D
dongshuilong 已提交
296 297 298 299 300 301 302 303 304
            if metric_info is not None:
                best_metric.update(metric_info)

        self.max_iter = len(self.train_dataloader) - 1 if platform.system(
        ) == "Windows" else len(self.train_dataloader)
        for epoch_id in range(best_metric["epoch"] + 1,
                              self.config["Global"]["epochs"] + 1):
            acc = 0.0
            # for one epoch train
D
dongshuilong 已提交
305
            self.train_epoch_func(self, epoch_id, print_batch_step)
D
dongshuilong 已提交
306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330

            if self.use_dali:
                self.train_dataloader.reset()
            metric_msg = ", ".join([
                "{}: {:.5f}".format(key, self.output_info[key].avg)
                for key in self.output_info
            ])
            logger.info("[Train][Epoch {}/{}][Avg]{}".format(
                epoch_id, self.config["Global"]["epochs"], metric_msg))
            self.output_info.clear()

            # eval model and save model if possible
            if self.config["Global"][
                    "eval_during_train"] and epoch_id % self.config["Global"][
                        "eval_interval"] == 0:
                acc = self.eval(epoch_id)
                if acc > best_metric["metric"]:
                    best_metric["metric"] = acc
                    best_metric["epoch"] = epoch_id
                    save_load.save_model(
                        self.model,
                        self.optimizer,
                        best_metric,
                        self.output_dir,
                        model_name=self.config["Arch"]["name"],
331 332
                        prefix="best_model",
                        loss=self.train_loss_func)
D
dongshuilong 已提交
333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350
                logger.info("[Eval][Epoch {}][best metric: {}]".format(
                    epoch_id, best_metric["metric"]))
                logger.scaler(
                    name="eval_acc",
                    value=acc,
                    step=epoch_id,
                    writer=self.vdl_writer)

                self.model.train()

            # save model
            if epoch_id % save_interval == 0:
                save_load.save_model(
                    self.model,
                    self.optimizer, {"metric": acc,
                                     "epoch": epoch_id},
                    self.output_dir,
                    model_name=self.config["Arch"]["name"],
351 352
                    prefix="epoch_{}".format(epoch_id),
                    loss=self.train_loss_func)
G
gaotingquan 已提交
353 354 355 356 357 358 359
            # save the latest model
            save_load.save_model(
                self.model,
                self.optimizer, {"metric": acc,
                                 "epoch": epoch_id},
                self.output_dir,
                model_name=self.config["Arch"]["name"],
360 361
                prefix="latest",
                loss=self.train_loss_func)
D
dongshuilong 已提交
362 363 364 365 366 367 368 369

        if self.vdl_writer is not None:
            self.vdl_writer.close()

    @paddle.no_grad()
    def eval(self, epoch_id=0):
        assert self.mode in ["train", "eval"]
        self.model.eval()
D
dongshuilong 已提交
370
        eval_result = self.eval_func(self, epoch_id)
D
dongshuilong 已提交
371 372 373 374 375 376
        self.model.train()
        return eval_result

    @paddle.no_grad()
    def infer(self):
        assert self.mode == "infer" and self.eval_mode == "classification"
377 378
        total_trainer = dist.get_world_size()
        local_rank = dist.get_rank()
D
dongshuilong 已提交
379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398
        image_list = get_image_list(self.config["Infer"]["infer_imgs"])
        # data split
        image_list = image_list[local_rank::total_trainer]

        batch_size = self.config["Infer"]["batch_size"]
        self.model.eval()
        batch_data = []
        image_file_list = []
        for idx, image_file in enumerate(image_list):
            with open(image_file, 'rb') as f:
                x = f.read()
            for process in self.preprocess_func:
                x = process(x)
            batch_data.append(x)
            image_file_list.append(image_file)
            if len(batch_data) >= batch_size or idx == len(image_list) - 1:
                batch_tensor = paddle.to_tensor(batch_data)
                out = self.model(batch_tensor)
                if isinstance(out, list):
                    out = out[0]
399 400 401
                if isinstance(out, dict) and "logits" in out:
                    out = out["logits"]
                if isinstance(out, dict) and "output" in out:
W
dbg  
weishengyu 已提交
402
                    out = out["output"]
D
dongshuilong 已提交
403 404 405 406 407 408 409
                result = self.postprocess_func(out, image_file_list)
                print(result)
                batch_data.clear()
                image_file_list.clear()

    def export(self):
        assert self.mode == "export"
C
cuicheng01 已提交
410 411
        use_multilabel = self.config["Global"].get("use_multilabel", False)
        model = ExportModel(self.config["Arch"], self.model, use_multilabel)
D
dongshuilong 已提交
412 413 414 415 416
        if self.config["Global"]["pretrained_model"] is not None:
            load_dygraph_pretrain(model.base_model,
                                  self.config["Global"]["pretrained_model"])

        model.eval()
D
dongshuilong 已提交
417 418
        save_path = os.path.join(self.config["Global"]["save_inference_dir"],
                                 "inference")
W
weishengyu 已提交
419
        if model.quanter:
W
weishengyu 已提交
420
            model.quanter.save_quantized_model(
C
cuicheng01 已提交
421
                model.base_model,
D
dongshuilong 已提交
422 423 424 425 426 427 428 429 430 431 432 433 434 435 436
                save_path,
                input_spec=[
                    paddle.static.InputSpec(
                        shape=[None] + self.config["Global"]["image_shape"],
                        dtype='float32')
                ])
        else:
            model = paddle.jit.to_static(
                model,
                input_spec=[
                    paddle.static.InputSpec(
                        shape=[None] + self.config["Global"]["image_shape"],
                        dtype='float32')
                ])
            paddle.jit.save(model, save_path)
D
dongshuilong 已提交
437 438


W
dbg  
weishengyu 已提交
439
class ExportModel(TheseusLayer):
D
dongshuilong 已提交
440 441 442 443
    """
    ExportModel: add softmax onto the model
    """

C
cuicheng01 已提交
444
    def __init__(self, config, model, use_multilabel):
D
dongshuilong 已提交
445 446 447 448 449 450 451 452 453 454 455 456
        super().__init__()
        self.base_model = model
        # we should choose a final model to export
        if isinstance(self.base_model, DistillationModel):
            self.infer_model_name = config["infer_model_name"]
        else:
            self.infer_model_name = None

        self.infer_output_key = config.get("infer_output_key", None)
        if self.infer_output_key == "features" and isinstance(self.base_model,
                                                              RecModel):
            self.base_model.head = IdentityHead()
C
cuicheng01 已提交
457 458
        if use_multilabel:
            self.out_act = nn.Sigmoid()
D
dongshuilong 已提交
459
        else:
C
cuicheng01 已提交
460 461 462 463
            if config.get("infer_add_softmax", True):
                self.out_act = nn.Softmax(axis=-1)
            else:
                self.out_act = None
D
dongshuilong 已提交
464 465 466 467 468 469 470 471 472 473 474 475 476 477 478

    def eval(self):
        self.training = False
        for layer in self.sublayers():
            layer.training = False
            layer.eval()

    def forward(self, x):
        x = self.base_model(x)
        if isinstance(x, list):
            x = x[0]
        if self.infer_model_name is not None:
            x = x[self.infer_model_name]
        if self.infer_output_key is not None:
            x = x[self.infer_output_key]
C
cuicheng01 已提交
479
        if self.out_act is not None:
wc晨曦's avatar
wc晨曦 已提交
480 481
            if isinstance(x, dict):
                x = x["logits"]
C
cuicheng01 已提交
482
            x = self.out_act(x)
D
dongshuilong 已提交
483
        return x