engine.py 27.1 KB
Newer Older
D
dongshuilong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

D
dongshuilong 已提交
17
import os
18
import platform
D
dongshuilong 已提交
19 20 21
import paddle
import paddle.distributed as dist
from visualdl import LogWriter
D
dongshuilong 已提交
22
from paddle import nn
D
dongshuilong 已提交
23 24
import numpy as np
import random
D
dongshuilong 已提交
25

26
from ppcls.utils.misc import AverageMeter
D
dongshuilong 已提交
27 28 29
from ppcls.utils import logger
from ppcls.utils.logger import init_logger
from ppcls.utils.config import print_config
30
from ppcls.data import build_dataloader
W
dbg  
weishengyu 已提交
31
from ppcls.arch import build_model, RecModel, DistillationModel, TheseusLayer
32
from ppcls.arch import apply_to_static
33 34 35
from ppcls.loss import build_loss
from ppcls.metric import build_metrics
from ppcls.optimizer import build_optimizer
36
from ppcls.utils.amp import AutoCast, build_scaler
37
from ppcls.utils.ema import ExponentialMovingAverage
D
dongshuilong 已提交
38
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
39
from ppcls.utils.save_load import init_model
40
from ppcls.utils import save_load
D
dongshuilong 已提交
41 42 43 44

from ppcls.data.utils.get_image_list import get_image_list
from ppcls.data.postprocess import build_postprocess
from ppcls.data import create_operators
45
from ppcls.engine import train as train_method
46
from ppcls.engine.train.utils import type_name
D
dongshuilong 已提交
47
from ppcls.engine import evaluation
D
dongshuilong 已提交
48 49 50
from ppcls.arch.gears.identity_head import IdentityHead


D
dongshuilong 已提交
51
class Engine(object):
D
dongshuilong 已提交
52
    def __init__(self, config, mode="train"):
D
dongshuilong 已提交
53
        assert mode in ["train", "eval", "infer", "export"]
D
dongshuilong 已提交
54 55
        self.mode = mode
        self.config = config
56 57 58
        self.eval_mode = self.config["Global"].get("eval_mode",
                                                   "classification")
        self.train_mode = self.config["Global"].get("train_mode", None)
59 60 61 62 63 64
        if "Head" in self.config["Arch"] or self.config["Arch"].get("is_rec",
                                                                    False):
            self.is_rec = True
        else:
            self.is_rec = False

65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
        # set seed
        seed = self.config["Global"].get("seed", False)
        if seed or seed == 0:
            assert isinstance(seed, int), "The 'seed' must be a integer!"
            paddle.seed(seed)
            np.random.seed(seed)
            random.seed(seed)

        # init logger
        self.output_dir = self.config['Global']['output_dir']
        log_file = os.path.join(self.output_dir, self.config["Arch"]["name"],
                                f"{mode}.log")
        init_logger(log_file=log_file)
        print_config(config)

80
        # init train_func and eval_func
81 82 83
        assert self.eval_mode in [
            "classification", "retrieval", "adaface"
        ], logger.error("Invalid eval mode: {}".format(self.eval_mode))
84 85 86 87 88 89
        if self.train_mode is None:
            self.train_epoch_func = train_method.train_epoch
        else:
            self.train_epoch_func = getattr(train_method,
                                            "train_epoch_" + self.train_mode)
        self.eval_func = getattr(evaluation, self.eval_mode + "_eval")
90

91 92 93 94 95 96 97 98 99 100 101
        self.use_dali = self.config['Global'].get("use_dali", False)

        # for visualdl
        self.vdl_writer = None
        if self.config['Global'][
                'use_visualdl'] and mode == "train" and dist.get_rank() == 0:
            vdl_writer_path = os.path.join(self.output_dir, "vdl")
            if not os.path.exists(vdl_writer_path):
                os.makedirs(vdl_writer_path)
            self.vdl_writer = LogWriter(logdir=vdl_writer_path)

D
dongshuilong 已提交
102
        # set device
