engine.py 21.6 KB
Newer Older
D
dongshuilong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

D
dongshuilong 已提交
17
import os
D
dongshuilong 已提交
18 19 20
import paddle
import paddle.distributed as dist
from visualdl import LogWriter
D
dongshuilong 已提交
21
from paddle import nn
D
dongshuilong 已提交
22 23
import numpy as np
import random
D
dongshuilong 已提交
24 25 26 27 28 29

from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger
from ppcls.utils.logger import init_logger
from ppcls.utils.config import print_config
from ppcls.data import build_dataloader
W
dbg  
weishengyu 已提交
30
from ppcls.arch import build_model, RecModel, DistillationModel, TheseusLayer
D
dongshuilong 已提交
31 32 33
from ppcls.loss import build_loss
from ppcls.metric import build_metrics
from ppcls.optimizer import build_optimizer
F
flytocc 已提交
34
from ppcls.utils.ema import ExponentialMovingAverage
D
dongshuilong 已提交
35 36 37 38 39 40 41
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
from ppcls.utils.save_load import init_model
from ppcls.utils import save_load

from ppcls.data.utils.get_image_list import get_image_list
from ppcls.data.postprocess import build_postprocess
from ppcls.data import create_operators
42 43
from .train import build_train_epoch_func
from .evaluation import build_eval_func
44
from ppcls.engine.train.utils import type_name
D
dongshuilong 已提交
45
from ppcls.engine import evaluation
D
dongshuilong 已提交
46 47 48
from ppcls.arch.gears.identity_head import IdentityHead


D
dongshuilong 已提交
49
class Engine(object):
D
dongshuilong 已提交
50
    def __init__(self, config, mode="train"):
D
dongshuilong 已提交
51
        assert mode in ["train", "eval", "infer", "export"]
D
dongshuilong 已提交
52 53
        self.mode = mode
        self.config = config
D
dongshuilong 已提交
54

D
dongshuilong 已提交
55
        # set seed
G
gaotingquan 已提交
56
        self._init_seed()
D
dongshuilong 已提交
57

D
dongshuilong 已提交
58
        # init logger
G
gaotingquan 已提交
59
        init_logger(self.config, mode=mode)
D
dongshuilong 已提交
60

G
gaotingquan 已提交
61 62 63
        # for visualdl
        self.vdl_writer = self._init_vdl()

D
dongshuilong 已提交
64
        # init train_func and eval_func
65 66
        self.train_epoch_func = build_train_epoch_func(self.config)
        self.eval_epoch_func = build_eval_func(self.config)
D
dongshuilong 已提交
67 68

        # set device
69
        self._init_device()
D
dongshuilong 已提交
70

71 72 73
        # gradient accumulation
        self.update_freq = self.config["Global"].get("update_freq", 1)

D
dongshuilong 已提交
74
        # build dataloader
G
gaotingquan 已提交
75 76 77 78 79 80 81
        self.dataloader_dict = build_dataloader(self)
        self.train_dataloader, self.unlabel_train_dataloader, self.eval_dataloader = self.dataloader_dict[
            "Train"], self.dataloader_dict[
                "UnLabelTrain"], self.dataloader_dict["Eval"]
        self.gallery_query_dataloader, self.gallery_dataloader, self.query_dataloader = self.dataloader_dict[
            "GalleryQuery"], self.dataloader_dict[
                "Gallery"], self.dataloader_dict["Query"]
D
dongshuilong 已提交
82 83

        # build loss
G
gaotingquan 已提交
84 85
        self.train_loss_func, self.unlabel_train_loss_func, self.eval_loss_func = build_loss(
            self.config, self.mode)
D
dongshuilong 已提交
86 87

        # build metric
G
gaotingquan 已提交
88
        self.train_metric_func, self.eval_metric_func = build_metrics(self)
D
dongshuilong 已提交
89 90

        # build model
littletomatodonkey's avatar
littletomatodonkey 已提交
91
        self.model = build_model(self.config, self.mode)
D
dongshuilong 已提交
92

