engine.py 26.4 KB
Newer Older
D
dongshuilong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

D
dongshuilong 已提交
17
import os
D
dongshuilong 已提交
18 19 20 21
import platform
import paddle
import paddle.distributed as dist
from visualdl import LogWriter
D
dongshuilong 已提交
22
from paddle import nn
D
dongshuilong 已提交
23 24
import numpy as np
import random
D
dongshuilong 已提交
25 26 27 28 29 30

from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger
from ppcls.utils.logger import init_logger
from ppcls.utils.config import print_config
from ppcls.data import build_dataloader
W
dbg  
weishengyu 已提交
31
from ppcls.arch import build_model, RecModel, DistillationModel, TheseusLayer
D
dongshuilong 已提交
32 33 34 35
from ppcls.arch import apply_to_static
from ppcls.loss import build_loss
from ppcls.metric import build_metrics
from ppcls.optimizer import build_optimizer
F
flytocc 已提交
36
from ppcls.utils.ema import ExponentialMovingAverage
D
dongshuilong 已提交
37 38 39 40 41 42 43
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
from ppcls.utils.save_load import init_model
from ppcls.utils import save_load

from ppcls.data.utils.get_image_list import get_image_list
from ppcls.data.postprocess import build_postprocess
from ppcls.data import create_operators
D
dongshuilong 已提交
44
from ppcls.engine import train as train_method
45
from ppcls.engine.train.utils import type_name
D
dongshuilong 已提交
46
from ppcls.engine import evaluation
D
dongshuilong 已提交
47 48 49
from ppcls.arch.gears.identity_head import IdentityHead


D
dongshuilong 已提交
50
class Engine(object):
D
dongshuilong 已提交
51
    def __init__(self, config, mode="train"):
D
dongshuilong 已提交
52
        assert mode in ["train", "eval", "infer", "export"]
D
dongshuilong 已提交
53 54 55 56
        self.mode = mode
        self.config = config
        self.eval_mode = self.config["Global"].get("eval_mode",
                                                   "classification")
D
dongshuilong 已提交
57
        self.train_mode = self.config["Global"].get("train_mode", None)
58 59
        if "Head" in self.config["Arch"] or self.config["Arch"].get("is_rec",
                                                                    False):
D
dongshuilong 已提交
60 61 62 63
            self.is_rec = True
        else:
            self.is_rec = False

D
dongshuilong 已提交
64 65
        # set seed
        seed = self.config["Global"].get("seed", False)
S
stephon 已提交
66
        if seed or seed == 0:
D
dongshuilong 已提交
67 68 69 70 71
            assert isinstance(seed, int), "The 'seed' must be a integer!"
            paddle.seed(seed)
            np.random.seed(seed)
            random.seed(seed)

D
dongshuilong 已提交
72 73 74 75
        # init logger
        self.output_dir = self.config['Global']['output_dir']
        log_file = os.path.join(self.output_dir, self.config["Arch"]["name"],
                                f"{mode}.log")
D
dongshuilong 已提交
76
        init_logger(log_file=log_file)
D
dongshuilong 已提交
77 78 79
        print_config(config)

        # init train_func and eval_func
D
dongshuilong 已提交
80 81 82
        assert self.eval_mode in [
            "classification", "retrieval", "adaface"
        ], logger.error("Invalid eval mode: {}".format(self.eval_mode))
D
dongshuilong 已提交
83 84 85 86 87
        if self.train_mode is None:
            self.train_epoch_func = train_method.train_epoch
        else:
            self.train_epoch_func = getattr(train_method,
                                            "train_epoch_" + self.train_mode)
D
dongshuilong 已提交
88 89
        self.eval_func = getattr(evaluation, self.eval_mode + "_eval")

D
dongshuilong 已提交
90 91 92 93
        self.use_dali = self.config['Global'].get("use_dali", False)

        # for visualdl
        self.vdl_writer = None
T
Tingquan Gao 已提交
94 95
        if self.config['Global'][
                'use_visualdl'] and mode == "train" and dist.get_rank() == 0:
D
dongshuilong 已提交
96 97 98 99 100 101
            vdl_writer_path = os.path.join(self.output_dir, "vdl")
            if not os.path.exists(vdl_writer_path):
                os.makedirs(vdl_writer_path)
            self.vdl_writer = LogWriter(logdir=vdl_writer_path)

        # set device