103 104 105 106 107
        assert self.config["Global"][
            "device"] in ["cpu", "gpu", "xpu", "npu", "mlu", "ascend"]
        self.device = paddle.set_device(self.config["Global"]["device"])
        logger.info('train with paddle {} and device {}'.format(
            paddle.__version__, self.device))
D
dongshuilong 已提交
108

109 110 111
        # gradient accumulation
        self.update_freq = self.config["Global"].get("update_freq", 1)

112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
        if "class_num" in config["Global"]:
            global_class_num = config["Global"]["class_num"]
            if "class_num" not in config["Arch"]:
                config["Arch"]["class_num"] = global_class_num
                msg = f"The Global.class_num will be deprecated. Please use Arch.class_num instead. Arch.class_num has been set to {global_class_num}."
            else:
                msg = "The Global.class_num will be deprecated. Please use Arch.class_num instead. The Global.class_num has been ignored."
            logger.warning(msg)
        #TODO(gaotingquan): support rec
        class_num = config["Arch"].get("class_num", None)
        self.config["DataLoader"].update({"class_num": class_num})
        self.config["DataLoader"].update({
            "epochs": self.config["Global"]["epochs"]
        })

127
        # build dataloader
128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
        if self.mode == 'train':
            self.train_dataloader = build_dataloader(
                self.config["DataLoader"], "Train", self.device, self.use_dali)
            if self.config["DataLoader"].get('UnLabelTrain', None) is not None:
                self.unlabel_train_dataloader = build_dataloader(
                    self.config["DataLoader"], "UnLabelTrain", self.device,
                    self.use_dali)
            else:
                self.unlabel_train_dataloader = None

            self.iter_per_epoch = len(
                self.train_dataloader) - 1 if platform.system(
                ) == "Windows" else len(self.train_dataloader)
            if self.config["Global"].get("iter_per_epoch", None):
                # set max iteration per epoch mannualy, when training by iteration(s), such as XBM, FixMatch.
                self.iter_per_epoch = self.config["Global"].get(
                    "iter_per_epoch")
145 146 147 148 149
            if self.iter_per_epoch < self.update_freq:
                logger.warning(
                    "The arg Global.update_freq greater than iter_per_epoch and has been set to 1. This may be caused by too few of batches."
                )
                self.update_freq = 1
150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
            self.iter_per_epoch = self.iter_per_epoch // self.update_freq * self.update_freq

        if self.mode == "eval" or (self.mode == "train" and
                                   self.config["Global"]["eval_during_train"]):
            if self.eval_mode in ["classification", "adaface"]:
                self.eval_dataloader = build_dataloader(
                    self.config["DataLoader"], "Eval", self.device,
                    self.use_dali)
            elif self.eval_mode == "retrieval":
                self.gallery_query_dataloader = None
                if len(self.config["DataLoader"]["Eval"].keys()) == 1:
                    key = list(self.config["DataLoader"]["Eval"].keys())[0]
                    self.gallery_query_dataloader = build_dataloader(
                        self.config["DataLoader"]["Eval"], key, self.device,
                        self.use_dali)
                else:
                    self.gallery_dataloader = build_dataloader(
                        self.config["DataLoader"]["Eval"], "Gallery",
                        self.device, self.use_dali)
                    self.query_dataloader = build_dataloader(
                        self.config["DataLoader"]["Eval"], "Query",
                        self.device, self.use_dali)
172 173

        # build loss
174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
        if self.mode == "train":
            label_loss_info = self.config["Loss"]["Train"]
            self.train_loss_func = build_loss(label_loss_info)
            unlabel_loss_info = self.config.get("UnLabelLoss", {}).get("Train",
                                                                       None)
            self.unlabel_train_loss_func = build_loss(unlabel_loss_info)
        if self.mode == "eval" or (self.mode == "train" and
                                   self.config["Global"]["eval_during_train"]):
            loss_config = self.config.get("Loss", None)
            if loss_config is not None:
                loss_config = loss_config.get("Eval")
                if loss_config is not None:
                    self.eval_loss_func = build_loss(loss_config)
                else:
                    self.eval_loss_func = None
            else:
                self.eval_loss_func = None
