grad_scaler.py 48.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import warnings
16
from collections import defaultdict
17
from enum import Enum
18

19
import numpy as np
20

21
from paddle import _legacy_C_ops
22 23 24 25 26 27 28 29 30 31
from paddle.fluid import core, in_dygraph_mode
from paddle.fluid.data_feeder import check_type
from paddle.fluid.dygraph import to_variable
from paddle.fluid.framework import _dygraph_tracer, dygraph_only


class OptimizerState(Enum):
    INIT = 0
    UNSCALED = 1
    STEPPED = 2
32 33


34 35 36 37
def _refresh_optimizer_state():
    return {"state": OptimizerState.INIT}


38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230
class AmpScaler:
    """
    AmpScaler is used for Auto-Mixed-Precision training/inferring in imperative
    mode. It controls the scaling of loss, helps avoiding numerical overflow.
    The object of this class has seventeen methods `scale()`, `unscale_()`, `minimize()` and `get`/`set` api of parameters.

    `scale()` is used to multiply the loss by a scale ratio.
    `unscale_()` is used to unscale the gradients of parameters, multiplies the gradients of parameters by 1/(scale ratio)
    `minimize()` is similar as `optimizer.minimize()`, performs parameters updating, and it will update the loss_scaling.

    Commonly, it is used together with `amp_guard` to achieve Auto-Mixed-Precision in
    imperative mode.

    Args:
        enable(bool, optional): Enable loss scaling or not. Default is True.
        init_loss_scaling (float, optional): The initial loss scaling factor. Default is 2**15.
        incr_ratio(float, optional): The multiplier to use when increasing the loss
                        scaling. Default is 2.0.
        decr_ratio(float, optional): The less-than-one-multiplier to use when decreasing
                        the loss scaling. Default is 0.5.
        incr_every_n_steps(int, optional): Increases loss scaling every n consecutive
                                steps with finite gradients. Default is 1000.
        decr_every_n_nan_or_inf(int, optional): Decreases loss scaling every n
                                    accumulated steps with nan or inf gradients. Default is 2.
        use_dynamic_loss_scaling(bool, optional): Whether to use dynamic loss scaling. If False, fixed loss_scaling is used. If True, the loss scaling is updated dynamicly. Default is True.
    Returns:
        An AmpScaler object.

    Examples:

     .. code-block:: python

        import numpy as np
        import paddle

        data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32')
        model = paddle.nn.Conv2D(3, 2, 3)
        optimizer = paddle.optimizer.SGDOptimizer(
                learning_rate=0.01, parameter_list=model.parameters())
        scaler = paddle.amp.AmpScaler(init_loss_scaling=1024)
        data = paddle.to_tensor(data)
        with paddle.amp.amp_guard():
            conv = model(data)
            loss = paddle.mean(conv)
            scaled = scaler.scale(loss)
            scaled.backward()
            scaler.minimize(optimizer, scaled)
    """

    @dygraph_only
    def __init__(
        self,
        enable=True,
        init_loss_scaling=2.0**15,
        incr_ratio=2.0,
        decr_ratio=0.5,
        incr_every_n_steps=1000,
        decr_every_n_nan_or_inf=1,
        use_dynamic_loss_scaling=True,
    ):

        tracer = _dygraph_tracer()
        if not tracer:
            raise ValueError(
                "current_tracer is None, maybe it is not in imperative mode."
            )

        if enable and not (
            tracer._expected_place.is_gpu_place()
            or tracer._expected_place.is_xpu_place()
            or tracer._expected_place.is_mlu_place()
            or tracer._expected_place.is_npu_place()
            or tracer._expected_place.is_custom_place()
        ):
            warnings.warn(
                'AmpScaler can only be enabled on CUDAPlace, XPUPlace, MLUPlace, NPUPlace and CustomPlace, current place is %s, so it makes no effect.'
                % tracer._expected_place
            )
            enable = False

        self._enable = enable

        if self._enable:
            assert incr_ratio > 1.0, "The incr_ratio must be > 1.0."
            assert decr_ratio < 1.0, "The decr_ratio must be < 1.0."

            self._init_loss_scaling = init_loss_scaling
            self._incr_ratio = incr_ratio
            self._decr_ratio = decr_ratio
            self._incr_every_n_steps = incr_every_n_steps
            self._decr_every_n_nan_or_inf = decr_every_n_nan_or_inf
            self._incr_count = 0
            self._decr_count = 0
            self._use_dynamic_loss_scaling = use_dynamic_loss_scaling

            self._found_inf = to_variable(np.array([0]).astype(np.bool_))
            self._temp_found_inf_fp16 = to_variable(
                np.array([0]).astype(np.bool_)
            )
            self._temp_found_inf_bf16 = to_variable(
                np.array([0]).astype(np.bool_)
            )
            self._temp_found_inf_fp32 = to_variable(
                np.array([0]).astype(np.bool_)
            )
            self._scale = to_variable(
                np.array([self._init_loss_scaling]).astype(np.float32)
            )
            self._cache_founf_inf = None
            self._optimizer_states = defaultdict(_refresh_optimizer_state)

    def scale(self, var):
        """
        Multiplies a Tensor by the scale factor and returns scaled outputs.
        If this instance of :class:`AmpScaler` is not enabled, output are returned unmodified.

        Args:
            var (Tensor):  The Tensor to scale.
        Returns:
            The scaled Tensor or original Tensor.

        Examples:

            .. code-block:: python

                import numpy as np
                import paddle

                data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32')
                model = paddle.nn.Conv2D(3, 2, 3)
                optimizer = paddle.optimizer.SGDOptimizer(
                        learning_rate=0.01, parameter_list=model.parameters())
                scaler = paddle.amp.AmpScaler(init_loss_scaling=1024)
                data = paddle.to_tensor(data)
                with paddle.amp.amp_guard():
                    conv = model(data)
                    loss = paddle.mean(conv)
                    scaled = scaler.scale(loss)
                    scaled.backward()
                    scaler.minimize(optimizer, scaled)
        """
        check_type(var, "var", core.VarBase, 'AmpScaler.scale()')

        if not self._enable:
            return var

        return var * self._scale

    def minimize(self, optimizer, *args, **kwargs):
        """
        This function is similar as `Optimizer.minimize()`, which performs parameters updating.

        If the scaled gradients of parameters contains NAN or INF, the parameters updating is skipped.
        Otherwise, if `unscale_()` has not been called, it first unscales the scaled gradients of parameters, then updates the parameters.

        Finally, the loss scaling ratio is updated.

        Args:
            optimizer(Optimizer):  The optimizer used to update parameters.
            args:  Arguments, which will be forward to `optimizer.minimize()`.
            kwargs: Keyword arguments, which will be forward to `Optimizer.minimize()`.

        Examples:

            .. code-block:: python

                import numpy as np
                import paddle

                data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32')
                model = paddle.nn.Conv2D(3, 2, 3)
                optimizer = paddle.optimizer.SGDOptimizer(
                        learning_rate=0.01, parameter_list=model.parameters())
                scaler = paddle.amp.AmpScaler(init_loss_scaling=1024)
                data = paddle.to_tensor(data)
                with paddle.amp.amp_guard():
                    conv = model(data)
                    loss = paddle.mean(conv)
                    scaled = scaler.scale(loss)
                    scaled.backward()
                    scaler.minimize(optimizer, scaled)
        """
        if not self._enable:
            return optimizer.minimize(*args, **kwargs)

        optimizer_state = self._optimizer_states[id(optimizer)]

        #  unscale the grad
        if optimizer_state["state"] is OptimizerState.INIT:
            self._unscale(optimizer)

        optimize_ops, params_grads = (None, None)

