engine.py 22.9 KB
Newer Older
D
dongshuilong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

D
dongshuilong 已提交
17
import os
D
dongshuilong 已提交
18 19 20 21
import platform
import paddle
import paddle.distributed as dist
from visualdl import LogWriter
D
dongshuilong 已提交
22
from paddle import nn
D
dongshuilong 已提交
23 24
import numpy as np
import random
D
dongshuilong 已提交
25 26 27 28 29 30 31

from ppcls.utils.check import check_gpu
from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger
from ppcls.utils.logger import init_logger
from ppcls.utils.config import print_config
from ppcls.data import build_dataloader
W
dbg  
weishengyu 已提交
32
from ppcls.arch import build_model, RecModel, DistillationModel, TheseusLayer
D
dongshuilong 已提交
33 34 35 36 37 38 39 40 41 42 43
from ppcls.arch import apply_to_static
from ppcls.loss import build_loss
from ppcls.metric import build_metrics
from ppcls.optimizer import build_optimizer
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
from ppcls.utils.save_load import init_model
from ppcls.utils import save_load

from ppcls.data.utils.get_image_list import get_image_list
from ppcls.data.postprocess import build_postprocess
from ppcls.data import create_operators
D
dongshuilong 已提交
44 45
from ppcls.engine.train import train_epoch
from ppcls.engine import evaluation
D
dongshuilong 已提交
46 47 48
from ppcls.arch.gears.identity_head import IdentityHead


D
dongshuilong 已提交
49
class Engine(object):
D
dongshuilong 已提交
50
    def __init__(self, config, mode="train"):
D
dongshuilong 已提交
51
        assert mode in ["train", "eval", "infer", "export"]
D
dongshuilong 已提交
52 53 54 55
        self.mode = mode
        self.config = config
        self.eval_mode = self.config["Global"].get("eval_mode",
                                                   "classification")
56 57
        if "Head" in self.config["Arch"] or self.config["Arch"].get("is_rec",
                                                                    False):
D
dongshuilong 已提交
58 59 60 61
            self.is_rec = True
        else:
            self.is_rec = False

D
dongshuilong 已提交
62 63
        # set seed
        seed = self.config["Global"].get("seed", False)
S
stephon 已提交
64
        if seed or seed == 0:
D
dongshuilong 已提交
65 66 67 68 69
            assert isinstance(seed, int), "The 'seed' must be a integer!"
            paddle.seed(seed)
            np.random.seed(seed)
            random.seed(seed)

D
dongshuilong 已提交
70 71 72 73
        # init logger
        self.output_dir = self.config['Global']['output_dir']
        log_file = os.path.join(self.output_dir, self.config["Arch"]["name"],
                                f"{mode}.log")
D
dongshuilong 已提交
74
        init_logger(log_file=log_file)
D
dongshuilong 已提交
75 76 77
        print_config(config)

        # init train_func and eval_func
D
dongshuilong 已提交
78 79 80
        assert self.eval_mode in [
            "classification", "retrieval", "adaface"
        ], logger.error("Invalid eval mode: {}".format(self.eval_mode))
D
dongshuilong 已提交
81 82 83
        self.train_epoch_func = train_epoch
        self.eval_func = getattr(evaluation, self.eval_mode + "_eval")

D
dongshuilong 已提交
84 85 86 87
        self.use_dali = self.config['Global'].get("use_dali", False)

        # for visualdl
        self.vdl_writer = None
T
Tingquan Gao 已提交
88 89
        if self.config['Global'][
                'use_visualdl'] and mode == "train" and dist.get_rank() == 0:
D
dongshuilong 已提交
90 91 92 93 94 95
            vdl_writer_path = os.path.join(self.output_dir, "vdl")
            if not os.path.exists(vdl_writer_path):
                os.makedirs(vdl_writer_path)
            self.vdl_writer = LogWriter(logdir=vdl_writer_path)

        # set device
D
dongshuilong 已提交
96 97
        assert self.config["Global"][
            "device"] in ["cpu", "gpu", "xpu", "npu", "mlu"]
D
dongshuilong 已提交
98 99 100 101
        self.device = paddle.set_device(self.config["Global"]["device"])
        logger.info('train with paddle {} and device {}'.format(
            paddle.__version__, self.device))