191 192

        # build metric
193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222
        if self.mode == 'train' and "Metric" in self.config and "Train" in self.config[
                "Metric"] and self.config["Metric"]["Train"]:
            metric_config = self.config["Metric"]["Train"]
            if hasattr(self.train_dataloader, "collate_fn"
                       ) and self.train_dataloader.collate_fn is not None:
                for m_idx, m in enumerate(metric_config):
                    if "TopkAcc" in m:
                        msg = f"Unable to calculate accuracy when using \"batch_transform_ops\". The metric \"{m}\" has been removed."
                        logger.warning(msg)
                        metric_config.pop(m_idx)
            self.train_metric_func = build_metrics(metric_config)
        else:
            self.train_metric_func = None

        if self.mode == "eval" or (self.mode == "train" and
                                   self.config["Global"]["eval_during_train"]):
            if self.eval_mode == "classification":
                if "Metric" in self.config and "Eval" in self.config["Metric"]:
                    self.eval_metric_func = build_metrics(self.config["Metric"]
                                                          ["Eval"])
                else:
                    self.eval_metric_func = None
            elif self.eval_mode == "retrieval":
                if "Metric" in self.config and "Eval" in self.config["Metric"]:
                    metric_config = self.config["Metric"]["Eval"]
                else:
                    metric_config = [{"name": "Recallk", "topk": (1, 5)}]
                self.eval_metric_func = build_metrics(metric_config)
        else:
            self.eval_metric_func = None
223

D
dongshuilong 已提交
224
        # build model
littletomatodonkey's avatar
littletomatodonkey 已提交
225
        self.model = build_model(self.config, self.mode)
226 227
        # set @to_static for benchmark, skip this by default.
        apply_to_static(self.config, self.model)
D
dongshuilong 已提交
228

T
Tingquan Gao 已提交
229
        # load_pretrain
230 231 232 233 234 235 236 237 238
        if self.config["Global"]["pretrained_model"] is not None:
            if self.config["Global"]["pretrained_model"].startswith("http"):
                load_dygraph_pretrain_from_url(
                    [self.model, getattr(self, 'train_loss_func', None)],
                    self.config["Global"]["pretrained_model"])
            else:
                load_dygraph_pretrain(
                    [self.model, getattr(self, 'train_loss_func', None)],
                    self.config["Global"]["pretrained_model"])
T
Tingquan Gao 已提交
239

240
        # build optimizer
241 242 243 244 245
        if self.mode == 'train':
            self.optimizer, self.lr_sch = build_optimizer(
                self.config["Optimizer"], self.config["Global"]["epochs"],
                self.iter_per_epoch // self.update_freq,
                [self.model, self.train_loss_func])
G
gaotingquan 已提交
246 247
        # amp
        self._init_amp()
248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264

        # build EMA model
        self.ema = "EMA" in self.config and self.mode == "train"
        if self.ema:
            self.model_ema = ExponentialMovingAverage(
                self.model, self.config['EMA'].get("decay", 0.9999))

        # check the gpu num
        world_size = dist.get_world_size()
        self.config["Global"]["distributed"] = world_size != 1
        if self.mode == "train":
            std_gpu_num = 8 if isinstance(
                self.config["Optimizer"],
                dict) and self.config["Optimizer"]["name"] == "AdamW" else 4
            if world_size != std_gpu_num:
                msg = f"The training strategy provided by PaddleClas is based on {std_gpu_num} gpus. But the number of gpu is {world_size} in current training. Please modify the stategy (learning rate, batch size and so on) if use this config to train."
                logger.warning(msg)
265 266

        # for distributed