231 232
        if self._found_inf:
            self._cache_founf_inf = True
233
        else:
234 235
            optimize_ops, params_grads = optimizer.minimize(*args, **kwargs)
            self._cache_founf_inf = False
236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370

        if self._use_dynamic_loss_scaling:
            # uopdate the scale
            self._update()

        self._optimizer_states = defaultdict(_refresh_optimizer_state)

        return optimize_ops, params_grads

    def _unscale(self, optimizer):
        """
        Unscale the gradients of parameters, multiplies the gradients of parameters by 1/(loss scaling ratio).
        If this instance of :class:`GradScaler` is not enabled, output are returned unmodified.
        Args:
            optimizer(Optimizer):  The optimizer used to update parameters.
        Returns:
            The unscaled parameters or original parameters.
        """
        if not self._enable:
            return

        optimizer_state = self._optimizer_states[id(optimizer)]

        if optimizer_state["state"] is OptimizerState.UNSCALED:
            raise RuntimeError(
                "unscale_() has already been called on this optimizer since the last update()."
            )
        elif optimizer_state["state"] is OptimizerState.STEPPED:
            raise RuntimeError("unscale_() is being called after step().")

        if getattr(optimizer, '_param_groups', None) and isinstance(
            optimizer._param_groups[0], dict
        ):
            param_grads = []
            param_grads_fp16 = []
            param_grads_bf16 = []
            param_grads_fp32 = []
            for group in optimizer._param_groups:
                for param in group['params']:
                    if param._grad_ivar() is not None:
                        param_grads.append(param._grad_ivar())
                        if (
                            param._grad_ivar().dtype
                            == core.VarDesc.VarType.FP16
                        ):
                            param_grads_fp16.append(param._grad_ivar())
                        elif (
                            param._grad_ivar().dtype
                            == core.VarDesc.VarType.BF16
                        ):
                            param_grads_bf16.append(param._grad_ivar())
                        else:
                            param_grads_fp32.append(param._grad_ivar())
        else:
            if in_dygraph_mode():
                # It is very time-consuming to call c++ functions in a loop on the python side.
                # We put this part of the code on the c++ side to improve the speed in eager mode.
                (
                    param_grads_fp16,
                    param_grads_bf16,
                    param_grads_fp32,
                ) = core.eager.get_grads_lists(optimizer._parameter_list)
            else:
                # Keep the original code to support legacy mode.
                # Delete the else branch when the legacy mode exits.
                param_grads = [
                    param._grad_ivar()
                    for param in optimizer._parameter_list
                    if param._grad_ivar() is not None
                ]
                param_grads_fp16 = [
                    param
                    for param in param_grads
                    if param.dtype == core.VarDesc.VarType.FP16
                ]
                param_grads_bf16 = [
                    param
                    for param in param_grads
                    if param.dtype == core.VarDesc.VarType.BF16
                ]
                param_grads_fp32 = [
                    param
                    for param in param_grads
                    if param.dtype == core.VarDesc.VarType.FP32
                ]
        if core.is_compiled_with_npu():
            float_status = _legacy_C_ops.alloc_float_status()
            _legacy_C_ops.clear_float_status(float_status, float_status)

            if len(param_grads_fp16):
                _legacy_C_ops.check_finite_and_unscale(
                    param_grads_fp16,
                    self._scale,
                    float_status,
                    param_grads_fp16,
                    self._temp_found_inf_fp16,
                )
            if len(param_grads_bf16):
                _legacy_C_ops.check_finite_and_unscale(
                    param_grads_bf16,
                    self._scale,
                    float_status,
                    param_grads_bf16,
                    self._temp_found_inf_bf16,
                )
            if len(param_grads_fp32):
                _legacy_C_ops.check_finite_and_unscale(
                    param_grads_fp32,
                    self._scale,
                    float_status,
                    param_grads_fp32,
                    self._temp_found_inf_fp32,
                )
        else:
            if len(param_grads_fp16):
                _legacy_C_ops.check_finite_and_unscale(
                    param_grads_fp16,
                    self._scale,
                    param_grads_fp16,
                    self._temp_found_inf_fp16,
                )
            if len(param_grads_bf16):
                _legacy_C_ops.check_finite_and_unscale(
                    param_grads_bf16,
                    self._scale,
                    param_grads_bf16,
                    self._temp_found_inf_bf16,
                )
            if len(param_grads_fp32):
                _legacy_C_ops.check_finite_and_unscale(
                    param_grads_fp32,
                    self._scale,
                    param_grads_fp32,
                    self._temp_found_inf_fp32,
                )