D
dongshuilong 已提交
93
        # load_pretrain
G
gaotingquan 已提交
94
        self._init_pretrained()
D
dongshuilong 已提交
95 96

        # build optimizer
G
gaotingquan 已提交
97 98 99
        self.optimizer, self.lr_sch = build_optimizer(
            self.config, self.train_dataloader,
            [self.model, self.train_loss_func])
100

101
        # AMP training and evaluating
G
gaotingquan 已提交
102
        self._init_amp()
103 104

        # for distributed
G
gaotingquan 已提交
105
        self._init_dist()
D
dongshuilong 已提交
106

107 108
        print_config(config)

D
dongshuilong 已提交
109 110 111 112 113
    def train(self):
        assert self.mode == "train"
        print_batch_step = self.config['Global']['print_batch_step']
        save_interval = self.config["Global"]["save_interval"]
        best_metric = {
C
cuicheng01 已提交
114
            "metric": -1.0,
D
dongshuilong 已提交
115 116
            "epoch": 0,
        }
G
gaotingquan 已提交
117 118 119

        # build EMA model
        self.ema = "EMA" in self.config and self.mode == "train"
F
flytocc 已提交
120
        if self.ema:
G
gaotingquan 已提交
121 122
            self.model_ema = ExponentialMovingAverage(
                self.model, self.config['EMA'].get("decay", 0.9999))
F
flytocc 已提交
123 124
            best_metric_ema = 0.0
            ema_module = self.model_ema.module
G
gaotingquan 已提交
125 126 127
        else:
            ema_module = None

D
dongshuilong 已提交
128 129 130 131 132 133 134 135 136 137 138 139
        # key:
        # val: metrics list word
        self.output_info = dict()
        self.time_info = {
            "batch_cost": AverageMeter(
                "batch_cost", '.5f', postfix=" s,"),
            "reader_cost": AverageMeter(
                "reader_cost", ".5f", postfix=" s,"),
        }
        # global iter counter
        self.global_step = 0

140 141
        if self.config.Global.checkpoints is not None:
            metric_info = init_model(self.config.Global, self.model,
F
flytocc 已提交
142 143
                                     self.optimizer, self.train_loss_func,
                                     ema_module)
D
dongshuilong 已提交
144 145 146 147 148 149 150
            if metric_info is not None:
                best_metric.update(metric_info)

        for epoch_id in range(best_metric["epoch"] + 1,
                              self.config["Global"]["epochs"] + 1):
            acc = 0.0
            # for one epoch train
D
dongshuilong 已提交
151
            self.train_epoch_func(self, epoch_id, print_batch_step)
D
dongshuilong 已提交
152

littletomatodonkey's avatar
littletomatodonkey 已提交
153 154
            metric_msg = ", ".join(
                [self.output_info[key].avg_info for key in self.output_info])
D
dongshuilong 已提交
155 156 157 158 159
            logger.info("[Train][Epoch {}/{}][Avg]{}".format(
                epoch_id, self.config["Global"]["epochs"], metric_msg))
            self.output_info.clear()

            # eval model and save model if possible
littletomatodonkey's avatar
littletomatodonkey 已提交
160 161
            start_eval_epoch = self.config["Global"].get("start_eval_epoch",
                                                         0) - 1
D
dongshuilong 已提交
162 163
            if self.config["Global"][
                    "eval_during_train"] and epoch_id % self.config["Global"][
C
cuicheng01 已提交
164
                        "eval_interval"] == 0 and epoch_id > start_eval_epoch:
D
dongshuilong 已提交
165
                acc = self.eval(epoch_id)
H
add xbm  
HydrogenSulfate 已提交
166 167 168 169

                # step lr (by epoch) according to given metric, such as acc
                for i in range(len(self.lr_sch)):
                    if getattr(self.lr_sch[i], "by_epoch", False) and \
170
                            type_name(self.lr_sch[i]) == "ReduceOnPlateau":
H
add xbm  
HydrogenSulfate 已提交
171 172
                        self.lr_sch[i].step(acc)