267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293
        if self.config["Global"]["distributed"]:
            dist.init_parallel_env()
            self.model = paddle.DataParallel(self.model)
            if self.mode == 'train' and len(self.train_loss_func.parameters(
            )) > 0:
                self.train_loss_func = paddle.DataParallel(
                    self.train_loss_func)

            # set different seed in different GPU manually in distributed environment
            if seed is None:
                logger.warning(
                    "The random seed cannot be None in a distributed environment. Global.seed has been set to 42 by default"
                )
                self.config["Global"]["seed"] = seed = 42
            logger.info(
                f"Set random seed to ({int(seed)} + $PADDLE_TRAINER_ID) for different trainer"
            )
            paddle.seed(int(seed) + dist.get_rank())
            np.random.seed(int(seed) + dist.get_rank())
            random.seed(int(seed) + dist.get_rank())

        # build postprocess for infer
        if self.mode == 'infer':
            self.preprocess_func = create_operators(self.config["Infer"][
                "transforms"])
            self.postprocess_func = build_postprocess(self.config["Infer"][
                "PostProcess"])
D
dongshuilong 已提交
294

295 296 297 298 299 300 301 302
    def train(self):
        assert self.mode == "train"
        print_batch_step = self.config['Global']['print_batch_step']
        save_interval = self.config["Global"]["save_interval"]
        best_metric = {
            "metric": -1.0,
            "epoch": 0,
        }
303
        ema_module = None
304 305 306
        if self.ema:
            best_metric_ema = 0.0
            ema_module = self.model_ema.module
307 308 309 310 311 312 313 314 315 316 317
        # key:
        # val: metrics list word
        self.output_info = dict()
        self.time_info = {
            "batch_cost": AverageMeter(
                "batch_cost", '.5f', postfix=" s,"),
            "reader_cost": AverageMeter(
                "reader_cost", ".5f", postfix=" s,"),
        }
        # global iter counter
        self.global_step = 0
318 319 320 321 322 323 324

        if self.config.Global.checkpoints is not None:
            metric_info = init_model(self.config.Global, self.model,
                                     self.optimizer, self.train_loss_func,
                                     ema_module)
            if metric_info is not None:
                best_metric.update(metric_info)
Y
Yang Nie 已提交
325 326 327
            if hasattr(self.train_dataloader.batch_sampler, "set_epoch"):
                self.train_dataloader.batch_sampler.set_epoch(best_metric[
                    "epoch"])
328 329 330 331

        for epoch_id in range(best_metric["epoch"] + 1,
                              self.config["Global"]["epochs"] + 1):
            acc = 0.0
332 333 334
            # for one epoch train
            self.train_epoch_func(self, epoch_id, print_batch_step)

335 336
            if self.use_dali:
                self.train_dataloader.reset()
337 338
            metric_msg = ", ".join(
                [self.output_info[key].avg_info for key in self.output_info])
339 340
            logger.info("[Train][Epoch {}/{}][Avg]{}".format(
                epoch_id, self.config["Global"]["epochs"], metric_msg))
341 342
            self.output_info.clear()

343 344 345
            # eval model and save model if possible
            start_eval_epoch = self.config["Global"].get("start_eval_epoch",
                                                         0) - 1
346 347
            if self.config["Global"][
                    "eval_during_train"] and epoch_id % self.config["Global"][
348
                        "eval_interval"] == 0 and epoch_id > start_eval_epoch:
349 350 351 352 353 354 355 356 357 358 359
                acc = self.eval(epoch_id)

                # step lr (by epoch) according to given metric, such as acc
                for i in range(len(self.lr_sch)):
                    if getattr(self.lr_sch[i], "by_epoch", False) and \
                            type_name(self.lr_sch[i]) == "ReduceOnPlateau":
                        self.lr_sch[i].step(acc)

                if acc > best_metric["metric"]:
                    best_metric["metric"] = acc
                    best_metric["epoch"] = epoch_id
