auto_cast.py 22.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16 17 18
from paddle.fluid.wrapped_decorator import (
    signature_safe_contextmanager,
    wrap_decorator,
)
19 20
from paddle.fluid import core
import contextlib
21 22 23 24 25 26 27 28 29 30
from paddle.fluid.framework import (
    Variable,
    _non_static_mode,
    OpProtoHolder,
    Parameter,
    _dygraph_tracer,
    dygraph_only,
    set_flags,
    get_flags,
)
31 32
import warnings
import copy
33 34 35 36
import functools
import paddle
import operator
import types
37

L
Leo Chen 已提交
38 39
AMP_LEVEL = core.AmpLevel

40
__all__ = ['amp_guard', 'amp_decorate']
41 42 43 44 45 46

# The set of ops that support fp16 calculation and are considered numerically-
# safe and performance-critical. These ops are always converted to fp16.
WHITE_LIST = {
    'conv2d',
    'matmul',
L
Leo Chen 已提交
47
    'matmul_v2',
48
    'mul',
C
cc 已提交
49 50
    'fake_quantize_dequantize_abs_max',
    'fake_quantize_dequantize_moving_average_abs_max',
51 52 53 54 55 56 57 58 59 60 61 62 63 64
}

# The set of ops that support fp16 calculation and are considered numerically-
# dangerous and whose effects may also be observed in downstream ops.
BLACK_LIST = {
    'exp',
    'square',
    'log',
    'mean',
    'sum',
    'cos_sim',
    'softmax',
    'softmax_with_cross_entropy',
    'sigmoid_cross_entropy_with_logits',
65
    'c_softmax_with_cross_entropy',
66 67
    'cross_entropy',
    'cross_entropy2',
68 69
    # default fp32 can avoid return inf when the sum value large than 65504
    'reduce_sum',
70 71 72 73 74 75
    # FP16 performance of grad op is worse than that of FP32. Use FP32 by default.
    'linear_interp_v2',
    'nearest_interp_v2',
    'bilinear_interp_v2',
    'bicubic_interp_v2',
    'trilinear_interp_v2',
76 77 78 79 80 81 82 83 84 85 86 87 88 89
}

AMP_RELATED_FLAGS = [
    'FLAGS_cudnn_exhaustive_search',
    'FLAGS_conv_workspace_size_limit',
    'FLAGS_cudnn_batchnorm_spatial_persistent',
]

AMP_RELATED_FLAGS_SETTING = {
    'FLAGS_cudnn_exhaustive_search': 1,
    'FLAGS_conv_workspace_size_limit': 1000,
    'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
}

90
PURE_FP16_WHITE_LIST = set()
91
PURE_FP16_BLACK_LIST = {
92 93 94 95 96 97 98 99 100 101
    'lookup_table',
    'lookup_table_v2',
    'scatter',
    'scatter_grad',
    # FP16 performance of grad op is worse than that of FP32. Use FP32 by default.
    'linear_interp_v2',
    'nearest_interp_v2',
    'bilinear_interp_v2',
    'bicubic_interp_v2',
    'trilinear_interp_v2',
102
}
103

104
BF16_WHITE_LIST = {'conv2d', 'matmul_v2'}
105
BF16_BLACK_LIST = set()
106

107 108
PURE_BF16_WHITE_LIST = set()
PURE_BF16_BLACK_LIST = set()
109

L
Leo Chen 已提交
110 111 112 113 114 115 116
_g_amp_state_ = None


def amp_state():
    global _g_amp_state_
    return _g_amp_state_

117

118
# NOTE(zhiqiu): similar as paddle.fluid.contrib.mixed_precision.fp16_lists.AutoMixedPrecisionLists._update_list
119
# The reason why not use AutoMixedPrecisionLists is that custom_black_varnames is not suitable for imperative mode.
120 121 122
def _update_list(
    custom_white_list, custom_black_list, level='O1', dtype='float16'
):
123 124 125
    """
    Update black and white list according to users' custom list.
    """
126 127 128 129 130 131 132
    if dtype == 'float16':
        if level == 'O1':
            _white_list = copy.copy(WHITE_LIST)
            _black_list = copy.copy(BLACK_LIST)
        else:
            _white_list = copy.copy(PURE_FP16_WHITE_LIST)
            _black_list = copy.copy(PURE_FP16_BLACK_LIST)