D
dongshuilong 已提交
173 174 175 176 177 178 179 180
                if acc > best_metric["metric"]:
                    best_metric["metric"] = acc
                    best_metric["epoch"] = epoch_id
                    save_load.save_model(
                        self.model,
                        self.optimizer,
                        best_metric,
                        self.output_dir,
F
flytocc 已提交
181
                        ema=ema_module,
D
dongshuilong 已提交
182
                        model_name=self.config["Arch"]["name"],
183
                        prefix="best_model",
littletomatodonkey's avatar
littletomatodonkey 已提交
184 185
                        loss=self.train_loss_func,
                        save_student_model=True)
D
dongshuilong 已提交
186 187 188 189 190 191 192 193 194 195
                logger.info("[Eval][Epoch {}][best metric: {}]".format(
                    epoch_id, best_metric["metric"]))
                logger.scaler(
                    name="eval_acc",
                    value=acc,
                    step=epoch_id,
                    writer=self.vdl_writer)

                self.model.train()

F
flytocc 已提交
196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221
                if self.ema:
                    ori_model, self.model = self.model, ema_module
                    acc_ema = self.eval(epoch_id)
                    self.model = ori_model
                    ema_module.eval()

                    if acc_ema > best_metric_ema:
                        best_metric_ema = acc_ema
                        save_load.save_model(
                            self.model,
                            self.optimizer,
                            {"metric": acc_ema,
                             "epoch": epoch_id},
                            self.output_dir,
                            ema=ema_module,
                            model_name=self.config["Arch"]["name"],
                            prefix="best_model_ema",
                            loss=self.train_loss_func)
                    logger.info("[Eval][Epoch {}][best metric ema: {}]".format(
                        epoch_id, best_metric_ema))
                    logger.scaler(
                        name="eval_acc_ema",
                        value=acc_ema,
                        step=epoch_id,
                        writer=self.vdl_writer)

D
dongshuilong 已提交
222
            # save model
D
dongshuilong 已提交
223
            if save_interval > 0 and epoch_id % save_interval == 0:
D
dongshuilong 已提交
224 225 226 227 228
                save_load.save_model(
                    self.model,
                    self.optimizer, {"metric": acc,
                                     "epoch": epoch_id},
                    self.output_dir,
F
flytocc 已提交
229
                    ema=ema_module,
D
dongshuilong 已提交
230
                    model_name=self.config["Arch"]["name"],
231 232
                    prefix="epoch_{}".format(epoch_id),
                    loss=self.train_loss_func)
G
gaotingquan 已提交
233 234 235 236 237 238
            # save the latest model
            save_load.save_model(
                self.model,
                self.optimizer, {"metric": acc,
                                 "epoch": epoch_id},
                self.output_dir,
F
flytocc 已提交
239
                ema=ema_module,
G
gaotingquan 已提交
240
                model_name=self.config["Arch"]["name"],
241 242
                prefix="latest",
                loss=self.train_loss_func)
D
dongshuilong 已提交
243 244 245 246 247 248 249 250

        if self.vdl_writer is not None:
            self.vdl_writer.close()

    @paddle.no_grad()
    def eval(self, epoch_id=0):
        assert self.mode in ["train", "eval"]
        self.model.eval()
D
dongshuilong 已提交
251
        eval_result = self.eval_func(self, epoch_id)
D
dongshuilong 已提交
252 253 254 255 256 257
        self.model.train()
        return eval_result

    @paddle.no_grad()
    def infer(self):
        assert self.mode == "infer" and self.eval_mode == "classification"
G
gaotingquan 已提交
258 259 260 261 262 263

        self.preprocess_func = create_operators(self.config["Infer"][
            "transforms"])
        self.postprocess_func = build_postprocess(self.config["Infer"][
            "PostProcess"])

264 265
        total_trainer = dist.get_world_size()
        local_rank = dist.get_rank()