371 372 373 374 375 376

        self._found_inf = (
            self._temp_found_inf_fp16
            or self._temp_found_inf_bf16
            or self._temp_found_inf_fp32
        )
377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580

        optimizer_state["state"] = OptimizerState.UNSCALED

    def _update(self):
        """
        Updates the loss_scaling.
        """
        if not self._enable:
            return

        if self._cache_founf_inf:
            self._incr_count = 0
            self._decr_count = self._decr_count + 1
            if self._decr_count == self._decr_every_n_nan_or_inf:
                print(
                    'Found inf or nan, current scale is: {}, decrease to: {}*{}'.format(
                        float(self._scale),
                        float(self._scale),
                        float(self._decr_ratio),
                    )
                )
                self._scale = self._scale * self._decr_ratio
                self._decr_count = 0
        else:
            self._decr_count = 0
            self._incr_count = self._incr_count + 1
            if self._incr_count == self._incr_every_n_steps:
                self._scale = self._scale * self._incr_ratio
                self._incr_count = 0

        return

    def is_enable(self):
        """
        Enable loss scaling or not.

        Returns:
            bool: enable loss scaling return True else return False.
        """
        return self._enable

    def is_use_dynamic_loss_scaling(self):
        """
        Whether to use dynamic loss scaling.

        Returns:
            bool: if fixed loss_scaling is used return False, if the loss scaling is updated dynamicly return true.
        """
        return self._use_dynamic_loss_scaling

    def get_init_loss_scaling(self):
        """
        Return the initial loss scaling factor.

        Reurns:
            float:  the initial loss scaling factor.
        """
        return self._init_loss_scaling

    def set_init_loss_scaling(self, new_init_loss_scaling):
        """
        Set the initial loss scaling factor by `new_init_loss_scaling`.

        Args:
            new_init_loss_scaling(int):  The new_init_loss_scaling used to update initial loss scaling factor.s
        """
        self._init_loss_scaling = new_init_loss_scaling
        self._scale = to_variable(
            np.array([self._init_loss_scaling]).astype(np.float32)
        )

    def get_incr_ratio(self):
        """
        Return the multiplier to use when increasing the loss scaling.

        Reurns:
            float:  the multiplier to use when increasing the loss scaling.
        """
        return self._incr_ratio

    def set_incr_ratio(self, new_incr_ratio):
        """
        Set the multiplier to use when increasing the loss scaling by `new_incr_ratio`, `new_incr_ratio` should > 1.0.

        Args:
            new_incr_ratio(float):  The new_incr_ratio used to update the multiplier to use when increasing the loss scaling.
        """
        assert new_incr_ratio > 1.0, "The new_incr_ratio must be > 1.0."
        self._incr_ratio = new_incr_ratio

    def get_decr_ratio(self):
        """
        Get the less-than-one-multiplier to use when decreasing the loss scaling.

        Reurns:
            float:  the less-than-one-multiplier to use when decreasing the loss scaling.
        """
        return self._decr_ratio

    def set_decr_ratio(self, new_decr_ratio):
        """
        Set the less-than-one-multiplier to use when decreasing the loss scaling by `new_incr_ratio`, `new_decr_ratio` should < 1.0.

        Args:
            new_decr_ratio(float):  The new_decr_ratio used to update the less-than-one-multiplier to use when decreasing the loss scaling.
        """
        assert new_decr_ratio < 1.0, "The new_decr_ratio must be < 1.0."
        self._decr_ratio = new_decr_ratio

    def get_incr_every_n_steps(self):
        """
        Return the num `n`, `n` represent increases loss scaling every `n` consecutive steps with finite gradients.

        Reurns:
            int:  the num `n`, `n` represent increases loss scaling every `n` consecutive steps with finite gradients.
        """
        return self._incr_every_n_steps

    def set_incr_every_n_steps(self, new_incr_every_n_steps):
        """
        Set the num `n` by `new_incr_every_n_steps`, `n` represent increases loss scaling every `n` consecutive steps with finite gradients.

        Args:
            new_incr_every_n_steps(int):  The new_incr_every_n_steps used to update the num `n`, `n` represent increases loss scaling every `n` consecutive steps with finite gradients.
        """
        self._incr_every_n_steps = new_incr_every_n_steps

    def get_decr_every_n_nan_or_inf(self):
        """
        Return the num `n`, `n` represent decreases loss scaling every `n` accumulated steps with nan or inf gradients.

        Reurns:
            int:  the num `n`, `n` represent decreases loss scaling every `n` accumulated steps with nan or inf gradients.
        """
        return self._decr_every_n_nan_or_inf

    def set_decr_every_n_nan_or_inf(self, new_decr_every_n_nan_or_inf):
        """
        Set the num `n` by `new_decr_every_n_nan_or_inf`, `n` represent decreases loss scaling every `n` accumulated steps with nan or inf gradients.

        Args:
            new_decr_every_n_nan_or_inf(int):  The new_decr_every_n_nan_or_inf used to update the num `n`, `n` represent decreases loss scaling every `n` accumulated steps with nan or inf gradients.
        """
        self._decr_every_n_nan_or_inf = new_decr_every_n_nan_or_inf

    def state_dict(self):
        """
        Returns the state of the scaler as a `dict`, If this instance is not enabled, returns an empty dict.

        Reurns:
            A dict of scaler includes:
            scale (tensor): The loss scaling factor.
            incr_ratio(float): The multiplier to use when increasing the loss scaling.
            decr_ratio(float): The less-than-one-multiplier to use when decreasing the loss scaling.
            incr_every_n_steps(int): Increases loss scaling every n consecutive steps with finite gradients.
            decr_every_n_nan_or_inf(int): Decreases loss scaling every n accumulated steps with nan or inf gradients.
            incr_count(int): The number of recent consecutive unskipped steps.
            decr_count(int): The number of recent consecutive skipped steps.
            use_dynamic_loss_scaling(bool): Whether to use dynamic loss scaling. If False, fixed loss_scaling is used. If True, the loss scaling is updated dynamicly. Default is True.
        """
        return (
            {
                "scale": self._scale.numpy(),
                "incr_ratio": self._incr_ratio,
                "decr_ratio": self._decr_ratio,
                "incr_every_n_steps": self._incr_every_n_steps,
                "decr_every_n_nan_or_inf": self._decr_every_n_nan_or_inf,
                "incr_count": self._incr_count,
                "decr_count": self._decr_count,
                "use_dynamic_loss_scaling": self._use_dynamic_loss_scaling,
            }
            if self._enable
            else {}
        )

    def load_state_dict(self, state_dict):
        """
        Loads the scaler state.

        Args:
           state_dict(dict): scaler state.  Should be an object returned from a call to `AmpScaler.state_dict()`.
        """
        if not self._enable:
            return

        if len(state_dict) == 0:
            raise RuntimeError(
                "The input state dict is empty, possibly because it was saved "
                "from a disabled instance of GradScaler."
            )

        self._init_loss_scaling = state_dict["scale"][0]
        self._scale = to_variable(
            np.array([self._init_loss_scaling]).astype(np.float32)
        )
        self._incr_ratio = state_dict["incr_ratio"]
        self._decr_ratio = state_dict["decr_ratio"]
        self._incr_every_n_steps = state_dict["incr_every_n_steps"]
        self._decr_every_n_nan_or_inf = state_dict["decr_every_n_nan_or_inf"]
        self._incr_count = state_dict["incr_count"]
        self._decr_count = state_dict["decr_count"]
        self._use_dynamic_loss_scaling = state_dict["use_dynamic_loss_scaling"]