360 361 362
                    save_load.save_model(
                        self.model,
                        self.optimizer,
363
                        best_metric,
364 365 366
                        self.output_dir,
                        ema=ema_module,
                        model_name=self.config["Arch"]["name"],
367
                        prefix="best_model",
368
                        loss=self.train_loss_func,
369 370 371 372 373 374 375 376 377 378 379
                        save_student_model=True)
                logger.info("[Eval][Epoch {}][best metric: {}]".format(
                    epoch_id, best_metric["metric"]))
                logger.scaler(
                    name="eval_acc",
                    value=acc,
                    step=epoch_id,
                    writer=self.vdl_writer)

                self.model.train()

380 381
                if self.ema:
                    ori_model, self.model = self.model, ema_module
382 383
                    acc_ema = self.eval(epoch_id)
                    self.model = ori_model
384
                    ema_module.eval()
385 386 387

                    if acc_ema > best_metric_ema:
                        best_metric_ema = acc_ema
388 389 390 391 392 393 394 395 396 397
                        save_load.save_model(
                            self.model,
                            self.optimizer,
                            {"metric": acc_ema,
                             "epoch": epoch_id},
                            self.output_dir,
                            ema=ema_module,
                            model_name=self.config["Arch"]["name"],
                            prefix="best_model_ema",
                            loss=self.train_loss_func)
398 399 400 401 402 403 404 405 406 407
                    logger.info("[Eval][Epoch {}][best metric ema: {}]".format(
                        epoch_id, best_metric_ema))
                    logger.scaler(
                        name="eval_acc_ema",
                        value=acc_ema,
                        step=epoch_id,
                        writer=self.vdl_writer)

            # save model
            if save_interval > 0 and epoch_id % save_interval == 0:
408 409 410 411 412 413 414 415 416
                save_load.save_model(
                    self.model,
                    self.optimizer, {"metric": acc,
                                     "epoch": epoch_id},
                    self.output_dir,
                    ema=ema_module,
                    model_name=self.config["Arch"]["name"],
                    prefix="epoch_{}".format(epoch_id),
                    loss=self.train_loss_func)
417
            # save the latest model
418 419 420 421 422 423 424 425 426
            save_load.save_model(
                self.model,
                self.optimizer, {"metric": acc,
                                 "epoch": epoch_id},
                self.output_dir,
                ema=ema_module,
                model_name=self.config["Arch"]["name"],
                prefix="latest",
                loss=self.train_loss_func)
427 428 429 430 431 432 433 434 435 436 437 438

        if self.vdl_writer is not None:
            self.vdl_writer.close()

    @paddle.no_grad()
    def eval(self, epoch_id=0):
        assert self.mode in ["train", "eval"]
        self.model.eval()
        eval_result = self.eval_func(self, epoch_id)
        self.model.train()
        return eval_result

D
dongshuilong 已提交
439 440 441
    @paddle.no_grad()
    def infer(self):
        assert self.mode == "infer" and self.eval_mode == "classification"
P
parap1uie-s 已提交
442
        results = []
443 444
        total_trainer = dist.get_world_size()
        local_rank = dist.get_rank()
D
dongshuilong 已提交
445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461
        image_list = get_image_list(self.config["Infer"]["infer_imgs"])
        # data split
        image_list = image_list[local_rank::total_trainer]

        batch_size = self.config["Infer"]["batch_size"]
        self.model.eval()
        batch_data = []
        image_file_list = []
        for idx, image_file in enumerate(image_list):
            with open(image_file, 'rb') as f:
                x = f.read()
            for process in self.preprocess_func:
                x = process(x)
            batch_data.append(x)
            image_file_list.append(image_file)
            if len(batch_data) >= batch_size or idx == len(image_list) - 1:
                batch_tensor = paddle.to_tensor(batch_data)
462

463
                with self.auto_cast(is_eval=True):
464
                    out = self.model(batch_tensor)
G
gaotingquan 已提交
465

D
dongshuilong 已提交
466 467
                if isinstance(out, list):
                    out = out[0]
littletomatodonkey's avatar
littletomatodonkey 已提交
468 469
                if isinstance(out, dict) and "Student" in out:
                    out = out["Student"]