D
dongshuilong 已提交
102
        assert self.config["Global"][
103
            "device"] in ["cpu", "gpu", "xpu", "npu", "mlu", "ascend"]
D
dongshuilong 已提交
104 105 106 107
        self.device = paddle.set_device(self.config["Global"]["device"])
        logger.info('train with paddle {} and device {}'.format(
            paddle.__version__, self.device))

108 109 110
        # gradient accumulation
        self.update_freq = self.config["Global"].get("update_freq", 1)

111 112 113 114 115 116 117 118
        if "class_num" in config["Global"]:
            global_class_num = config["Global"]["class_num"]
            if "class_num" not in config["Arch"]:
                config["Arch"]["class_num"] = global_class_num
                msg = f"The Global.class_num will be deprecated. Please use Arch.class_num instead. Arch.class_num has been set to {global_class_num}."
            else:
                msg = "The Global.class_num will be deprecated. Please use Arch.class_num instead. The Global.class_num has been ignored."
            logger.warning(msg)
119
        #TODO(gaotingquan): support rec
G
gaotingquan 已提交
120 121
        class_num = config["Arch"].get("class_num", None)
        self.config["DataLoader"].update({"class_num": class_num})
H
HydrogenSulfate 已提交
122

D
dongshuilong 已提交
123 124 125 126
        # build dataloader
        if self.mode == 'train':
            self.train_dataloader = build_dataloader(
                self.config["DataLoader"], "Train", self.device, self.use_dali)
D
dongshuilong 已提交
127 128 129 130 131 132 133 134 135 136 137 138 139 140
            if self.config["DataLoader"].get('UnLabelTrain', None) is not None:
                self.unlabel_train_dataloader = build_dataloader(
                        self.config["DataLoader"], "UnLabelTrain", self.device,
                        self.use_dali)
            else:
                self.unlabel_train_dataloader = None

            self.iter_per_epoch = len(self.train_dataloader) - 1 if platform.system(
                ) == "Windows" else len(self.train_dataloader)
            if self.config["Global"].get("iter_per_epoch", None):
                # set max iteration per epoch mannualy, when training by iteration(s), such as XBM, FixMatch.
                self.iter_per_epoch = self.config["Global"].get("iter_per_epoch")
            self.iter_per_epoch = self.iter_per_epoch // self.update_freq * self.update_freq

D
dongshuilong 已提交
141 142
        if self.mode == "eval" or (self.mode == "train" and
                                   self.config["Global"]["eval_during_train"]):
D
dongshuilong 已提交
143
            if self.eval_mode in ["classification", "adaface"]:
D
dongshuilong 已提交
144 145 146 147
                self.eval_dataloader = build_dataloader(
                    self.config["DataLoader"], "Eval", self.device,
                    self.use_dali)
            elif self.eval_mode == "retrieval":
148 149 150 151 152 153 154 155 156 157 158 159 160
                self.gallery_query_dataloader = None
                if len(self.config["DataLoader"]["Eval"].keys()) == 1:
                    key = list(self.config["DataLoader"]["Eval"].keys())[0]
                    self.gallery_query_dataloader = build_dataloader(
                        self.config["DataLoader"]["Eval"], key, self.device,
                        self.use_dali)
                else:
                    self.gallery_dataloader = build_dataloader(
                        self.config["DataLoader"]["Eval"], "Gallery",
                        self.device, self.use_dali)
                    self.query_dataloader = build_dataloader(
                        self.config["DataLoader"]["Eval"], "Query",
                        self.device, self.use_dali)
D
dongshuilong 已提交
161 162 163

        # build loss
        if self.mode == "train":
D
dongshuilong 已提交
164 165 166 167 168
            label_loss_info = self.config["Loss"]["Train"]
            self.train_loss_func = build_loss(label_loss_info)
            unlabel_loss_info = self.config.get("UnLabelLoss", {}).get("Train",
                                                                       None)
            self.unlabel_train_loss_func = build_loss(unlabel_loss_info)
D
dongshuilong 已提交
169 170
        if self.mode == "eval" or (self.mode == "train" and
                                   self.config["Global"]["eval_during_train"]):