581 582
class GradScaler(AmpScaler):
    """
583
    GradScaler is used for Auto-Mixed-Precision training in dynamic graph mode.
584
    It controls the scaling of loss, helps avoiding numerical overflow.
585
    The object of this class has nineteen methods `scale()`, `unscale_()`, `minimize()`, `step()`, `update()` and `get`/`set` api of parameters.
586 587

    `scale()` is used to multiply the loss by a scale ratio.
588 589 590 591 592
    `unscale_()` is used to unscale the gradients of parameters, multiplies the gradients of parameters by 1/(scale ratio)
    `minimize()` is similar as `optimizer.minimize()`, performs parameters updating, and it will update the loss_scaling, it equal to `step()` + `update()`.
    `step()` is similar as `optimizer.step()`, which performs parameters updating.
    `update` is used to update the loss_scaling.

593

594
    Commonly, it is used together with `paddle.amp.auto_cast` to achieve Auto-Mixed-Precision in
595 596 597 598 599
    dynamic graph mode.

    Args:
        enable(bool, optional): Enable loss scaling or not. Default is True.
        init_loss_scaling (float, optional): The initial loss scaling factor. Default is 2**15.
600
        incr_ratio(float, optional): The multiplier to use when increasing the loss
601
                        scaling. Default is 2.0.
602
        decr_ratio(float, optional): The less-than-one-multiplier to use when decreasing
603
                        the loss scaling. Default is 0.5.
604
        incr_every_n_steps(int, optional): Increases loss scaling every n consecutive
605
                                steps with finite gradients. Default is 1000.
606
        decr_every_n_nan_or_inf(int, optional): Decreases loss scaling every n
607 608 609
                                    accumulated steps with nan or inf gradients. Default is 2.
        use_dynamic_loss_scaling(bool, optional): Whether to use dynamic loss scaling. If False, fixed loss_scaling is used. If True, the loss scaling is updated dynamicly. Default is True.
    Returns:
610
        An GradScaler object.
611 612 613

    Examples:

614
        .. code-block:: python
615

616
            import paddle
617

618 619 620 621
            model = paddle.nn.Conv2D(3, 2, 3, bias_attr=True)
            optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())
            scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
            data = paddle.rand([10, 3, 32, 32])
L
Leo Chen 已提交
622

623 624 625
            with paddle.amp.auto_cast():
                conv = model(data)
                loss = paddle.mean(conv)
626 627

            scaled = scaler.scale(loss)  # scale the loss
L
Leo Chen 已提交
628
            scaled.backward()            # do backward
629
            scaler.minimize(optimizer, scaled)  # update parameters
630
            optimizer.clear_grad()
631 632
    """