D
dongshuilong 已提交
266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282
        image_list = get_image_list(self.config["Infer"]["infer_imgs"])
        # data split
        image_list = image_list[local_rank::total_trainer]

        batch_size = self.config["Infer"]["batch_size"]
        self.model.eval()
        batch_data = []
        image_file_list = []
        for idx, image_file in enumerate(image_list):
            with open(image_file, 'rb') as f:
                x = f.read()
            for process in self.preprocess_func:
                x = process(x)
            batch_data.append(x)
            image_file_list.append(image_file)
            if len(batch_data) >= batch_size or idx == len(image_list) - 1:
                batch_tensor = paddle.to_tensor(batch_data)
G
gaotingquan 已提交
283 284 285 286 287 288 289 290 291 292 293

                if self.amp and self.amp_eval:
                    with paddle.amp.auto_cast(
                            custom_black_list={
                                "flatten_contiguous_range", "greater_than"
                            },
                            level=self.amp_level):
                        out = self.model(batch_tensor)
                else:
                    out = self.model(batch_tensor)

D
dongshuilong 已提交
294 295
                if isinstance(out, list):
                    out = out[0]
littletomatodonkey's avatar
littletomatodonkey 已提交
296 297
                if isinstance(out, dict) and "Student" in out:
                    out = out["Student"]
298 299 300
                if isinstance(out, dict) and "logits" in out:
                    out = out["logits"]
                if isinstance(out, dict) and "output" in out:
W
dbg  
weishengyu 已提交
301
                    out = out["output"]
D
dongshuilong 已提交
302 303 304 305 306 307 308
                result = self.postprocess_func(out, image_file_list)
                print(result)
                batch_data.clear()
                image_file_list.clear()

    def export(self):
        assert self.mode == "export"
Z
zhiboniu 已提交
309 310
        use_multilabel = self.config["Global"].get(
            "use_multilabel",
C
cuicheng01 已提交
311
            False) or "ATTRMetric" in self.config["Metric"]["Eval"][0]
C
cuicheng01 已提交
312
        model = ExportModel(self.config["Arch"], self.model, use_multilabel)
313 314 315 316 317 318 319 320 321
        if self.config["Global"]["pretrained_model"] is not None:
            if self.config["Global"]["pretrained_model"].startswith("http"):
                load_dygraph_pretrain_from_url(
                    model.base_model,
                    self.config["Global"]["pretrained_model"])
            else:
                load_dygraph_pretrain(
                    model.base_model,
                    self.config["Global"]["pretrained_model"])
D
dongshuilong 已提交
322 323

        model.eval()
G
gaotingquan 已提交
324

325
        # for re-parameterization nets
H
HydrogenSulfate 已提交
326
        for layer in self.model.sublayers():
327 328 329
            if hasattr(layer, "re_parameterize") and not getattr(layer,
                                                                 "is_repped"):
                layer.re_parameterize()
G
gaotingquan 已提交
330

D
dongshuilong 已提交
331 332
        save_path = os.path.join(self.config["Global"]["save_inference_dir"],
                                 "inference")
littletomatodonkey's avatar
littletomatodonkey 已提交
333 334 335 336 337 338 339 340 341 342 343 344

        model = paddle.jit.to_static(
            model,
            input_spec=[
                paddle.static.InputSpec(
                    shape=[None] + self.config["Global"]["image_shape"],
                    dtype='float32')
            ])
        if hasattr(model.base_model,
                   "quanter") and model.base_model.quanter is not None:
            model.base_model.quanter.save_quantized_model(model,
                                                          save_path + "_int8")
D
dongshuilong 已提交
345 346
        else:
            paddle.jit.save(model, save_path)
G
gaotingquan 已提交
347 348 349
        logger.info(
            f"Export succeeded! The inference model exported has been saved in \"{self.config['Global']['save_inference_dir']}\"."
        )
D
dongshuilong 已提交
350