102 103 104 105 106 107 108 109
        if "class_num" in config["Global"]:
            global_class_num = config["Global"]["class_num"]
            if "class_num" not in config["Arch"]:
                config["Arch"]["class_num"] = global_class_num
                msg = f"The Global.class_num will be deprecated. Please use Arch.class_num instead. Arch.class_num has been set to {global_class_num}."
            else:
                msg = "The Global.class_num will be deprecated. Please use Arch.class_num instead. The Global.class_num has been ignored."
            logger.warning(msg)
110
        #TODO(gaotingquan): support rec
G
gaotingquan 已提交
111 112
        class_num = config["Arch"].get("class_num", None)
        self.config["DataLoader"].update({"class_num": class_num})
D
dongshuilong 已提交
113 114 115 116
        # build dataloader
        if self.mode == 'train':
            self.train_dataloader = build_dataloader(
                self.config["DataLoader"], "Train", self.device, self.use_dali)
D
dongshuilong 已提交
117 118
        if self.mode == "eval" or (self.mode == "train" and
                                   self.config["Global"]["eval_during_train"]):
D
dongshuilong 已提交
119
            if self.eval_mode in ["classification", "adaface"]:
D
dongshuilong 已提交
120 121 122 123
                self.eval_dataloader = build_dataloader(
                    self.config["DataLoader"], "Eval", self.device,
                    self.use_dali)
            elif self.eval_mode == "retrieval":
124 125 126 127 128 129 130 131 132 133 134 135 136
                self.gallery_query_dataloader = None
                if len(self.config["DataLoader"]["Eval"].keys()) == 1:
                    key = list(self.config["DataLoader"]["Eval"].keys())[0]
                    self.gallery_query_dataloader = build_dataloader(
                        self.config["DataLoader"]["Eval"], key, self.device,
                        self.use_dali)
                else:
                    self.gallery_dataloader = build_dataloader(
                        self.config["DataLoader"]["Eval"], "Gallery",
                        self.device, self.use_dali)
                    self.query_dataloader = build_dataloader(
                        self.config["DataLoader"]["Eval"], "Query",
                        self.device, self.use_dali)
D
dongshuilong 已提交
137 138 139 140 141

        # build loss
        if self.mode == "train":
            loss_info = self.config["Loss"]["Train"]
            self.train_loss_func = build_loss(loss_info)
D
dongshuilong 已提交
142 143
        if self.mode == "eval" or (self.mode == "train" and
                                   self.config["Global"]["eval_during_train"]):
D
dongshuilong 已提交
144 145 146 147 148 149 150 151 152 153 154
            loss_config = self.config.get("Loss", None)
            if loss_config is not None:
                loss_config = loss_config.get("Eval")
                if loss_config is not None:
                    self.eval_loss_func = build_loss(loss_config)
                else:
                    self.eval_loss_func = None
            else:
                self.eval_loss_func = None

        # build metric
155
        if self.mode == 'train' and "Metric" in self.config and "Train" in self.config[
T
Tingquan Gao 已提交
156
                "Metric"] and self.config["Metric"]["Train"]:
157 158 159 160 161
            metric_config = self.config["Metric"]["Train"]
            if hasattr(self.train_dataloader, "collate_fn"
                       ) and self.train_dataloader.collate_fn is not None:
                for m_idx, m in enumerate(metric_config):
                    if "TopkAcc" in m:
162
                        msg = f"Unable to calculate accuracy when using \"batch_transform_ops\". The metric \"{m}\" has been removed."
163
                        logger.warning(msg)
164
                        metric_config.pop(m_idx)
165
            self.train_metric_func = build_metrics(metric_config)
D
dongshuilong 已提交
166 167 168
        else:
            self.train_metric_func = None

D
dongshuilong 已提交
169 170
        if self.mode == "eval" or (self.mode == "train" and
                                   self.config["Global"]["eval_during_train"]):
D
dongshuilong 已提交
171
            if self.eval_mode == "classification":
172 173 174 175 176
                if "Metric" in self.config and "Eval" in self.config["Metric"]:
                    self.eval_metric_func = build_metrics(self.config["Metric"]
                                                          ["Eval"])
                else:
                    self.eval_metric_func = None
D
dongshuilong 已提交
177
            elif self.eval_mode == "retrieval":