633 634 635 636 637 638 639 640 641 642
    def __init__(
        self,
        enable=True,
        init_loss_scaling=2.0**15,
        incr_ratio=2.0,
        decr_ratio=0.5,
        incr_every_n_steps=1000,
        decr_every_n_nan_or_inf=2,
        use_dynamic_loss_scaling=True,
    ):
643
        super().__init__(
644 645 646 647 648 649 650 651
            enable,
            init_loss_scaling,
            incr_ratio,
            decr_ratio,
            incr_every_n_steps,
            decr_every_n_nan_or_inf,
            use_dynamic_loss_scaling,
        )
652 653 654

    def scale(self, var):
        """
655
        Multiplies a Tensor by the scale factor and returns scaled outputs.
656 657 658 659 660 661
        If this instance of :class:`GradScaler` is not enabled, output are returned unmodified.

        Args:
            var (Tensor):  The tensor to scale.
        Returns:
            The scaled tensor or original tensor.
662

663
        Examples:
L
Leo Chen 已提交
664

665
            .. code-block:: python
666

667 668 669 670 671 672
                import paddle

                model = paddle.nn.Conv2D(3, 2, 3, bias_attr=True)
                optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())
                scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
                data = paddle.rand([10, 3, 32, 32])
L
Leo Chen 已提交
673

674 675 676
                with paddle.amp.auto_cast():
                    conv = model(data)
                    loss = paddle.mean(conv)
L
Leo Chen 已提交
677

678
                scaled = scaler.scale(loss)  # scale the loss
L
Leo Chen 已提交
679
                scaled.backward()            # do backward
680
                scaler.minimize(optimizer, scaled)  # update parameters
681
                optimizer.clear_grad()
682
        """
683
        return super().scale(var)
684 685 686

    def minimize(self, optimizer, *args, **kwargs):
        """
687
        This function is similar as `optimizer.minimize()`, which performs parameters updating.
688

689
        If the scaled gradients of parameters contains NAN or INF, the parameters updating is skipped.
690
        Otherwise, if `unscale_()` has not been called, it first unscales the scaled gradients of parameters, then updates the parameters.
691 692 693 694 695 696

        Finally, the loss scaling ratio is updated.

        Args:
            optimizer(Optimizer):  The optimizer used to update parameters.
            args:  Arguments, which will be forward to `optimizer.minimize()`.
697
            kwargs: Keyword arguments, which will be forward to `optimizer.minimize()`.
698 699

        Examples:
L
Leo Chen 已提交
700

701 702
            .. code-block:: python

703 704 705 706 707 708
                import paddle

                model = paddle.nn.Conv2D(3, 2, 3, bias_attr=True)
                optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())
                scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
                data = paddle.rand([10, 3, 32, 32])
L
Leo Chen 已提交
709

710 711 712
                with paddle.amp.auto_cast():
                    conv = model(data)
                    loss = paddle.mean(conv)
L
Leo Chen 已提交
713

714
                scaled = scaler.scale(loss)  # scale the loss
L
Leo Chen 已提交
715
                scaled.backward()            # do backward
716
                scaler.minimize(optimizer, scaled)  # update parameters
717
                optimizer.clear_grad()
718
        """
719
        return super().minimize(optimizer, *args, **kwargs)
720

721 722 723
    def step(self, optimizer):
        """
        This function is similar as `optimizer.step()`, which performs parameters updating.
724

725
        If the scaled gradients of parameters contains NAN or INF, the parameters updating is skipped.
726
        Otherwise, if `unscale_()` has not been called, it first unscales the scaled gradients of parameters, then updates the parameters.
727 728 729 730 731

        Args:
            optimizer(Optimizer):  The optimizer used to update parameters.

        Examples:
732

733
            .. code-block:: python
734

735 736
                # required: gpu
                import paddle
737

738 739 740 741 742 743 744
                model = paddle.nn.Conv2D(3, 2, 3, bias_attr=True)
                optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())
                scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
                data = paddle.rand([10, 3, 32, 32])
                with paddle.amp.auto_cast():
                    conv = model(data)
                    loss = paddle.mean(conv)
745
                scaled = scaler.scale(loss)  # scale the loss
746
                scaled.backward()            # do backward
747 748
                scaler.step(optimizer)       # update parameters
                scaler.update()              # update the loss scaling ratio
749 750 751 752 753
                optimizer.clear_grad()
        """
        if not self._enable:
            return optimizer.step()

754 755 756
        optimizer_state = self._optimizer_states[id(optimizer)]
        if optimizer_state["state"] is OptimizerState.STEPPED:
            raise RuntimeError(
757 758
                "step() has already been called since the last update()."
            )
759

760
        #  unscale the grad
761 762
        if optimizer_state["state"] is OptimizerState.INIT:
            self._unscale(optimizer)
763

764 765
        if self._found_inf:
            self._cache_founf_inf = True
766
        else:
767 768
            optimizer.step()
            self._cache_founf_inf = False
769