D
dongshuilong 已提交
171 172 173 174 175 176 177 178 179 180 181
            loss_config = self.config.get("Loss", None)
            if loss_config is not None:
                loss_config = loss_config.get("Eval")
                if loss_config is not None:
                    self.eval_loss_func = build_loss(loss_config)
                else:
                    self.eval_loss_func = None
            else:
                self.eval_loss_func = None

        # build metric
182
        if self.mode == 'train' and "Metric" in self.config and "Train" in self.config[
T
Tingquan Gao 已提交
183
                "Metric"] and self.config["Metric"]["Train"]:
184 185 186 187 188
            metric_config = self.config["Metric"]["Train"]
            if hasattr(self.train_dataloader, "collate_fn"
                       ) and self.train_dataloader.collate_fn is not None:
                for m_idx, m in enumerate(metric_config):
                    if "TopkAcc" in m:
189
                        msg = f"Unable to calculate accuracy when using \"batch_transform_ops\". The metric \"{m}\" has been removed."
190
                        logger.warning(msg)
191
                        metric_config.pop(m_idx)
192
            self.train_metric_func = build_metrics(metric_config)
D
dongshuilong 已提交
193 194 195
        else:
            self.train_metric_func = None

D
dongshuilong 已提交
196 197
        if self.mode == "eval" or (self.mode == "train" and
                                   self.config["Global"]["eval_during_train"]):
D
dongshuilong 已提交
198
            if self.eval_mode == "classification":
199 200 201 202 203
                if "Metric" in self.config and "Eval" in self.config["Metric"]:
                    self.eval_metric_func = build_metrics(self.config["Metric"]
                                                          ["Eval"])
                else:
                    self.eval_metric_func = None
D
dongshuilong 已提交
204
            elif self.eval_mode == "retrieval":
205
                if "Metric" in self.config and "Eval" in self.config["Metric"]:
G
gaotingquan 已提交
206
                    metric_config = self.config["Metric"]["Eval"]
D
dongshuilong 已提交
207
                else:
208
                    metric_config = [{"name": "Recallk", "topk": (1, 5)}]
D
dongshuilong 已提交
209 210 211 212 213
                self.eval_metric_func = build_metrics(metric_config)
        else:
            self.eval_metric_func = None

        # build model
littletomatodonkey's avatar
littletomatodonkey 已提交
214
        self.model = build_model(self.config, self.mode)
D
dongshuilong 已提交
215 216
        # set @to_static for benchmark, skip this by default.
        apply_to_static(self.config, self.model)
D
dongshuilong 已提交
217

D
dongshuilong 已提交
218 219 220 221
        # load_pretrain
        if self.config["Global"]["pretrained_model"] is not None:
            if self.config["Global"]["pretrained_model"].startswith("http"):
                load_dygraph_pretrain_from_url(
222
                    [self.model, getattr(self, 'train_loss_func', None)],
223
                    self.config["Global"]["pretrained_model"])
D
dongshuilong 已提交
224 225
            else:
                load_dygraph_pretrain(
226
                    [self.model, getattr(self, 'train_loss_func', None)],
227
                    self.config["Global"]["pretrained_model"])
D
dongshuilong 已提交
228 229 230 231

        # build optimizer
        if self.mode == 'train':
            self.optimizer, self.lr_sch = build_optimizer(
H
HydrogenSulfate 已提交
232
                self.config["Optimizer"], self.config["Global"]["epochs"],
D
dongshuilong 已提交
233
                self.iter_per_epoch // self.update_freq,
234
                [self.model, self.train_loss_func])
235

236 237 238
        # AMP training and evaluating
        self.amp = "AMP" in self.config and self.config["AMP"] is not None
        self.amp_eval = False
G
gaotingquan 已提交
239
        # for amp
Z
zhangbo9674 已提交
240
        if self.amp:
241 242 243 244 245
            AMP_RELATED_FLAGS_SETTING = {'FLAGS_max_inplace_grad_add': 8, }
            if paddle.is_compiled_with_cuda():
                AMP_RELATED_FLAGS_SETTING.update({
                    'FLAGS_cudnn_batchnorm_spatial_persistent': 1
                })
G
gaotingquan 已提交
246
            paddle.set_flags(AMP_RELATED_FLAGS_SETTING)