133
    else:
134 135 136 137 138 139
        if level == 'O1':
            _white_list = copy.copy(BF16_WHITE_LIST)
            _black_list = copy.copy(BF16_BLACK_LIST)
        else:
            _white_list = copy.copy(PURE_BF16_WHITE_LIST)
            _black_list = copy.copy(PURE_BF16_BLACK_LIST)
140 141 142
    if custom_white_list and custom_black_list:
        for op_name in custom_white_list:
            if op_name in custom_black_list:
143 144 145
                raise ValueError(
                    "Custom white list overlap " "custom black list"
                )
146 147 148 149 150 151 152 153 154 155 156 157 158
    if custom_white_list:
        for op_name in custom_white_list:
            if op_name in _black_list:
                _black_list.remove(op_name)
            _white_list.add(op_name)
    if custom_black_list:
        for op_name in custom_black_list:
            if op_name in _white_list:
                _white_list.remove(op_name)
            _black_list.add(op_name)
    return _white_list, _black_list


159 160 161 162 163 164
def _in_amp_guard():
    """
    Judge whether current code block is in `amp_guard` context.
    """
    tracer = _dygraph_tracer()
    if tracer:
L
Leo Chen 已提交
165
        if tracer._amp_level == core.AmpLevel.O1:
166 167 168
            return True
        else:
            return False
169 170 171 172
    else:
        return False


173 174 175 176 177
def _in_pure_fp16_guard():
    tracer = _dygraph_tracer()
    return tracer and tracer._amp_level == core.AmpLevel.O2


178 179 180 181 182 183 184 185 186 187 188 189 190 191
def _is_gpu_float16_supported():
    """
    Judge whether current gpu support float16 amp.
    """
    prop = paddle.device.cuda.get_device_capability()
    return prop[0] >= 7


def _is_gpu_bfloat16_supported():
    """
    Judge whether current gpu support bfloat16 amp.
    """
    prop = paddle.device.cuda.get_device_capability()
    cuda_version = paddle.version.cuda()
192
    if cuda_version is not None and cuda_version != 'False':
193 194 195 196 197 198
        cuda_version_check = int(cuda_version.split('.')[0]) >= 11
    else:
        cuda_version_check = False
    return prop[0] >= 8 and cuda_version_check


199
@dygraph_only
200
def pure_fp16_initialize(models):
201 202 203
    for idx in range(len(models)):
        for layer in models[idx].sublayers(include_self=True):
            layer._casted_by_pure_fp16 = True
204
            if (layer._dtype == 'float16') or isinstance(
205 206 207 208 209 210 211 212 213 214
                layer,
                (
                    paddle.nn.BatchNorm,
                    paddle.nn.BatchNorm1D,
                    paddle.nn.BatchNorm2D,
                    paddle.nn.BatchNorm3D,
                    paddle.nn.LayerNorm,
                    paddle.nn.SyncBatchNorm,
                ),
            ):
215
                continue
216 217 218 219 220 221 222
            if isinstance(
                layer,
                (
                    paddle.incubate.nn.FusedFeedForward,
                    paddle.incubate.nn.FusedMultiHeadAttention,
                ),
            ):
223 224
                layer._amp_decorate(dtype='float16')
                continue
225 226 227
            layer._to_impl(
                dtype='float16', include_sublayers=False, floating_only=True
            )
228
    return models
229 230


231 232 233 234
@dygraph_only
def pure_bf16_initialize(models):
    for idx in range(len(models)):
        for layer in models[idx].sublayers(include_self=True):
235 236 237
            layer._to_impl(
                dtype='bfloat16', include_sublayers=False, floating_only=True
            )
238 239 240
    return models


241 242 243 244
def check_models(models):
    for model in models:
        if not isinstance(model, paddle.nn.Layer):
            raise RuntimeError(
245 246 247 248
                "Current train mode is pure fp16, models should be paddle.nn.Layer, but receive {}.".format(
                    type(model)
                )
            )
249 250 251 252
        if isinstance(model, paddle.DataParallel):
            raise RuntimeError(
                "For distributed AMP training, you should first use paddle.amp.decorate() to decotate origin model, and then call paddle.DataParallel get distributed model."
            )