770 771 772 773 774 775 776 777
        optimizer_state["state"] = OptimizerState.STEPPED

        if not self._use_dynamic_loss_scaling:
            self._optimizer_states = defaultdict(_refresh_optimizer_state)

    def update(self):
        """
        Updates the loss_scaling.
778

779 780 781
        Examples:

            .. code-block:: python
782

783 784 785 786 787 788 789 790 791 792
                # required: gpu
                import paddle

                model = paddle.nn.Conv2D(3, 2, 3, bias_attr=True)
                optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())
                scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
                data = paddle.rand([10, 3, 32, 32])
                with paddle.amp.auto_cast():
                    conv = model(data)
                    loss = paddle.mean(conv)
793
                scaled = scaler.scale(loss)     # scale the loss
794 795 796
                scaled.backward()               # do backward
                scaler.step(optimizer)          # update parameters
                scaler.update()                 # update the loss scaling ratio
797
                optimizer.clear_grad()
798 799 800
        """
        if not self._enable:
            return
801 802
        if self._use_dynamic_loss_scaling:
            self._update()
803 804 805 806 807
            self._optimizer_states = defaultdict(_refresh_optimizer_state)
        return

    def unscale_(self, optimizer):
        """
808
        Unscale the gradients of parameters, multiplies the gradients of parameters by 1/(loss scaling ratio).
809 810 811 812 813 814 815
        If this instance of :class:`GradScaler` is not enabled, output are returned unmodified.

        Args:
            optimizer(Optimizer):  The optimizer used to update parameters.

        Returns:
            The unscaled parameters or original parameters.
816

817 818 819 820 821 822 823 824 825 826 827 828 829 830
        Examples:

            .. code-block:: python

                # required: gpu
                import paddle

                model = paddle.nn.Conv2D(3, 2, 3, bias_attr=True)
                optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())
                scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
                data = paddle.rand([10, 3, 32, 32])
                with paddle.amp.auto_cast():
                    conv = model(data)
                    loss = paddle.mean(conv)
831
                scaled = scaler.scale(loss)  # scale the loss
832 833 834
                scaled.backward()            # do backward
                scaler.unscale_(optimizer)    # unscale the parameter
                scaler.step(optimizer)
835 836
                scaler.update()
                optimizer.clear_grad()
837
        """
838
        return super()._unscale(optimizer)
839

840 841 842 843 844 845
    def is_enable(self):
        """
        Enable loss scaling or not.

        Returns:
            bool: enable loss scaling return True else return False.
846

847 848 849
        Examples:
            .. code-block:: python

850
                # required: gpu,xpu
851 852 853 854 855 856 857 858 859 860 861
                import paddle
                scaler = paddle.amp.GradScaler(enable=True,
                                               init_loss_scaling=1024,
                                               incr_ratio=2.0,
                                               decr_ratio=0.5,
                                               incr_every_n_steps=1000,
                                               decr_every_n_nan_or_inf=2,
                                               use_dynamic_loss_scaling=True)
                enable = scaler.is_enable()
                print(enable) # True
        """
862
        return super().is_enable()
863 864 865 866 867 868 869

    def is_use_dynamic_loss_scaling(self):
        """
        Whether to use dynamic loss scaling.

        Returns:
            bool: if fixed loss_scaling is used return False, if the loss scaling is updated dynamicly return true.
870

871 872
        Examples:
            .. code-block:: python
873

874
                # required: gpu,xpu
875 876 877 878 879 880 881 882 883 884 885
                import paddle
                scaler = paddle.amp.GradScaler(enable=True,
                                               init_loss_scaling=1024,
                                               incr_ratio=2.0,
                                               decr_ratio=0.5,
                                               incr_every_n_steps=1000,
                                               decr_every_n_nan_or_inf=2,
                                               use_dynamic_loss_scaling=True)
                use_dynamic_loss_scaling = scaler.is_use_dynamic_loss_scaling()
                print(use_dynamic_loss_scaling) # True
        """
886
        return super().is_use_dynamic_loss_scaling()
887 888 889 890 891 892 893

    def get_init_loss_scaling(self):
        """
        Return the initial loss scaling factor.

        Reurns:
            float:  the initial loss scaling factor.
894

895 896 897
        Examples:
            .. code-block:: python

898
                # required: gpu,xpu
899 900 901 902 903 904 905 906 907 908 909
                import paddle
                scaler = paddle.amp.GradScaler(enable=True,
                                               init_loss_scaling=1024,
                                               incr_ratio=2.0,
                                               decr_ratio=0.5,
                                               incr_every_n_steps=1000,
                                               decr_every_n_nan_or_inf=2,
                                               use_dynamic_loss_scaling=True)
                init_loss_scaling = scaler.get_init_loss_scaling()
                print(init_loss_scaling) # 1024
        """
910
        return super().get_init_loss_scaling()
911 912 913 914 915 916

    def set_init_loss_scaling(self, new_init_loss_scaling):
        """
        Set the initial loss scaling factor by `new_init_loss_scaling`.

        Args:
917
            new_init_loss_scaling(float):  The new_init_loss_scaling used to update initial loss scaling factor.
918

919 920
        Examples:
            .. code-block:: python
921

922
                # required: gpu,xpu
923 924 925 926 927 928 929 930 931 932 933 934 935
                import paddle
                scaler = paddle.amp.GradScaler(enable=True,
                                               init_loss_scaling=1024,
                                               incr_ratio=2.0,
                                               decr_ratio=0.5,
                                               incr_every_n_steps=1000,
                                               decr_every_n_nan_or_inf=2,
                                               use_dynamic_loss_scaling=True)
                print(scaler.get_init_loss_scaling()) # 1024
                new_init_loss_scaling = 1000
                scaler.set_init_loss_scaling(new_init_loss_scaling)
                print(scaler.get_init_loss_scaling()) # 1000
        """