470 471 472
                if isinstance(out, dict) and "logits" in out:
                    out = out["logits"]
                if isinstance(out, dict) and "output" in out:
W
dbg  
weishengyu 已提交
473
                    out = out["output"]
P
parap1uie-s 已提交
474

G
gaotingquan 已提交
475 476 477
                result = self.postprocess_func(out, image_file_list)
                logger.info(result)
                results.extend(result)
D
dongshuilong 已提交
478 479
                batch_data.clear()
                image_file_list.clear()
P
parap1uie-s 已提交
480
        return results
D
dongshuilong 已提交
481 482 483

    def export(self):
        assert self.mode == "export"
Z
zhiboniu 已提交
484 485
        use_multilabel = self.config["Global"].get(
            "use_multilabel",
C
cuicheng01 已提交
486
            False) or "ATTRMetric" in self.config["Metric"]["Eval"][0]
C
cuicheng01 已提交
487
        model = ExportModel(self.config["Arch"], self.model, use_multilabel)
488 489 490 491 492 493 494 495 496
        if self.config["Global"]["pretrained_model"] is not None:
            if self.config["Global"]["pretrained_model"].startswith("http"):
                load_dygraph_pretrain_from_url(
                    model.base_model,
                    self.config["Global"]["pretrained_model"])
            else:
                load_dygraph_pretrain(
                    model.base_model,
                    self.config["Global"]["pretrained_model"])
D
dongshuilong 已提交
497 498

        model.eval()
G
gaotingquan 已提交
499

500
        # for re-parameterization nets
H
HydrogenSulfate 已提交
501
        for layer in self.model.sublayers():
502 503 504
            if hasattr(layer, "re_parameterize") and not getattr(layer,
                                                                 "is_repped"):
                layer.re_parameterize()
G
gaotingquan 已提交
505

D
dongshuilong 已提交
506 507
        save_path = os.path.join(self.config["Global"]["save_inference_dir"],
                                 "inference")
littletomatodonkey's avatar
littletomatodonkey 已提交
508 509 510 511 512 513 514 515 516 517 518 519

        model = paddle.jit.to_static(
            model,
            input_spec=[
                paddle.static.InputSpec(
                    shape=[None] + self.config["Global"]["image_shape"],
                    dtype='float32')
            ])
        if hasattr(model.base_model,
                   "quanter") and model.base_model.quanter is not None:
            model.base_model.quanter.save_quantized_model(model,
                                                          save_path + "_int8")
D
dongshuilong 已提交
520 521
        else:
            paddle.jit.save(model, save_path)
G
gaotingquan 已提交
522 523 524
        logger.info(
            f"Export succeeded! The inference model exported has been saved in \"{self.config['Global']['save_inference_dir']}\"."
        )
D
dongshuilong 已提交
525

G
gaotingquan 已提交
526
    def _init_amp(self):
527
        amp_config = self.config.get("AMP", None)
G
gaotingquan 已提交
528 529
        use_amp = True if amp_config and amp_config.get("use_amp",
                                                        True) else False
G
gaotingquan 已提交
530

531 532 533 534
        if not use_amp:
            self.auto_cast = AutoCast(use_amp)
            self.scaler = build_scaler(use_amp)
        else:
G
gaotingquan 已提交
535 536 537 538 539 540 541
            AMP_RELATED_FLAGS_SETTING = {'FLAGS_max_inplace_grad_add': 8, }
            if paddle.is_compiled_with_cuda():
                AMP_RELATED_FLAGS_SETTING.update({
                    'FLAGS_cudnn_batchnorm_spatial_persistent': 1
                })
            paddle.set_flags(AMP_RELATED_FLAGS_SETTING)

542 543 544
            use_promote = amp_config.get("use_promote", False)
            amp_level = amp_config.get("level", "O1")
            if amp_level not in ["O1", "O2"]:
G
gaotingquan 已提交
545 546
                msg = "[Parameter Error]: The optimize level of AMP only support 'O1' and 'O2'. The level has been set 'O1'."
                logger.warning(msg)