247 248 249 250

            self.scale_loss = self.config["AMP"].get("scale_loss", 1.0)
            self.use_dynamic_loss_scaling = self.config["AMP"].get(
                "use_dynamic_loss_scaling", False)
Z
zhangbo9674 已提交
251 252 253
            self.scaler = paddle.amp.GradScaler(
                init_loss_scaling=self.scale_loss,
                use_dynamic_loss_scaling=self.use_dynamic_loss_scaling)
254 255 256

            self.amp_level = self.config['AMP'].get("level", "O1")
            if self.amp_level not in ["O1", "O2"]:
257 258 259
                msg = "[Parameter Error]: The optimize level of AMP only support 'O1' and 'O2'. The level has been set 'O1'."
                logger.warning(msg)
                self.config['AMP']["level"] = "O1"
260 261 262 263
                self.amp_level = "O1"

            self.amp_eval = self.config["AMP"].get("use_fp16_test", False)
            # TODO(gaotingquan): Paddle not yet support FP32 evaluation when training with AMPO2
G
gaotingquan 已提交
264
            if self.mode == "train" and self.config["Global"].get(
265 266 267 268 269 270 271
                    "eval_during_train",
                    True) and self.amp_level == "O2" and self.amp_eval == False:
                msg = "PaddlePaddle only support FP16 evaluation when training with AMP O2 now. "
                logger.warning(msg)
                self.config["AMP"]["use_fp16_test"] = True
                self.amp_eval = True

G
gaotingquan 已提交
272 273
            # TODO(gaotingquan): to compatible with different versions of Paddle
            paddle_version = paddle.__version__[:3]
274
            # paddle version < 2.3.0 and not develop
G
gaotingquan 已提交
275
            if paddle_version not in ["2.3", "0.0"]:
276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293
                if self.mode == "train":
                    self.model, self.optimizer = paddle.amp.decorate(
                        models=self.model,
                        optimizers=self.optimizer,
                        level=self.amp_level,
                        save_dtype='float32')
                elif self.amp_eval:
                    if self.amp_level == "O2":
                        msg = "The PaddlePaddle that installed not support FP16 evaluation in AMP O2. Please use PaddlePaddle version >= 2.3.0. Use FP32 evaluation instead and please notice the Eval Dataset output_fp16 should be 'False'."
                        logger.warning(msg)
                        self.amp_eval = False
                    else:
                        self.model, self.optimizer = paddle.amp.decorate(
                            models=self.model,
                            level=self.amp_level,
                            save_dtype='float32')
            # paddle version >= 2.3.0 or develop
            else:
G
gaotingquan 已提交
294 295 296 297 298
                if self.mode == "train" or self.amp_eval:
                    self.model = paddle.amp.decorate(
                        models=self.model,
                        level=self.amp_level,
                        save_dtype='float32')
299

G
gaotingquan 已提交
300 301
            if self.mode == "train" and len(self.train_loss_func.parameters(
            )) > 0:
302 303
                self.train_loss_func = paddle.amp.decorate(
                    models=self.train_loss_func,
304
                    level=self.amp_level,
305
                    save_dtype='float32')
D
dongshuilong 已提交
306

F
flytocc 已提交
307
        # build EMA model
Y
Yang Nie 已提交
308
        self.ema = "EMA" in self.config and self.mode == "train"
F
flytocc 已提交
309 310 311 312
        if self.ema:
            self.model_ema = ExponentialMovingAverage(
                self.model, self.config['EMA'].get("decay", 0.9999))

313
        # check the gpu num
314 315
        world_size = dist.get_world_size()
        self.config["Global"]["distributed"] = world_size != 1
316
        if self.mode == "train":
H
HydrogenSulfate 已提交
317 318 319
            std_gpu_num = 8 if isinstance(
                self.config["Optimizer"],
                dict) and self.config["Optimizer"]["name"] == "AdamW" else 4
320 321 322 323 324
            if world_size != std_gpu_num:
                msg = f"The training strategy provided by PaddleClas is based on {std_gpu_num} gpus. But the number of gpu is {world_size} in current training. Please modify the stategy (learning rate, batch size and so on) if use this config to train."
                logger.warning(msg)

        # for distributed
D
dongshuilong 已提交
325 326 327
        if self.config["Global"]["distributed"]:
            dist.init_parallel_env()
            self.model = paddle.DataParallel(self.model)