936
        super().set_init_loss_scaling(new_init_loss_scaling)
937 938 939 940 941 942 943

    def get_incr_ratio(self):
        """
        Return the multiplier to use when increasing the loss scaling.

        Reurns:
            float:  the multiplier to use when increasing the loss scaling.
944

945 946 947
        Examples:
            .. code-block:: python

948
                # required: gpu,xpu
949 950 951 952 953 954 955 956 957 958 959
                import paddle
                scaler = paddle.amp.GradScaler(enable=True,
                                               init_loss_scaling=1024,
                                               incr_ratio=2.0,
                                               decr_ratio=0.5,
                                               incr_every_n_steps=1000,
                                               decr_every_n_nan_or_inf=2,
                                               use_dynamic_loss_scaling=True)
                incr_ratio = scaler.get_incr_ratio()
                print(incr_ratio) # 2.0
        """
960
        return super().get_incr_ratio()
961 962 963 964 965 966 967

    def set_incr_ratio(self, new_incr_ratio):
        """
        Set the multiplier to use when increasing the loss scaling by `new_incr_ratio`, `new_incr_ratio` should > 1.0.

        Args:
            new_incr_ratio(float):  The new_incr_ratio used to update the multiplier to use when increasing the loss scaling.
968

969 970 971
        Examples:
            .. code-block:: python

972
                # required: gpu,xpu
973 974 975 976 977 978 979 980 981 982 983 984 985
                import paddle
                scaler = paddle.amp.GradScaler(enable=True,
                                               init_loss_scaling=1024,
                                               incr_ratio=2.0,
                                               decr_ratio=0.5,
                                               incr_every_n_steps=1000,
                                               decr_every_n_nan_or_inf=2,
                                               use_dynamic_loss_scaling=True)
                print(scaler.get_incr_ratio()) # 2.0
                new_incr_ratio = 3.0
                scaler.set_incr_ratio(new_incr_ratio)
                print(scaler.get_incr_ratio()) # 3.0
        """
986
        super().set_incr_ratio(new_incr_ratio)
987 988 989 990 991 992 993

    def get_decr_ratio(self):
        """
        Get the less-than-one-multiplier to use when decreasing the loss scaling.

        Reurns:
            float:  the less-than-one-multiplier to use when decreasing the loss scaling.
994

995 996 997
        Examples:
            .. code-block:: python

998
                # required: gpu,xpu
999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009
                import paddle
                scaler = paddle.amp.GradScaler(enable=True,
                                               init_loss_scaling=1024,
                                               incr_ratio=2.0,
                                               decr_ratio=0.5,
                                               incr_every_n_steps=1000,
                                               decr_every_n_nan_or_inf=2,
                                               use_dynamic_loss_scaling=True)
                decr_ratio = scaler.get_decr_ratio()
                print(decr_ratio) # 0.5
        """
1010
        return super().get_decr_ratio()
1011 1012 1013 1014 1015 1016 1017

    def set_decr_ratio(self, new_decr_ratio):
        """
        Set the less-than-one-multiplier to use when decreasing the loss scaling by `new_incr_ratio`, `new_decr_ratio` should < 1.0.

        Args:
            new_decr_ratio(float):  The new_decr_ratio used to update the less-than-one-multiplier to use when decreasing the loss scaling.
1018

1019 1020 1021
        Examples:
            .. code-block:: python

1022
                # required: gpu,xpu
1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035
                import paddle
                scaler = paddle.amp.GradScaler(enable=True,
                                               init_loss_scaling=1024,
                                               incr_ratio=2.0,
                                               decr_ratio=0.5,
                                               incr_every_n_steps=1000,
                                               decr_every_n_nan_or_inf=2,
                                               use_dynamic_loss_scaling=True)
                print(scaler.get_decr_ratio()) # 0.5
                new_decr_ratio = 0.1
                scaler.set_decr_ratio(new_decr_ratio)
                print(scaler.get_decr_ratio()) # 0.1
        """
1036
        super().set_decr_ratio(new_decr_ratio)
1037 1038 1039 1040 1041 1042 1043

    def get_incr_every_n_steps(self):
        """
        Return the num `n`, `n` represent increases loss scaling every `n` consecutive steps with finite gradients.

        Reurns:
            int:  the num `n`, `n` represent increases loss scaling every `n` consecutive steps with finite gradients.
1044

1045 1046 1047
        Examples:
            .. code-block:: python

1048
                # required: gpu,xpu
1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059
                import paddle
                scaler = paddle.amp.GradScaler(enable=True,
                                               init_loss_scaling=1024,
                                               incr_ratio=2.0,
                                               decr_ratio=0.5,
                                               incr_every_n_steps=1000,
                                               decr_every_n_nan_or_inf=2,
                                               use_dynamic_loss_scaling=True)
                incr_every_n_steps = scaler.get_incr_every_n_steps()
                print(incr_every_n_steps) # 1000
        """
1060
        return super().get_incr_every_n_steps()