253 254 255 256


def check_optimizers(optimizers):
    for optimizer in optimizers:
257
        if not isinstance(
258 259 260
            optimizer,
            (paddle.optimizer.Optimizer, paddle.fluid.optimizer.Optimizer),
        ):
261
            raise RuntimeError(
262 263 264 265
                "Current train mode is pure fp16, optimizers should be paddle.optimizer.Optimizer or paddle.fluid.optimizer.Optimizer, but receive {}.".format(
                    type(optimizer)
                )
            )
266 267


268 269
@signature_safe_contextmanager
@dygraph_only
270 271 272 273 274 275 276
def amp_guard(
    enable=True,
    custom_white_list=None,
    custom_black_list=None,
    level='O1',
    dtype='float16',
):
277 278 279
    """
    :api_attr: imperative

280
    Create a context which enables auto-mixed-precision(AMP) of operators executed in dynamic graph mode.
281 282 283 284
    If enabled, the input data type (float32 or float16) of each operator is decided
    by autocast algorithm for better performance.

    Commonly, it is used together with `GradScaler` to achieve Auto-Mixed-Precision in
285
    imperative mode. It is used together with `decorator` to achieve Pure fp16 in imperative mode.
286 287 288

    Args:
        enable(bool, optional): Enable auto-mixed-precision or not. Default is True.
289
        custom_white_list(set|list|tuple, optional): The custom white_list. It's the set of ops that support
290
             fp16 calculation and are considered numerically-safe and performance-critical. These ops
291 292
             will be converted to fp16.
        custom_black_list(set|list|tuple, optional): The custom black_list. The set of ops that support fp16
293
             calculation and are considered numerically-dangerous and whose effects may also be
294
             observed in downstream ops. These ops will not be converted to fp16.
295
        level(str, optional): Auto mixed precision level. Accepted values are "O1" and "O2": O1 represent mixed precision, the input data type of each operator will be casted by white_list and black_list;
296
             O2 represent Pure fp16, all operators parameters and input data will be casted to fp16, except operators in black_list, don't support fp16 kernel and batchnorm. Default is O1(amp)
297
        dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'.
298

299

300 301 302 303 304
    Examples:

     .. code-block:: python

        import numpy as np
305
        import paddle
306 307

        data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32')
308 309 310 311
        with paddle.fluid.dygraph.guard():
            conv2d = paddle.fluid.dygraph.Conv2D(3, 2, 3)
            data = paddle.fluid.dygraph.to_variable(data)
            with paddle.fluid.dygraph.amp_guard():
312 313
                conv = conv2d(data)
                print(conv.dtype) # FP16
314
            with paddle.fluid.dygraph.amp_guard(enable=False):
315 316 317 318
                conv = conv2d(data)
                print(conv.dtype) # FP32

    """
L
Leo Chen 已提交
319 320 321 322 323
    amp_state = locals()
    global _g_amp_state_
    original_state = _g_amp_state_
    _g_amp_state_ = amp_state

324 325
    # check amp_level: O0-O2
    level = level.upper()
L
Leo Chen 已提交
326
    if not (level in ['O0', 'O1', 'O2']):
327
        raise ValueError(
328
            "level should be O0, O1 or O2. O0 represents fp32 train mode, O1 represents AMP train mode, O2 represents pure fp16/bf16 train mode."
329 330
        )

331 332 333 334 335 336
    # check amp_dtype: float16 or bfloat16
    dtype = dtype.lower()
    if not (dtype in ['float16', 'bfloat16']):
        raise ValueError("dtype should be 'float16' or 'bfloat16'.")

    # check tracer
337 338 339
    tracer = _dygraph_tracer()
    if not tracer:
        raise ValueError(
340 341
            "current_tracer is None, maybe it is not in imperative mode."
        )
342

343
    # check device_type:
Q
qipengh 已提交
344
    # NOTE: Now, amp only support gpu for float16 and bfloat16, xpu for float16, mlu for float16, npu for float16.
345
    # Maybe we will support cpu for bfloat16.