328 329
            if self.mode == 'train' and len(self.train_loss_func.parameters(
            )) > 0:
330 331
                self.train_loss_func = paddle.DataParallel(
                    self.train_loss_func)
D
dongshuilong 已提交
332 333 334 335 336 337 338 339 340 341 342 343
        # build postprocess for infer
        if self.mode == 'infer':
            self.preprocess_func = create_operators(self.config["Infer"][
                "transforms"])
            self.postprocess_func = build_postprocess(self.config["Infer"][
                "PostProcess"])

    def train(self):
        assert self.mode == "train"
        print_batch_step = self.config['Global']['print_batch_step']
        save_interval = self.config["Global"]["save_interval"]
        best_metric = {
C
cuicheng01 已提交
344
            "metric": -1.0,
D
dongshuilong 已提交
345 346
            "epoch": 0,
        }
F
flytocc 已提交
347 348 349 350
        ema_module = None
        if self.ema:
            best_metric_ema = 0.0
            ema_module = self.model_ema.module
D
dongshuilong 已提交
351 352 353 354 355 356 357 358 359 360 361 362
        # key:
        # val: metrics list word
        self.output_info = dict()
        self.time_info = {
            "batch_cost": AverageMeter(
                "batch_cost", '.5f', postfix=" s,"),
            "reader_cost": AverageMeter(
                "reader_cost", ".5f", postfix=" s,"),
        }
        # global iter counter
        self.global_step = 0

363 364
        if self.config.Global.checkpoints is not None:
            metric_info = init_model(self.config.Global, self.model,
F
flytocc 已提交
365 366
                                     self.optimizer, self.train_loss_func,
                                     ema_module)
D
dongshuilong 已提交
367 368 369 370 371 372 373
            if metric_info is not None:
                best_metric.update(metric_info)

        for epoch_id in range(best_metric["epoch"] + 1,
                              self.config["Global"]["epochs"] + 1):
            acc = 0.0
            # for one epoch train
D
dongshuilong 已提交
374
            self.train_epoch_func(self, epoch_id, print_batch_step)
D
dongshuilong 已提交
375 376 377

            if self.use_dali:
                self.train_dataloader.reset()
littletomatodonkey's avatar
littletomatodonkey 已提交
378 379
            metric_msg = ", ".join(
                [self.output_info[key].avg_info for key in self.output_info])
D
dongshuilong 已提交
380 381 382 383 384
            logger.info("[Train][Epoch {}/{}][Avg]{}".format(
                epoch_id, self.config["Global"]["epochs"], metric_msg))
            self.output_info.clear()

            # eval model and save model if possible
littletomatodonkey's avatar
littletomatodonkey 已提交
385 386
            start_eval_epoch = self.config["Global"].get("start_eval_epoch",
                                                         0) - 1
D
dongshuilong 已提交
387 388
            if self.config["Global"][
                    "eval_during_train"] and epoch_id % self.config["Global"][
C
cuicheng01 已提交
389
                        "eval_interval"] == 0 and epoch_id > start_eval_epoch:
D
dongshuilong 已提交
390
                acc = self.eval(epoch_id)
H
add xbm  
HydrogenSulfate 已提交
391 392 393 394

                # step lr (by epoch) according to given metric, such as acc
                for i in range(len(self.lr_sch)):
                    if getattr(self.lr_sch[i], "by_epoch", False) and \
395
                            type_name(self.lr_sch[i]) == "ReduceOnPlateau":
H
add xbm  
HydrogenSulfate 已提交
396 397
                        self.lr_sch[i].step(acc)

D
dongshuilong 已提交
398 399 400 401 402 403 404 405
                if acc > best_metric["metric"]:
                    best_metric["metric"] = acc
                    best_metric["epoch"] = epoch_id
                    save_load.save_model(
                        self.model,
                        self.optimizer,
                        best_metric,
                        self.output_dir,
F
flytocc 已提交
406
                        ema=ema_module,
D
dongshuilong 已提交
407
                        model_name=self.config["Arch"]["name"],
408
                        prefix="best_model",
littletomatodonkey's avatar
littletomatodonkey 已提交
409 410
                        loss=self.train_loss_func,
                        save_student_model=True)