178
                if "Metric" in self.config and "Eval" in self.config["Metric"]:
G
gaotingquan 已提交
179
                    metric_config = self.config["Metric"]["Eval"]
D
dongshuilong 已提交
180
                else:
181
                    metric_config = [{"name": "Recallk", "topk": (1, 5)}]
D
dongshuilong 已提交
182 183 184 185 186
                self.eval_metric_func = build_metrics(metric_config)
        else:
            self.eval_metric_func = None

        # build model
littletomatodonkey's avatar
littletomatodonkey 已提交
187
        self.model = build_model(self.config, self.mode)
D
dongshuilong 已提交
188 189
        # set @to_static for benchmark, skip this by default.
        apply_to_static(self.config, self.model)
D
dongshuilong 已提交
190

D
dongshuilong 已提交
191 192 193 194
        # load_pretrain
        if self.config["Global"]["pretrained_model"] is not None:
            if self.config["Global"]["pretrained_model"].startswith("http"):
                load_dygraph_pretrain_from_url(
195
                    [self.model, getattr(self, 'train_loss_func', None)],
196
                    self.config["Global"]["pretrained_model"])
D
dongshuilong 已提交
197 198
            else:
                load_dygraph_pretrain(
199
                    [self.model, getattr(self, 'train_loss_func', None)],
200
                    self.config["Global"]["pretrained_model"])
D
dongshuilong 已提交
201 202 203 204

        # build optimizer
        if self.mode == 'train':
            self.optimizer, self.lr_sch = build_optimizer(
H
HydrogenSulfate 已提交
205
                self.config["Optimizer"], self.config["Global"]["epochs"],
206 207
                len(self.train_dataloader),
                [self.model, self.train_loss_func])
208

209 210 211
        # AMP training and evaluating
        self.amp = "AMP" in self.config and self.config["AMP"] is not None
        self.amp_eval = False
G
gaotingquan 已提交
212
        # for amp
Z
zhangbo9674 已提交
213
        if self.amp:
214 215 216 217 218
            AMP_RELATED_FLAGS_SETTING = {'FLAGS_max_inplace_grad_add': 8, }
            if paddle.is_compiled_with_cuda():
                AMP_RELATED_FLAGS_SETTING.update({
                    'FLAGS_cudnn_batchnorm_spatial_persistent': 1
                })
G
gaotingquan 已提交
219
            paddle.set_flags(AMP_RELATED_FLAGS_SETTING)
220 221 222 223

            self.scale_loss = self.config["AMP"].get("scale_loss", 1.0)
            self.use_dynamic_loss_scaling = self.config["AMP"].get(
                "use_dynamic_loss_scaling", False)
Z
zhangbo9674 已提交
224 225 226
            self.scaler = paddle.amp.GradScaler(
                init_loss_scaling=self.scale_loss,
                use_dynamic_loss_scaling=self.use_dynamic_loss_scaling)
227 228 229

            self.amp_level = self.config['AMP'].get("level", "O1")
            if self.amp_level not in ["O1", "O2"]:
230 231 232
                msg = "[Parameter Error]: The optimize level of AMP only support 'O1' and 'O2'. The level has been set 'O1'."
                logger.warning(msg)
                self.config['AMP']["level"] = "O1"
233 234 235 236
                self.amp_level = "O1"

            self.amp_eval = self.config["AMP"].get("use_fp16_test", False)
            # TODO(gaotingquan): Paddle not yet support FP32 evaluation when training with AMPO2
G
gaotingquan 已提交
237
            if self.mode == "train" and self.config["Global"].get(
238 239 240 241 242 243 244
                    "eval_during_train",
                    True) and self.amp_level == "O2" and self.amp_eval == False:
                msg = "PaddlePaddle only support FP16 evaluation when training with AMP O2 now. "
                logger.warning(msg)
                self.config["AMP"]["use_fp16_test"] = True
                self.amp_eval = True

G
gaotingquan 已提交
245 246
            # TODO(gaotingquan): to compatible with different versions of Paddle
            paddle_version = paddle.__version__[:3]
247
            # paddle version < 2.3.0 and not develop
G
gaotingquan 已提交
248
            if paddle_version not in ["2.3", "0.0"]:
249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266
                if self.mode == "train":
                    self.model, self.optimizer = paddle.amp.decorate(
                        models=self.model,
                        optimizers=self.optimizer,
                        level=self.amp_level,
                        save_dtype='float32')
                elif self.amp_eval:
                    if self.amp_level == "O2":
                        msg = "The PaddlePaddle that installed not support FP16 evaluation in AMP O2. Please use PaddlePaddle version >= 2.3.0. Use FP32 evaluation instead and please notice the Eval Dataset output_fp16 should be 'False'."
                        logger.warning(msg)
                        self.amp_eval = False
                    else:
                        self.model, self.optimizer = paddle.amp.decorate(
                            models=self.model,
                            level=self.amp_level,
                            save_dtype='float32')
            # paddle version >= 2.3.0 or develop
            else:
G
gaotingquan 已提交
267 268 269 270 271
                if self.mode == "train" or self.amp_eval:
                    self.model = paddle.amp.decorate(
                        models=self.model,
                        level=self.amp_level,
                        save_dtype='float32')
272

G
gaotingquan 已提交
273 274
            if self.mode == "train" and len(self.train_loss_func.parameters(
            )) > 0:
275 276
                self.train_loss_func = paddle.amp.decorate(
                    models=self.train_loss_func,
277
                    level=self.amp_level,
278
                    save_dtype='float32')
D
dongshuilong 已提交
279

280
        # check the gpu num
281 282
        world_size = dist.get_world_size()
        self.config["Global"]["distributed"] = world_size != 1
283
        if self.mode == "train":
H
HydrogenSulfate 已提交
284 285 286
            std_gpu_num = 8 if isinstance(
                self.config["Optimizer"],
                dict) and self.config["Optimizer"]["name"] == "AdamW" else 4
287 288 289 290 291
            if world_size != std_gpu_num:
                msg = f"The training strategy provided by PaddleClas is based on {std_gpu_num} gpus. But the number of gpu is {world_size} in current training. Please modify the stategy (learning rate, batch size and so on) if use this config to train."
                logger.warning(msg)

        # for distributed
D
dongshuilong 已提交
292 293 294
        if self.config["Global"]["distributed"]:
            dist.init_parallel_env()
            self.model = paddle.DataParallel(self.model)
295 296
            if self.mode == 'train' and len(self.train_loss_func.parameters(
            )) > 0:
297 298
                self.train_loss_func = paddle.DataParallel(
                    self.train_loss_func)
D
dongshuilong 已提交
299 300 301 302 303 304 305 306 307 308 309 310
        # build postprocess for infer
        if self.mode == 'infer':
            self.preprocess_func = create_operators(self.config["Infer"][
                "transforms"])
            self.postprocess_func = build_postprocess(self.config["Infer"][
                "PostProcess"])

    def train(self):
        assert self.mode == "train"
        print_batch_step = self.config['Global']['print_batch_step']
        save_interval = self.config["Global"]["save_interval"]
        best_metric = {
C
cuicheng01 已提交
311
            "metric": -1.0,
D
dongshuilong 已提交
312 313 314 315 316 317 318 319 320 321 322 323 324 325
            "epoch": 0,
        }
        # key:
        # val: metrics list word
        self.output_info = dict()
        self.time_info = {
            "batch_cost": AverageMeter(
                "batch_cost", '.5f', postfix=" s,"),
            "reader_cost": AverageMeter(
                "reader_cost", ".5f", postfix=" s,"),
        }
        # global iter counter
        self.global_step = 0

326 327 328
        if self.config.Global.checkpoints is not None:
            metric_info = init_model(self.config.Global, self.model,
                                     self.optimizer, self.train_loss_func)
D
dongshuilong 已提交
329 330 331 332 333
            if metric_info is not None:
                best_metric.update(metric_info)

        self.max_iter = len(self.train_dataloader) - 1 if platform.system(
        ) == "Windows" else len(self.train_dataloader)
H
HydrogenSulfate 已提交
334

D
dongshuilong 已提交
335 336 337 338
        for epoch_id in range(best_metric["epoch"] + 1,
                              self.config["Global"]["epochs"] + 1):
            acc = 0.0
            # for one epoch train
D
dongshuilong 已提交
339
            self.train_epoch_func(self, epoch_id, print_batch_step)