346 347 348 349 350 351 352
    if enable and not (
        tracer._expected_place.is_gpu_place()
        or tracer._expected_place.is_xpu_place()
        or tracer._expected_place.is_mlu_place()
        or tracer._expected_place.is_npu_place()
        or tracer._expected_place.is_custom_place()
    ):
353
        warnings.warn(
354
            'amp_guard can only be enabled on CUDAPlace, XPUPlace, MLUPlace, NPUPlace, and CustomPlace, current place is %s, so it makes no effect.'
355 356
            % tracer._expected_place
        )
357
        enable = False
F
furnace 已提交
358 359 360 361
    # For npu:
    if tracer._expected_place.is_npu_place() and (dtype == 'bfloat16'):
        warnings.warn('NPUPlace only support float16 amp.')
        enable = False
362 363 364 365
    # For xpu:
    if tracer._expected_place.is_xpu_place() and (dtype == 'bfloat16'):
        warnings.warn('XPUPlace only support float16 amp.')
        enable = False
Q
qipengh 已提交
366 367 368 369
    # For mlu:
    if tracer._expected_place.is_mlu_place() and (dtype == 'bfloat16'):
        warnings.warn('MLUPlace only support float16 amp.')
        enable = False
370 371 372 373
    # For custom device:
    if tracer._expected_place.is_custom_place() and (dtype == 'bfloat16'):
        warnings.warn('CustomPlace only support float16 amp.')
        enable = False
374 375
    # For gpu float16: Compute Capability should >= 7.
    # For gpu bfloat16: Compute Capability should >= 8 & CUDA Version should >= 11.
Z
zhangbo9674 已提交
376
    if tracer._expected_place.is_gpu_place():
377 378
        if (dtype == 'float16') and not _is_gpu_float16_supported():
            prop = paddle.device.cuda.get_device_capability()
Z
zhangbo9674 已提交
379
            warnings.warn(
380
                "For float16, amp only support NVIDIA GPU with Compute Capability 7.0 or higher, current GPU is: %s, with Compute Capability: %d.%d."
381 382
                % (paddle.device.cuda.get_device_name(), prop[0], prop[1])
            )
383 384 385 386 387
        elif (dtype == 'bfloat16') and not _is_gpu_bfloat16_supported():
            prop = paddle.device.cuda.get_device_capability()
            cuda_version = paddle.version.cuda()
            warnings.warn(
                "For bfloat16, amp only support NVIDIA GPU with Compute Capability 8.0 or higher and CUDA Version 11.0 or higher, current GPU is: %s, with Compute Capability: %d.%d, current CUDA Version is: %s."
388 389 390 391 392 393 394
                % (
                    paddle.device.cuda.get_device_name(),
                    prop[0],
                    prop[1],
                    cuda_version,
                )
            )
395 396

    amp_dtype = dtype
Z
zhangbo9674 已提交
397

398
    if level == 'O1':
L
Leo Chen 已提交
399
        amp_level = AMP_LEVEL.O1
400 401 402 403 404 405 406
        if dtype == 'float16':
            _white_list = WHITE_LIST
            _black_list = BLACK_LIST
        elif dtype == 'bfloat16':
            _white_list = BF16_WHITE_LIST
            _black_list = BF16_BLACK_LIST

L
Leo Chen 已提交
407
    elif level == 'O2':
L
Leo Chen 已提交
408
        amp_level = AMP_LEVEL.O2
409 410 411 412 413 414
        if dtype == 'float16':
            _white_list = PURE_FP16_WHITE_LIST
            _black_list = PURE_FP16_BLACK_LIST
        elif dtype == 'bfloat16':
            _white_list = BF16_WHITE_LIST
            _black_list = BF16_BLACK_LIST
L
Leo Chen 已提交
415 416
    elif level == 'O0':
        amp_level = AMP_LEVEL.O0
417 418 419 420 421 422
        if dtype == 'float16':
            _white_list = WHITE_LIST
            _black_list = BLACK_LIST
        elif dtype == 'bfloat16':
            _white_list = BF16_WHITE_LIST
            _black_list = BF16_BLACK_LIST
423

424
    if custom_white_list or custom_black_list:
425 426 427
        _white_list, _black_list = _update_list(
            custom_white_list, custom_black_list, level, dtype
        )
428 429

    if not enable:
L
Leo Chen 已提交
430
        amp_level = AMP_LEVEL.O0