D
dongshuilong 已提交
411 412 413 414 415 416 417 418 419 420
                logger.info("[Eval][Epoch {}][best metric: {}]".format(
                    epoch_id, best_metric["metric"]))
                logger.scaler(
                    name="eval_acc",
                    value=acc,
                    step=epoch_id,
                    writer=self.vdl_writer)

                self.model.train()

F
flytocc 已提交
421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446
                if self.ema:
                    ori_model, self.model = self.model, ema_module
                    acc_ema = self.eval(epoch_id)
                    self.model = ori_model
                    ema_module.eval()

                    if acc_ema > best_metric_ema:
                        best_metric_ema = acc_ema
                        save_load.save_model(
                            self.model,
                            self.optimizer,
                            {"metric": acc_ema,
                             "epoch": epoch_id},
                            self.output_dir,
                            ema=ema_module,
                            model_name=self.config["Arch"]["name"],
                            prefix="best_model_ema",
                            loss=self.train_loss_func)
                    logger.info("[Eval][Epoch {}][best metric ema: {}]".format(
                        epoch_id, best_metric_ema))
                    logger.scaler(
                        name="eval_acc_ema",
                        value=acc_ema,
                        step=epoch_id,
                        writer=self.vdl_writer)

D
dongshuilong 已提交
447
            # save model
D
dongshuilong 已提交
448
            if save_interval > 0 and epoch_id % save_interval == 0:
D
dongshuilong 已提交
449 450 451 452 453
                save_load.save_model(
                    self.model,
                    self.optimizer, {"metric": acc,
                                     "epoch": epoch_id},
                    self.output_dir,
F
flytocc 已提交
454
                    ema=ema_module,
D
dongshuilong 已提交
455
                    model_name=self.config["Arch"]["name"],
456 457
                    prefix="epoch_{}".format(epoch_id),
                    loss=self.train_loss_func)
G
gaotingquan 已提交
458 459 460 461 462 463
            # save the latest model
            save_load.save_model(
                self.model,
                self.optimizer, {"metric": acc,
                                 "epoch": epoch_id},
                self.output_dir,
F
flytocc 已提交
464
                ema=ema_module,
G
gaotingquan 已提交
465
                model_name=self.config["Arch"]["name"],
466 467
                prefix="latest",
                loss=self.train_loss_func)
D
dongshuilong 已提交
468 469 470 471 472 473 474 475

        if self.vdl_writer is not None:
            self.vdl_writer.close()

    @paddle.no_grad()
    def eval(self, epoch_id=0):
        assert self.mode in ["train", "eval"]
        self.model.eval()
D
dongshuilong 已提交
476
        eval_result = self.eval_func(self, epoch_id)
D
dongshuilong 已提交
477 478 479 480 481 482
        self.model.train()
        return eval_result

    @paddle.no_grad()
    def infer(self):
        assert self.mode == "infer" and self.eval_mode == "classification"
483 484
        total_trainer = dist.get_world_size()
        local_rank = dist.get_rank()
D
dongshuilong 已提交
485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501
        image_list = get_image_list(self.config["Infer"]["infer_imgs"])
        # data split
        image_list = image_list[local_rank::total_trainer]

        batch_size = self.config["Infer"]["batch_size"]
        self.model.eval()
        batch_data = []
        image_file_list = []
        for idx, image_file in enumerate(image_list):
            with open(image_file, 'rb') as f:
                x = f.read()
            for process in self.preprocess_func:
                x = process(x)
            batch_data.append(x)
            image_file_list.append(image_file)
            if len(batch_data) >= batch_size or idx == len(image_list) - 1:
                batch_tensor = paddle.to_tensor(batch_data)
G
gaotingquan 已提交
502 503 504 505 506 507 508 509 510 511 512

                if self.amp and self.amp_eval:
                    with paddle.amp.auto_cast(
                            custom_black_list={
                                "flatten_contiguous_range", "greater_than"
                            },
                            level=self.amp_level):
                        out = self.model(batch_tensor)
                else:
                    out = self.model(batch_tensor)

D
dongshuilong 已提交
513 514
                if isinstance(out, list):
                    out = out[0]
littletomatodonkey's avatar
littletomatodonkey 已提交
515 516
                if isinstance(out, dict) and "Student" in out:
                    out = out["Student"]