G
gaotingquan 已提交
351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387
    def _init_vdl(self):
        if self.config['Global'][
                'use_visualdl'] and mode == "train" and dist.get_rank() == 0:
            vdl_writer_path = os.path.join(self.output_dir, "vdl")
            if not os.path.exists(vdl_writer_path):
                os.makedirs(vdl_writer_path)
            return LogWriter(logdir=vdl_writer_path)
        return None

    def _init_seed(self):
        seed = self.config["Global"].get("seed", False)
        if dist.get_world_size() != 1:
            # if self.config["Global"]["distributed"]:
            # set different seed in different GPU manually in distributed environment
            if not seed:
                logger.warning(
                    "The random seed cannot be None in a distributed environment. Global.seed has been set to 42 by default"
                )
                self.config["Global"]["seed"] = seed = 42
            logger.info(
                f"Set random seed to ({int(seed)} + $PADDLE_TRAINER_ID) for different trainer"
            )
            dist_seed = int(seed) + dist.get_rank()
            paddle.seed(dist_seed)
            np.random.seed(dist_seed)
            random.seed(dist_seed)
        elif seed or seed == 0:
            assert isinstance(seed, int), "The 'seed' must be a integer!"
            paddle.seed(seed)
            np.random.seed(seed)
            random.seed(seed)

    def _init_device(self):
        device = self.config["Global"]["device"]
        assert device in ["cpu", "gpu", "xpu", "npu", "mlu", "ascend"]
        logger.info('train with paddle {} and device {}'.format(
            paddle.__version__, device))
388
        paddle.set_device(device)
G
gaotingquan 已提交
389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492

    def _init_pretrained(self):
        if self.config["Global"]["pretrained_model"] is not None:
            if self.config["Global"]["pretrained_model"].startswith("http"):
                load_dygraph_pretrain_from_url(
                    [self.model, getattr(self, 'train_loss_func', None)],
                    self.config["Global"]["pretrained_model"])
            else:
                load_dygraph_pretrain(
                    [self.model, getattr(self, 'train_loss_func', None)],
                    self.config["Global"]["pretrained_model"])

    def _init_amp(self):
        self.amp = "AMP" in self.config and self.config["AMP"] is not None
        self.amp_eval = False
        # for amp
        if self.amp:
            AMP_RELATED_FLAGS_SETTING = {'FLAGS_max_inplace_grad_add': 8, }
            if paddle.is_compiled_with_cuda():
                AMP_RELATED_FLAGS_SETTING.update({
                    'FLAGS_cudnn_batchnorm_spatial_persistent': 1
                })
            paddle.set_flags(AMP_RELATED_FLAGS_SETTING)

            self.scale_loss = self.config["AMP"].get("scale_loss", 1.0)
            self.use_dynamic_loss_scaling = self.config["AMP"].get(
                "use_dynamic_loss_scaling", False)
            self.scaler = paddle.amp.GradScaler(
                init_loss_scaling=self.scale_loss,
                use_dynamic_loss_scaling=self.use_dynamic_loss_scaling)

            self.amp_level = self.config['AMP'].get("level", "O1")
            if self.amp_level not in ["O1", "O2"]:
                msg = "[Parameter Error]: The optimize level of AMP only support 'O1' and 'O2'. The level has been set 'O1'."
                logger.warning(msg)
                self.config['AMP']["level"] = "O1"
                self.amp_level = "O1"

            self.amp_eval = self.config["AMP"].get("use_fp16_test", False)
            # TODO(gaotingquan): Paddle not yet support FP32 evaluation when training with AMPO2
            if self.mode == "train" and self.config["Global"].get(
                    "eval_during_train",
                    True) and self.amp_level == "O2" and self.amp_eval == False:
                msg = "PaddlePaddle only support FP16 evaluation when training with AMP O2 now. "
                logger.warning(msg)
                self.config["AMP"]["use_fp16_test"] = True
                self.amp_eval = True

            # TODO(gaotingquan): to compatible with different versions of Paddle
            paddle_version = paddle.__version__[:3]
            # paddle version < 2.3.0 and not develop
            if paddle_version not in ["2.3", "0.0"]:
                if self.mode == "train":
                    self.model, self.optimizer = paddle.amp.decorate(
                        models=self.model,
                        optimizers=self.optimizer,
                        level=self.amp_level,
                        save_dtype='float32')
                elif self.amp_eval:
                    if self.amp_level == "O2":
                        msg = "The PaddlePaddle that installed not support FP16 evaluation in AMP O2. Please use PaddlePaddle version >= 2.3.0. Use FP32 evaluation instead and please notice the Eval Dataset output_fp16 should be 'False'."
                        logger.warning(msg)
                        self.amp_eval = False
                    else:
                        self.model, self.optimizer = paddle.amp.decorate(
                            models=self.model,
                            level=self.amp_level,
                            save_dtype='float32')
            # paddle version >= 2.3.0 or develop
            else:
                if self.mode == "train" or self.amp_eval:
                    self.model = paddle.amp.decorate(
                        models=self.model,
                        level=self.amp_level,
                        save_dtype='float32')

            if self.mode == "train" and len(self.train_loss_func.parameters(
            )) > 0:
                self.train_loss_func = paddle.amp.decorate(
                    models=self.train_loss_func,
                    level=self.amp_level,
                    save_dtype='float32')

    def _init_dist(self):
        # check the gpu num
        world_size = dist.get_world_size()
        self.config["Global"]["distributed"] = world_size != 1
        # TODO(gaotingquan):
        if self.mode == "train":
            std_gpu_num = 8 if isinstance(
                self.config["Optimizer"],
                dict) and self.config["Optimizer"]["name"] == "AdamW" else 4
            if world_size != std_gpu_num:
                msg = f"The training strategy provided by PaddleClas is based on {std_gpu_num} gpus. But the number of gpu is {world_size} in current training. Please modify the stategy (learning rate, batch size and so on) if use this config to train."
                logger.warning(msg)

        if self.config["Global"]["distributed"]:
            dist.init_parallel_env()
            self.model = paddle.DataParallel(self.model)
            if self.mode == 'train' and len(self.train_loss_func.parameters(
            )) > 0:
                self.train_loss_func = paddle.DataParallel(
                    self.train_loss_func)