431
        amp_dtype = "float32"
432 433 434

    if tracer:
        # enable auto_cast
435 436 437
        original_amp_level = tracer._amp_level
        tracer._amp_level = amp_level

438 439 440 441 442 443 444 445 446 447 448 449
        # set amp op list
        original_white_list, original_black_list = tracer._get_amp_op_list()
        tracer._set_amp_op_list(_white_list, _black_list)

        # TODO(zhiqiu) set amp related flags automatically in this guard
        # Currently, if FLAGS_cudnn_batchnorm_spatial_persistent is set True in amp_guard,
        # batch_norm can run in fast mode, but batch_norm_grad can not if backward if not executed insise amp_guard.
        # So, users need to set related flags manually.

        # original_flags = get_flags(AMP_RELATED_FLAGS)
        # set_flags(AMP_RELATED_FLAGS_SETTING)

450 451 452 453
        # set amp dtype
        original_amp_dtype = tracer._amp_dtype
        tracer._amp_dtype = amp_dtype

454 455 456 457 458
    # restore status
    try:
        yield
    finally:
        if tracer:
L
Leo Chen 已提交
459
            _g_amp_state_ = original_state
460
            tracer._amp_level = original_amp_level
461 462
            tracer._set_amp_op_list(original_white_list, original_black_list)
            # set_flags(original_flags)
463
            tracer._amp_dtype = original_amp_dtype
464 465


466
class StateDictHook:
467 468 469 470 471 472
    def __init__(self, save_dtype):
        self._save_dtype = save_dtype

    def __call__(self, state_dict):
        for key in state_dict:
            param = state_dict[key]
473
            with paddle.fluid.dygraph.guard():
474 475 476 477
                if paddle.is_floating_point(param):
                    param_applied = paddle.cast(param, self._save_dtype)
                    param_applied.name = param.name
                    state_dict[key] = param_applied
478 479 480


@dygraph_only
481 482 483 484 485 486 487 488
def amp_decorate(
    models,
    optimizers=None,
    level='O1',
    dtype='float16',
    master_weight=None,
    save_dtype=None,
):
489
    """
490
    Decorate models and optimizers for auto-mixed-precision. When level is O1(amp), the decorate will do nothing.
491
    When level is O2(pure fp16), the decorate will cast all parameters of models to FP16, except BatchNorm and LayerNorm.
492

493 494 495 496 497
    Commonly, it is used together with `amp_guard` to achieve Pure fp16 in imperative mode.

    Args:
        models(Layer|list of Layer, optional): The defined models by user, models must be either a single model or a list of models. Default is None.
        optimizers(Optimizer|list of Optimizer, optional): The defined optimizers by user, optimizers must be either a single optimizer or a list of optimizers. Default is None.
498
        level(str, optional): Auto mixed precision level. Accepted values are "O1" and "O2": O1 represent mixed precision, the decorator will do nothing;
499 500
             O2 represent Pure fp16/bf16, the decorator will cast all parameters of models to FP16/BF16, except BatchNorm and LayerNorm. Default is O1(amp)
        dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'.
501
        master_weight(bool, optinal): For level='O2', whether to use multi-precision during weight updating. If master_weight is None, in O2 level optimizer will use multi-precision. Default is None.
502
        save_dtype(float, optional): The save model parameter dtype when use `paddle.save` or `paddle.jit.save`,it should be float16, bfloat16, float32, float64 or None.
503 504 505 506
             The save_dtype will not change model parameters dtype, it just change the state_dict dtype. When save_dtype is None, the save dtype is same as model dtype. Default is None.

    Examples:

507 508
     .. code-block:: python

509 510 511 512 513
        # required: gpu
        # Demo1: single model and optimizer:
        import paddle

        model = paddle.nn.Conv2D(3, 2, 3, bias_attr=False)
514
        optimizer = paddle.optimizer.SGD(parameters=model.parameters())
515

516
        model, optimizer = paddle.fluid.dygraph.amp_decorate(models=model, optimizers=optimizer, level='O2')
517 518 519

        data = paddle.rand([10, 3, 32, 32])

520
        with paddle.fluid.dygraph.amp_guard(enable=True, custom_white_list=None, custom_black_list=None, level='O2'):
521 522 523 524 525 526 527 528
            output = model(data)
            print(output.dtype) # FP16

        # required: gpu
        # Demo2: multi models and optimizers:
        model2 = paddle.nn.Conv2D(3, 2, 3, bias_attr=False)
        optimizer2 = paddle.optimizer.Adam(parameters=model2.parameters())

529
        models, optimizers = paddle.fluid.dygraph.amp_decorate(models=[model, model2], optimizers=[optimizer, optimizer2], level='O2')
530 531 532

        data = paddle.rand([10, 3, 32, 32])

533
        with paddle.fluid.dygraph.amp_guard(enable=True, custom_white_list=None, custom_black_list=None, level='O2'):
534 535 536 537
            output = models[0](data)
            output2 = models[1](data)
            print(output.dtype) # FP16
            print(output2.dtype) # FP16
538

539 540 541 542 543 544 545 546 547 548 549 550
        # required: gpu
        # Demo3: optimizers is None:
        model3 = paddle.nn.Conv2D(3, 2, 3, bias_attr=False)
        optimizer3 = paddle.optimizer.Adam(parameters=model2.parameters())

        model = paddle.fluid.dygraph.amp_decorate(models=model3, level='O2')

        data = paddle.rand([10, 3, 32, 32])

        with paddle.fluid.dygraph.amp_guard(enable=True, custom_white_list=None, custom_black_list=None, level='O2'):
            output = model(data)
            print(output.dtype) # FP16
551 552 553 554 555 556 557
    """
    if not (level in ['O1', 'O2']):
        raise ValueError(
            "level should be O1 or O2, O1 represent AMP train mode, O2 represent Pure fp16 train mode."
        )

    if level == 'O1':