1061 1062 1063 1064 1065 1066 1067

    def set_incr_every_n_steps(self, new_incr_every_n_steps):
        """
        Set the num `n` by `new_incr_every_n_steps`, `n` represent increases loss scaling every `n` consecutive steps with finite gradients.

        Args:
            new_incr_every_n_steps(int):  The new_incr_every_n_steps used to update the num `n`, `n` represent increases loss scaling every `n` consecutive steps with finite gradients.
1068

1069 1070 1071
        Examples:
            .. code-block:: python

1072
                # required: gpu,xpu
1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085
                import paddle
                scaler = paddle.amp.GradScaler(enable=True,
                                               init_loss_scaling=1024,
                                               incr_ratio=2.0,
                                               decr_ratio=0.5,
                                               incr_every_n_steps=1000,
                                               decr_every_n_nan_or_inf=2,
                                               use_dynamic_loss_scaling=True)
                print(scaler.get_incr_every_n_steps()) # 1000
                new_incr_every_n_steps = 2000
                scaler.set_incr_every_n_steps(new_incr_every_n_steps)
                print(scaler.get_incr_every_n_steps()) # 2000
        """
1086
        super().set_incr_every_n_steps(new_incr_every_n_steps)
1087 1088 1089 1090 1091 1092 1093

    def get_decr_every_n_nan_or_inf(self):
        """
        Return the num `n`, `n` represent decreases loss scaling every `n` accumulated steps with nan or inf gradients.

        Reurns:
            int:  the num `n`, `n` represent decreases loss scaling every `n` accumulated steps with nan or inf gradients.
1094

1095 1096 1097
        Examples:
            .. code-block:: python

1098
                # required: gpu,xpu
1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109
                import paddle
                scaler = paddle.amp.GradScaler(enable=True,
                                               init_loss_scaling=1024,
                                               incr_ratio=2.0,
                                               decr_ratio=0.5,
                                               incr_every_n_steps=1000,
                                               decr_every_n_nan_or_inf=2,
                                               use_dynamic_loss_scaling=True)
                decr_every_n_nan_or_inf = scaler.get_decr_every_n_nan_or_inf()
                print(decr_every_n_nan_or_inf) # 2
        """
1110
        return super().get_decr_every_n_nan_or_inf()
1111 1112 1113 1114 1115 1116 1117

    def set_decr_every_n_nan_or_inf(self, new_decr_every_n_nan_or_inf):
        """
        Set the num `n` by `new_decr_every_n_nan_or_inf`, `n` represent decreases loss scaling every `n` accumulated steps with nan or inf gradients.

        Args:
            new_decr_every_n_nan_or_inf(int):  The new_decr_every_n_nan_or_inf used to update the num `n`, `n` represent decreases loss scaling every `n` accumulated steps with nan or inf gradients.
1118

1119 1120 1121
        Examples:
            .. code-block:: python

1122
                # required: gpu,xpu
1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135
                import paddle
                scaler = paddle.amp.GradScaler(enable=True,
                                               init_loss_scaling=1024,
                                               incr_ratio=2.0,
                                               decr_ratio=0.5,
                                               incr_every_n_steps=1000,
                                               decr_every_n_nan_or_inf=2,
                                               use_dynamic_loss_scaling=True)
                print(scaler.get_decr_every_n_nan_or_inf()) # 2
                new_decr_every_n_nan_or_inf = 3
                scaler.set_decr_every_n_nan_or_inf(new_decr_every_n_nan_or_inf)
                print(scaler.get_decr_every_n_nan_or_inf()) # 3
        """
1136
        super().set_decr_every_n_nan_or_inf(new_decr_every_n_nan_or_inf)
1137 1138 1139 1140 1141 1142 1143

    def state_dict(self):
        """
        Returns the state of the scaler as a `dict`, If this instance is not enabled, returns an empty dict.

        Reurns:
            A dict of scaler includes:
1144 1145 1146 1147 1148 1149 1150 1151 1152
            scale (tensor): The loss scaling factor.
            incr_ratio(float): The multiplier to use when increasing the loss scaling.
            decr_ratio(float): The less-than-one-multiplier to use when decreasing the loss scaling.
            incr_every_n_steps(int): Increases loss scaling every n consecutive steps with finite gradients.
            decr_every_n_nan_or_inf(int): Decreases loss scaling every n accumulated steps with nan or inf gradients.
            incr_count(int): The number of recent consecutive unskipped steps.
            decr_count(int): The number of recent consecutive skipped steps.
            use_dynamic_loss_scaling(bool): Whether to use dynamic loss scaling. If False, fixed loss_scaling is used. If True, the loss scaling is updated dynamicly. Default is True.

1153

1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169
        Examples:

            .. code-block:: python

                # required: gpu,xpu
                import paddle

                scaler = paddle.amp.GradScaler(enable=True,
                                               init_loss_scaling=1024,
                                               incr_ratio=2.0,
                                               decr_ratio=0.5,
                                               incr_every_n_steps=1000,
                                               decr_every_n_nan_or_inf=2,
                                               use_dynamic_loss_scaling=True)
                scaler_state = scaler.state_dict()
        """
1170
        return super().state_dict()
1171 1172 1173 1174

    def load_state_dict(self, state_dict):
        """
        Loads the scaler state.
1175

1176 1177
        Args:
           state_dict(dict): scaler state.  Should be an object returned from a call to `GradScaler.state_dict()`.
1178

1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195
        Examples:

            .. code-block:: python

                # required: gpu,xpu
                import paddle

                scaler = paddle.amp.GradScaler(enable=True,
                                               init_loss_scaling=1024,
                                               incr_ratio=2.0,
                                               decr_ratio=0.5,
                                               incr_every_n_steps=1000,
                                               decr_every_n_nan_or_inf=2,
                                               use_dynamic_loss_scaling=True)
                scaler_state = scaler.state_dict()
                scaler.load_state_dict(scaler_state)
        """
1196
        super().load_state_dict(state_dict)