D
dongshuilong 已提交
493

W
dbg  
weishengyu 已提交
494
class ExportModel(TheseusLayer):
D
dongshuilong 已提交
495 496 497 498
    """
    ExportModel: add softmax onto the model
    """

C
cuicheng01 已提交
499
    def __init__(self, config, model, use_multilabel):
D
dongshuilong 已提交
500 501 502 503 504 505 506 507 508 509 510 511
        super().__init__()
        self.base_model = model
        # we should choose a final model to export
        if isinstance(self.base_model, DistillationModel):
            self.infer_model_name = config["infer_model_name"]
        else:
            self.infer_model_name = None

        self.infer_output_key = config.get("infer_output_key", None)
        if self.infer_output_key == "features" and isinstance(self.base_model,
                                                              RecModel):
            self.base_model.head = IdentityHead()
C
cuicheng01 已提交
512 513
        if use_multilabel:
            self.out_act = nn.Sigmoid()
D
dongshuilong 已提交
514
        else:
C
cuicheng01 已提交
515 516 517 518
            if config.get("infer_add_softmax", True):
                self.out_act = nn.Softmax(axis=-1)
            else:
                self.out_act = None
D
dongshuilong 已提交
519 520 521 522 523 524 525 526 527 528 529 530 531 532 533

    def eval(self):
        self.training = False
        for layer in self.sublayers():
            layer.training = False
            layer.eval()

    def forward(self, x):
        x = self.base_model(x)
        if isinstance(x, list):
            x = x[0]
        if self.infer_model_name is not None:
            x = x[self.infer_model_name]
        if self.infer_output_key is not None:
            x = x[self.infer_output_key]
C
cuicheng01 已提交
534
        if self.out_act is not None:
wc晨曦's avatar
wc晨曦 已提交
535 536
            if isinstance(x, dict):
                x = x["logits"]
C
cuicheng01 已提交
537
            x = self.out_act(x)
D
dongshuilong 已提交
538
        return x