558 559 560 561
        if optimizers is None:
            return models
        else:
            return models, optimizers
562 563 564 565 566 567 568 569 570 571 572

    models_is_list = False
    if isinstance(models, paddle.nn.Layer):
        models_is_list = False
        models = [models]
        check_models(models)
    elif isinstance(models, list):
        check_models(models)
        models_is_list = True
    else:
        raise TypeError(
573 574
            "models must be either a single model or a list of models."
        )
575 576 577 578 579 580
    if dtype == 'float16':
        models = pure_fp16_initialize(models=models)
    elif dtype == 'bfloat16':
        models = pure_bf16_initialize(models=models)
    else:
        raise TypeError("dtype only support float16 or bfloat16.")
581

582 583 584
    if optimizers is not None:
        # check optimizers
        optimizers_is_list = False
585
        if isinstance(
586 587 588
            optimizers,
            (paddle.optimizer.Optimizer, paddle.fluid.optimizer.Optimizer),
        ):
589 590 591 592 593 594 595 596 597 598
            optimizers_is_list = False
            optimizers = [optimizers]
            check_optimizers(optimizers)
        elif isinstance(optimizers, list):
            check_optimizers(optimizers)
            optimizers_is_list = True
        else:
            raise TypeError(
                "optimizers must be either a single optimizer or a list of optimizers."
            )
599
        # supprot master_weight
600 601 602 603 604 605
        for idx_opt in range(len(optimizers)):
            if hasattr(optimizers[idx_opt], '_multi_precision'):
                if master_weight is False:
                    optimizers[idx_opt]._multi_precision = False
                else:
                    optimizers[idx_opt]._multi_precision = True
606 607

    if save_dtype is not None:
608
        if not (save_dtype in ['float16', 'bfloat16', 'float32', 'float64']):
609 610
            raise ValueError(
                "save_dtype can only be float16 float32 or float64, but your input save_dtype is %s."
611 612
                % save_dtype
            )
613 614 615 616 617
        for idx in range(len(models)):
            for layer in models[idx].sublayers(include_self=True):
                layer.register_state_dict_hook(StateDictHook(save_dtype))

    if models_is_list:
618 619 620 621 622
        if optimizers is not None:
            if optimizers_is_list:
                return models, optimizers
            else:
                return models, optimizers[0]
623
        else:
624
            return models
625
    else:
626 627 628 629 630
        if optimizers is not None:
            if optimizers_is_list:
                return models[0], optimizers
            else:
                return models[0], optimizers[0]
631
        else:
632
            return models[0]