D
dongshuilong 已提交
340 341 342

            if self.use_dali:
                self.train_dataloader.reset()
littletomatodonkey's avatar
littletomatodonkey 已提交
343 344
            metric_msg = ", ".join(
                [self.output_info[key].avg_info for key in self.output_info])
D
dongshuilong 已提交
345 346 347 348 349
            logger.info("[Train][Epoch {}/{}][Avg]{}".format(
                epoch_id, self.config["Global"]["epochs"], metric_msg))
            self.output_info.clear()

            # eval model and save model if possible
littletomatodonkey's avatar
littletomatodonkey 已提交
350 351
            start_eval_epoch = self.config["Global"].get("start_eval_epoch",
                                                         0) - 1
D
dongshuilong 已提交
352 353
            if self.config["Global"][
                    "eval_during_train"] and epoch_id % self.config["Global"][
C
cuicheng01 已提交
354
                        "eval_interval"] == 0 and epoch_id > start_eval_epoch:
D
dongshuilong 已提交
355 356 357 358 359 360 361 362 363 364
                acc = self.eval(epoch_id)
                if acc > best_metric["metric"]:
                    best_metric["metric"] = acc
                    best_metric["epoch"] = epoch_id
                    save_load.save_model(
                        self.model,
                        self.optimizer,
                        best_metric,
                        self.output_dir,
                        model_name=self.config["Arch"]["name"],
365
                        prefix="best_model",
littletomatodonkey's avatar
littletomatodonkey 已提交
366 367
                        loss=self.train_loss_func,
                        save_student_model=True)
D
dongshuilong 已提交
368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385
                logger.info("[Eval][Epoch {}][best metric: {}]".format(
                    epoch_id, best_metric["metric"]))
                logger.scaler(
                    name="eval_acc",
                    value=acc,
                    step=epoch_id,
                    writer=self.vdl_writer)

                self.model.train()

            # save model
            if epoch_id % save_interval == 0:
                save_load.save_model(
                    self.model,
                    self.optimizer, {"metric": acc,
                                     "epoch": epoch_id},
                    self.output_dir,
                    model_name=self.config["Arch"]["name"],
386 387
                    prefix="epoch_{}".format(epoch_id),
                    loss=self.train_loss_func)
G
gaotingquan 已提交
388 389 390 391 392 393 394
            # save the latest model
            save_load.save_model(
                self.model,
                self.optimizer, {"metric": acc,
                                 "epoch": epoch_id},
                self.output_dir,
                model_name=self.config["Arch"]["name"],
395 396
                prefix="latest",
                loss=self.train_loss_func)
D
dongshuilong 已提交
397 398 399 400 401 402 403 404

        if self.vdl_writer is not None:
            self.vdl_writer.close()

    @paddle.no_grad()
    def eval(self, epoch_id=0):
        assert self.mode in ["train", "eval"]
        self.model.eval()
D
dongshuilong 已提交
405
        eval_result = self.eval_func(self, epoch_id)
D
dongshuilong 已提交
406 407 408 409 410 411
        self.model.train()
        return eval_result

    @paddle.no_grad()
    def infer(self):
        assert self.mode == "infer" and self.eval_mode == "classification"
412 413
        total_trainer = dist.get_world_size()
        local_rank = dist.get_rank()
D
dongshuilong 已提交
414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430
        image_list = get_image_list(self.config["Infer"]["infer_imgs"])
        # data split
        image_list = image_list[local_rank::total_trainer]

        batch_size = self.config["Infer"]["batch_size"]
        self.model.eval()
        batch_data = []
        image_file_list = []
        for idx, image_file in enumerate(image_list):
            with open(image_file, 'rb') as f:
                x = f.read()
            for process in self.preprocess_func:
                x = process(x)
            batch_data.append(x)
            image_file_list.append(image_file)
            if len(batch_data) >= batch_size or idx == len(image_list) - 1:
                batch_tensor = paddle.to_tensor(batch_data)
G
gaotingquan 已提交
431 432 433 434 435 436 437 438 439 440 441

                if self.amp and self.amp_eval:
                    with paddle.amp.auto_cast(
                            custom_black_list={
                                "flatten_contiguous_range", "greater_than"
                            },
                            level=self.amp_level):
                        out = self.model(batch_tensor)
                else:
                    out = self.model(batch_tensor)