547
                amp_level = amp_config["level"] = "O1"
G
gaotingquan 已提交
548

549
            amp_eval = self.config["AMP"].get("use_fp16_test", False)
G
gaotingquan 已提交
550 551 552
            # TODO(gaotingquan): Paddle not yet support FP32 evaluation when training with AMPO2
            if self.mode == "train" and self.config["Global"].get(
                    "eval_during_train",
553
                    True) and amp_level == "O2" and amp_eval == False:
G
gaotingquan 已提交
554 555 556
                msg = "PaddlePaddle only support FP16 evaluation when training with AMP O2 now. "
                logger.warning(msg)
                self.config["AMP"]["use_fp16_test"] = True
557 558 559 560 561 562 563 564 565 566 567 568 569 570 571
                amp_eval = True

            self.auto_cast = AutoCast(
                use_amp,
                amp_level=amp_level,
                use_promote=use_promote,
                amp_eval=amp_eval)

            scale_loss = amp_config.get("scale_loss", 1.0)
            use_dynamic_loss_scaling = amp_config.get(
                "use_dynamic_loss_scaling", False)
            self.scaler = build_scaler(
                use_amp,
                scale_loss=scale_loss,
                use_dynamic_loss_scaling=use_dynamic_loss_scaling)
G
gaotingquan 已提交
572 573 574 575 576

            if self.mode == "train":
                self.model, self.optimizer = paddle.amp.decorate(
                    models=self.model,
                    optimizers=self.optimizer,
577
                    level=amp_level,
G
gaotingquan 已提交
578 579 580
                    save_dtype='float32')
            elif self.amp_eval:
                self.model = paddle.amp.decorate(
581
                    models=self.model, level=amp_level, save_dtype='float32')
G
gaotingquan 已提交
582 583 584 585 586 587 588 589

            if self.mode == "train" and len(self.train_loss_func.parameters(
            )) > 0:
                self.train_loss_func = paddle.amp.decorate(
                    models=self.train_loss_func,
                    level=self.amp_level,
                    save_dtype='float32')

D
dongshuilong 已提交
590

W
dbg  
weishengyu 已提交
591
class ExportModel(TheseusLayer):
D
dongshuilong 已提交
592 593 594 595
    """
    ExportModel: add softmax onto the model
    """

C
cuicheng01 已提交
596
    def __init__(self, config, model, use_multilabel):
D
dongshuilong 已提交
597 598 599 600 601 602 603 604 605 606 607 608
        super().__init__()
        self.base_model = model
        # we should choose a final model to export
        if isinstance(self.base_model, DistillationModel):
            self.infer_model_name = config["infer_model_name"]
        else:
            self.infer_model_name = None

        self.infer_output_key = config.get("infer_output_key", None)
        if self.infer_output_key == "features" and isinstance(self.base_model,
                                                              RecModel):
            self.base_model.head = IdentityHead()
C
cuicheng01 已提交
609 610
        if use_multilabel:
            self.out_act = nn.Sigmoid()
D
dongshuilong 已提交
611
        else:
C
cuicheng01 已提交
612 613 614 615
            if config.get("infer_add_softmax", True):
                self.out_act = nn.Softmax(axis=-1)
            else:
                self.out_act = None
D
dongshuilong 已提交
616 617 618 619 620 621 622 623 624 625 626 627 628 629 630

    def eval(self):
        self.training = False
        for layer in self.sublayers():
            layer.training = False
            layer.eval()

    def forward(self, x):
        x = self.base_model(x)
        if isinstance(x, list):
            x = x[0]
        if self.infer_model_name is not None:
            x = x[self.infer_model_name]
        if self.infer_output_key is not None:
            x = x[self.infer_output_key]
C
cuicheng01 已提交
631
        if self.out_act is not None:
wc晨曦's avatar
wc晨曦 已提交
632 633
            if isinstance(x, dict):
                x = x["logits"]
C
cuicheng01 已提交
634
            x = self.out_act(x)
D
dongshuilong 已提交
635
        return x