517 518 519
                if isinstance(out, dict) and "logits" in out:
                    out = out["logits"]
                if isinstance(out, dict) and "output" in out:
W
dbg  
weishengyu 已提交
520
                    out = out["output"]
D
dongshuilong 已提交
521 522 523 524 525 526 527
                result = self.postprocess_func(out, image_file_list)
                print(result)
                batch_data.clear()
                image_file_list.clear()

    def export(self):
        assert self.mode == "export"
Z
zhiboniu 已提交
528 529
        use_multilabel = self.config["Global"].get(
            "use_multilabel",
C
cuicheng01 已提交
530
            False) or "ATTRMetric" in self.config["Metric"]["Eval"][0]
C
cuicheng01 已提交
531
        model = ExportModel(self.config["Arch"], self.model, use_multilabel)
D
dongshuilong 已提交
532 533 534 535 536
        if self.config["Global"]["pretrained_model"] is not None:
            load_dygraph_pretrain(model.base_model,
                                  self.config["Global"]["pretrained_model"])

        model.eval()
G
gaotingquan 已提交
537 538

        # for rep nets
H
HydrogenSulfate 已提交
539 540 541
        for layer in self.model.sublayers():
            if hasattr(layer, "rep") and not getattr(layer, "is_repped"):
                layer.rep()
G
gaotingquan 已提交
542

D
dongshuilong 已提交
543 544
        save_path = os.path.join(self.config["Global"]["save_inference_dir"],
                                 "inference")
littletomatodonkey's avatar
littletomatodonkey 已提交
545 546 547 548 549 550 551 552 553 554 555 556

        model = paddle.jit.to_static(
            model,
            input_spec=[
                paddle.static.InputSpec(
                    shape=[None] + self.config["Global"]["image_shape"],
                    dtype='float32')
            ])
        if hasattr(model.base_model,
                   "quanter") and model.base_model.quanter is not None:
            model.base_model.quanter.save_quantized_model(model,
                                                          save_path + "_int8")
D
dongshuilong 已提交
557 558
        else:
            paddle.jit.save(model, save_path)
G
gaotingquan 已提交
559 560 561
        logger.info(
            f"Export succeeded! The inference model exported has been saved in \"{self.config['Global']['save_inference_dir']}\"."
        )
D
dongshuilong 已提交
562 563


W
dbg  
weishengyu 已提交
564
class ExportModel(TheseusLayer):
D
dongshuilong 已提交
565 566 567 568
    """
    ExportModel: add softmax onto the model
    """

C
cuicheng01 已提交
569
    def __init__(self, config, model, use_multilabel):
D
dongshuilong 已提交
570 571 572 573 574 575 576 577 578 579 580 581
        super().__init__()
        self.base_model = model
        # we should choose a final model to export
        if isinstance(self.base_model, DistillationModel):
            self.infer_model_name = config["infer_model_name"]
        else:
            self.infer_model_name = None

        self.infer_output_key = config.get("infer_output_key", None)
        if self.infer_output_key == "features" and isinstance(self.base_model,
                                                              RecModel):
            self.base_model.head = IdentityHead()
C
cuicheng01 已提交
582 583
        if use_multilabel:
            self.out_act = nn.Sigmoid()
D
dongshuilong 已提交
584
        else:
C
cuicheng01 已提交
585 586 587 588
            if config.get("infer_add_softmax", True):
                self.out_act = nn.Softmax(axis=-1)
            else:
                self.out_act = None
D
dongshuilong 已提交
589 590 591 592 593 594 595 596 597 598 599 600 601 602 603

    def eval(self):
        self.training = False
        for layer in self.sublayers():
            layer.training = False
            layer.eval()

    def forward(self, x):
        x = self.base_model(x)
        if isinstance(x, list):
            x = x[0]
        if self.infer_model_name is not None:
            x = x[self.infer_model_name]
        if self.infer_output_key is not None:
            x = x[self.infer_output_key]
C
cuicheng01 已提交
604
        if self.out_act is not None:
wc晨曦's avatar
wc晨曦 已提交
605 606
            if isinstance(x, dict):
                x = x["logits"]
C
cuicheng01 已提交
607
            x = self.out_act(x)
D
dongshuilong 已提交
608
        return x