D
dongshuilong 已提交
442 443
                if isinstance(out, list):
                    out = out[0]
444 445 446
                if isinstance(out, dict) and "logits" in out:
                    out = out["logits"]
                if isinstance(out, dict) and "output" in out:
W
dbg  
weishengyu 已提交
447
                    out = out["output"]
D
dongshuilong 已提交
448 449 450 451 452 453 454
                result = self.postprocess_func(out, image_file_list)
                print(result)
                batch_data.clear()
                image_file_list.clear()

    def export(self):
        assert self.mode == "export"
Z
zhiboniu 已提交
455 456
        use_multilabel = self.config["Global"].get(
            "use_multilabel",
457
            False) and "ATTRMetric" in self.config["Metric"]["Eval"][0]
C
cuicheng01 已提交
458
        model = ExportModel(self.config["Arch"], self.model, use_multilabel)
D
dongshuilong 已提交
459 460 461 462 463
        if self.config["Global"]["pretrained_model"] is not None:
            load_dygraph_pretrain(model.base_model,
                                  self.config["Global"]["pretrained_model"])

        model.eval()
G
gaotingquan 已提交
464 465 466 467 468 469

        # for rep nets
        for layer in self.model.sublayers():
            if hasattr(layer, "rep"):
                layer.rep()

D
dongshuilong 已提交
470 471
        save_path = os.path.join(self.config["Global"]["save_inference_dir"],
                                 "inference")
littletomatodonkey's avatar
littletomatodonkey 已提交
472 473 474 475 476 477 478 479 480 481 482 483

        model = paddle.jit.to_static(
            model,
            input_spec=[
                paddle.static.InputSpec(
                    shape=[None] + self.config["Global"]["image_shape"],
                    dtype='float32')
            ])
        if hasattr(model.base_model,
                   "quanter") and model.base_model.quanter is not None:
            model.base_model.quanter.save_quantized_model(model,
                                                          save_path + "_int8")
D
dongshuilong 已提交
484 485
        else:
            paddle.jit.save(model, save_path)
G
gaotingquan 已提交
486 487 488
        logger.info(
            f"Export succeeded! The inference model exported has been saved in \"{self.config['Global']['save_inference_dir']}\"."
        )
D
dongshuilong 已提交
489 490


W
dbg  
weishengyu 已提交
491
class ExportModel(TheseusLayer):
D
dongshuilong 已提交
492 493 494 495
    """
    ExportModel: add softmax onto the model
    """

C
cuicheng01 已提交
496
    def __init__(self, config, model, use_multilabel):
D
dongshuilong 已提交
497 498 499 500 501 502 503 504 505 506 507 508
        super().__init__()
        self.base_model = model
        # we should choose a final model to export
        if isinstance(self.base_model, DistillationModel):
            self.infer_model_name = config["infer_model_name"]
        else:
            self.infer_model_name = None

        self.infer_output_key = config.get("infer_output_key", None)
        if self.infer_output_key == "features" and isinstance(self.base_model,
                                                              RecModel):
            self.base_model.head = IdentityHead()
C
cuicheng01 已提交
509 510
        if use_multilabel:
            self.out_act = nn.Sigmoid()
D
dongshuilong 已提交
511
        else:
C
cuicheng01 已提交
512 513 514 515
            if config.get("infer_add_softmax", True):
                self.out_act = nn.Softmax(axis=-1)
            else:
                self.out_act = None
D
dongshuilong 已提交
516 517 518 519 520 521 522 523 524 525 526 527 528 529 530

    def eval(self):
        self.training = False
        for layer in self.sublayers():
            layer.training = False
            layer.eval()

    def forward(self, x):
        x = self.base_model(x)
        if isinstance(x, list):
            x = x[0]
        if self.infer_model_name is not None:
            x = x[self.infer_model_name]
        if self.infer_output_key is not None:
            x = x[self.infer_output_key]
C
cuicheng01 已提交
531
        if self.out_act is not None:
wc晨曦's avatar
wc晨曦 已提交
532 533
            if isinstance(x, dict):
                x = x["logits"]
C
cuicheng01 已提交
534
            x = self.out_act(x)
D
dongshuilong 已提交
535
        return x