grad_scaler.py 26.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle.fluid.dygraph.amp import AmpScaler
16 17
from paddle.fluid.dygraph.amp import OptimizerState
from collections import defaultdict
18

19
__all__ = []
20 21


22 23 24 25
def _refresh_optimizer_state():
    return {"state": OptimizerState.INIT}


26 27
class GradScaler(AmpScaler):
    """
28
    GradScaler is used for Auto-Mixed-Precision training in dynamic graph mode.
29
    It controls the scaling of loss, helps avoiding numerical overflow.
30
    The object of this class has nineteen methods `scale()`, `unscale_()`, `minimize()`, `step()`, `update()` and `get`/`set` api of parameters.
31 32

    `scale()` is used to multiply the loss by a scale ratio.
33 34 35 36 37
    `unscale_()` is used to unscale the gradients of parameters, multiplies the gradients of parameters by 1/(scale ratio)
    `minimize()` is similar as `optimizer.minimize()`, performs parameters updating, and it will update the loss_scaling, it equal to `step()` + `update()`.
    `step()` is similar as `optimizer.step()`, which performs parameters updating.
    `update` is used to update the loss_scaling.

38

39
    Commonly, it is used together with `paddle.amp.auto_cast` to achieve Auto-Mixed-Precision in
40 41 42 43 44
    dynamic graph mode.

    Args:
        enable(bool, optional): Enable loss scaling or not. Default is True.
        init_loss_scaling (float, optional): The initial loss scaling factor. Default is 2**15.
45
        incr_ratio(float, optional): The multiplier to use when increasing the loss
46
                        scaling. Default is 2.0.
47
        decr_ratio(float, optional): The less-than-one-multiplier to use when decreasing
48
                        the loss scaling. Default is 0.5.
49
        incr_every_n_steps(int, optional): Increases loss scaling every n consecutive
50
                                steps with finite gradients. Default is 1000.
51
        decr_every_n_nan_or_inf(int, optional): Decreases loss scaling every n
52 53 54
                                    accumulated steps with nan or inf gradients. Default is 2.
        use_dynamic_loss_scaling(bool, optional): Whether to use dynamic loss scaling. If False, fixed loss_scaling is used. If True, the loss scaling is updated dynamicly. Default is True.
    Returns:
55
        An GradScaler object.
56 57 58

    Examples:

59
        .. code-block:: python
60

61
            import paddle
62

63 64 65 66
            model = paddle.nn.Conv2D(3, 2, 3, bias_attr=True)
            optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())
            scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
            data = paddle.rand([10, 3, 32, 32])
L
Leo Chen 已提交
67

68 69 70
            with paddle.amp.auto_cast():
                conv = model(data)
                loss = paddle.mean(conv)
71 72

            scaled = scaler.scale(loss)  # scale the loss
L
Leo Chen 已提交
73
            scaled.backward()            # do backward
74
            scaler.minimize(optimizer, scaled)  # update parameters
75
            optimizer.clear_grad()
76 77
    """

78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
    def __init__(
        self,
        enable=True,
        init_loss_scaling=2.0**15,
        incr_ratio=2.0,
        decr_ratio=0.5,
        incr_every_n_steps=1000,
        decr_every_n_nan_or_inf=2,
        use_dynamic_loss_scaling=True,
    ):
        super(GradScaler, self).__init__(
            enable,
            init_loss_scaling,
            incr_ratio,
            decr_ratio,
            incr_every_n_steps,
            decr_every_n_nan_or_inf,
            use_dynamic_loss_scaling,
        )
97 98 99

    def scale(self, var):
        """
100
        Multiplies a Tensor by the scale factor and returns scaled outputs.
101 102 103 104 105 106
        If this instance of :class:`GradScaler` is not enabled, output are returned unmodified.

        Args:
            var (Tensor):  The tensor to scale.
        Returns:
            The scaled tensor or original tensor.
107

108
        Examples:
L
Leo Chen 已提交
109

110
            .. code-block:: python
111

112 113 114 115 116 117
                import paddle

                model = paddle.nn.Conv2D(3, 2, 3, bias_attr=True)
                optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())
                scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
                data = paddle.rand([10, 3, 32, 32])
L
Leo Chen 已提交
118

119 120 121
                with paddle.amp.auto_cast():
                    conv = model(data)
                    loss = paddle.mean(conv)
L
Leo Chen 已提交
122

123
                scaled = scaler.scale(loss)  # scale the loss
L
Leo Chen 已提交
124
                scaled.backward()            # do backward
125
                scaler.minimize(optimizer, scaled)  # update parameters
126
                optimizer.clear_grad()
127 128 129 130 131
        """
        return super(GradScaler, self).scale(var)

    def minimize(self, optimizer, *args, **kwargs):
        """
132
        This function is similar as `optimizer.minimize()`, which performs parameters updating.
133

134
        If the scaled gradients of parameters contains NAN or INF, the parameters updating is skipped.
135
        Otherwise, if `unscale_()` has not been called, it first unscales the scaled gradients of parameters, then updates the parameters.
136 137 138 139 140 141

        Finally, the loss scaling ratio is updated.

        Args:
            optimizer(Optimizer):  The optimizer used to update parameters.
            args:  Arguments, which will be forward to `optimizer.minimize()`.
142
            kwargs: Keyword arguments, which will be forward to `optimizer.minimize()`.
143 144

        Examples:
L
Leo Chen 已提交
145

146 147
            .. code-block:: python

148 149 150 151 152 153
                import paddle

                model = paddle.nn.Conv2D(3, 2, 3, bias_attr=True)
                optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())
                scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
                data = paddle.rand([10, 3, 32, 32])
L
Leo Chen 已提交
154

155 156 157
                with paddle.amp.auto_cast():
                    conv = model(data)
                    loss = paddle.mean(conv)
L
Leo Chen 已提交
158

159
                scaled = scaler.scale(loss)  # scale the loss
L
Leo Chen 已提交
160
                scaled.backward()            # do backward
161
                scaler.minimize(optimizer, scaled)  # update parameters
162
                optimizer.clear_grad()
163 164
        """
        return super(GradScaler, self).minimize(optimizer, *args, **kwargs)
165

166 167 168
    def step(self, optimizer):
        """
        This function is similar as `optimizer.step()`, which performs parameters updating.
169

170
        If the scaled gradients of parameters contains NAN or INF, the parameters updating is skipped.
171
        Otherwise, if `unscale_()` has not been called, it first unscales the scaled gradients of parameters, then updates the parameters.
172 173 174 175 176

        Args:
            optimizer(Optimizer):  The optimizer used to update parameters.

        Examples:
177

178
            .. code-block:: python
179

180 181
                # required: gpu
                import paddle
182

183 184 185 186 187 188 189
                model = paddle.nn.Conv2D(3, 2, 3, bias_attr=True)
                optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())
                scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
                data = paddle.rand([10, 3, 32, 32])
                with paddle.amp.auto_cast():
                    conv = model(data)
                    loss = paddle.mean(conv)
190
                scaled = scaler.scale(loss)  # scale the loss
191
                scaled.backward()            # do backward
192 193
                scaler.step(optimizer)       # update parameters
                scaler.update()              # update the loss scaling ratio
194 195 196 197 198
                optimizer.clear_grad()
        """
        if not self._enable:
            return optimizer.step()

199 200 201
        optimizer_state = self._optimizer_states[id(optimizer)]
        if optimizer_state["state"] is OptimizerState.STEPPED:
            raise RuntimeError(
202 203
                "step() has already been called since the last update()."
            )
204

205
        #  unscale the grad
206 207
        if optimizer_state["state"] is OptimizerState.INIT:
            self._unscale(optimizer)
208 209 210 211 212 213 214

        if self._found_inf:
            self._cache_founf_inf = True
        else:
            optimizer.step()
            self._cache_founf_inf = False

215 216 217 218 219 220 221 222
        optimizer_state["state"] = OptimizerState.STEPPED

        if not self._use_dynamic_loss_scaling:
            self._optimizer_states = defaultdict(_refresh_optimizer_state)

    def update(self):
        """
        Updates the loss_scaling.
223

224 225 226
        Examples:

            .. code-block:: python
227

228 229 230 231 232 233 234 235 236 237
                # required: gpu
                import paddle

                model = paddle.nn.Conv2D(3, 2, 3, bias_attr=True)
                optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())
                scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
                data = paddle.rand([10, 3, 32, 32])
                with paddle.amp.auto_cast():
                    conv = model(data)
                    loss = paddle.mean(conv)
238
                scaled = scaler.scale(loss)     # scale the loss
239 240 241
                scaled.backward()               # do backward
                scaler.step(optimizer)          # update parameters
                scaler.update()                 # update the loss scaling ratio
242
                optimizer.clear_grad()
243 244 245
        """
        if not self._enable:
            return
246 247
        if self._use_dynamic_loss_scaling:
            self._update()
248 249 250 251 252
            self._optimizer_states = defaultdict(_refresh_optimizer_state)
        return

    def unscale_(self, optimizer):
        """
253
        Unscale the gradients of parameters, multiplies the gradients of parameters by 1/(loss scaling ratio).
254 255 256 257 258 259 260
        If this instance of :class:`GradScaler` is not enabled, output are returned unmodified.

        Args:
            optimizer(Optimizer):  The optimizer used to update parameters.

        Returns:
            The unscaled parameters or original parameters.
261

262 263 264 265 266 267 268 269 270 271 272 273 274 275
        Examples:

            .. code-block:: python

                # required: gpu
                import paddle

                model = paddle.nn.Conv2D(3, 2, 3, bias_attr=True)
                optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())
                scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
                data = paddle.rand([10, 3, 32, 32])
                with paddle.amp.auto_cast():
                    conv = model(data)
                    loss = paddle.mean(conv)
276
                scaled = scaler.scale(loss)  # scale the loss
277 278 279
                scaled.backward()            # do backward
                scaler.unscale_(optimizer)    # unscale the parameter
                scaler.step(optimizer)
280 281
                scaler.update()
                optimizer.clear_grad()
282 283
        """
        return super(GradScaler, self)._unscale(optimizer)
284

285 286 287 288 289 290
    def is_enable(self):
        """
        Enable loss scaling or not.

        Returns:
            bool: enable loss scaling return True else return False.
291

292 293 294
        Examples:
            .. code-block:: python

295
                # required: gpu,xpu
296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314
                import paddle
                scaler = paddle.amp.GradScaler(enable=True,
                                               init_loss_scaling=1024,
                                               incr_ratio=2.0,
                                               decr_ratio=0.5,
                                               incr_every_n_steps=1000,
                                               decr_every_n_nan_or_inf=2,
                                               use_dynamic_loss_scaling=True)
                enable = scaler.is_enable()
                print(enable) # True
        """
        return super(GradScaler, self).is_enable()

    def is_use_dynamic_loss_scaling(self):
        """
        Whether to use dynamic loss scaling.

        Returns:
            bool: if fixed loss_scaling is used return False, if the loss scaling is updated dynamicly return true.
315

316 317
        Examples:
            .. code-block:: python
318

319
                # required: gpu,xpu
320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338
                import paddle
                scaler = paddle.amp.GradScaler(enable=True,
                                               init_loss_scaling=1024,
                                               incr_ratio=2.0,
                                               decr_ratio=0.5,
                                               incr_every_n_steps=1000,
                                               decr_every_n_nan_or_inf=2,
                                               use_dynamic_loss_scaling=True)
                use_dynamic_loss_scaling = scaler.is_use_dynamic_loss_scaling()
                print(use_dynamic_loss_scaling) # True
        """
        return super(GradScaler, self).is_use_dynamic_loss_scaling()

    def get_init_loss_scaling(self):
        """
        Return the initial loss scaling factor.

        Reurns:
            float:  the initial loss scaling factor.
339

340 341 342
        Examples:
            .. code-block:: python

343
                # required: gpu,xpu
344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361
                import paddle
                scaler = paddle.amp.GradScaler(enable=True,
                                               init_loss_scaling=1024,
                                               incr_ratio=2.0,
                                               decr_ratio=0.5,
                                               incr_every_n_steps=1000,
                                               decr_every_n_nan_or_inf=2,
                                               use_dynamic_loss_scaling=True)
                init_loss_scaling = scaler.get_init_loss_scaling()
                print(init_loss_scaling) # 1024
        """
        return super(GradScaler, self).get_init_loss_scaling()

    def set_init_loss_scaling(self, new_init_loss_scaling):
        """
        Set the initial loss scaling factor by `new_init_loss_scaling`.

        Args:
362
            new_init_loss_scaling(float):  The new_init_loss_scaling used to update initial loss scaling factor.
363

364 365
        Examples:
            .. code-block:: python
366

367
                # required: gpu,xpu
368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388
                import paddle
                scaler = paddle.amp.GradScaler(enable=True,
                                               init_loss_scaling=1024,
                                               incr_ratio=2.0,
                                               decr_ratio=0.5,
                                               incr_every_n_steps=1000,
                                               decr_every_n_nan_or_inf=2,
                                               use_dynamic_loss_scaling=True)
                print(scaler.get_init_loss_scaling()) # 1024
                new_init_loss_scaling = 1000
                scaler.set_init_loss_scaling(new_init_loss_scaling)
                print(scaler.get_init_loss_scaling()) # 1000
        """
        super(GradScaler, self).set_init_loss_scaling(new_init_loss_scaling)

    def get_incr_ratio(self):
        """
        Return the multiplier to use when increasing the loss scaling.

        Reurns:
            float:  the multiplier to use when increasing the loss scaling.
389

390 391 392
        Examples:
            .. code-block:: python

393
                # required: gpu,xpu
394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412
                import paddle
                scaler = paddle.amp.GradScaler(enable=True,
                                               init_loss_scaling=1024,
                                               incr_ratio=2.0,
                                               decr_ratio=0.5,
                                               incr_every_n_steps=1000,
                                               decr_every_n_nan_or_inf=2,
                                               use_dynamic_loss_scaling=True)
                incr_ratio = scaler.get_incr_ratio()
                print(incr_ratio) # 2.0
        """
        return super(GradScaler, self).get_incr_ratio()

    def set_incr_ratio(self, new_incr_ratio):
        """
        Set the multiplier to use when increasing the loss scaling by `new_incr_ratio`, `new_incr_ratio` should > 1.0.

        Args:
            new_incr_ratio(float):  The new_incr_ratio used to update the multiplier to use when increasing the loss scaling.
413

414 415 416
        Examples:
            .. code-block:: python

417
                # required: gpu,xpu
418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438
                import paddle
                scaler = paddle.amp.GradScaler(enable=True,
                                               init_loss_scaling=1024,
                                               incr_ratio=2.0,
                                               decr_ratio=0.5,
                                               incr_every_n_steps=1000,
                                               decr_every_n_nan_or_inf=2,
                                               use_dynamic_loss_scaling=True)
                print(scaler.get_incr_ratio()) # 2.0
                new_incr_ratio = 3.0
                scaler.set_incr_ratio(new_incr_ratio)
                print(scaler.get_incr_ratio()) # 3.0
        """
        super(GradScaler, self).set_incr_ratio(new_incr_ratio)

    def get_decr_ratio(self):
        """
        Get the less-than-one-multiplier to use when decreasing the loss scaling.

        Reurns:
            float:  the less-than-one-multiplier to use when decreasing the loss scaling.
439

440 441 442
        Examples:
            .. code-block:: python

443
                # required: gpu,xpu
444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462
                import paddle
                scaler = paddle.amp.GradScaler(enable=True,
                                               init_loss_scaling=1024,
                                               incr_ratio=2.0,
                                               decr_ratio=0.5,
                                               incr_every_n_steps=1000,
                                               decr_every_n_nan_or_inf=2,
                                               use_dynamic_loss_scaling=True)
                decr_ratio = scaler.get_decr_ratio()
                print(decr_ratio) # 0.5
        """
        return super(GradScaler, self).get_decr_ratio()

    def set_decr_ratio(self, new_decr_ratio):
        """
        Set the less-than-one-multiplier to use when decreasing the loss scaling by `new_incr_ratio`, `new_decr_ratio` should < 1.0.

        Args:
            new_decr_ratio(float):  The new_decr_ratio used to update the less-than-one-multiplier to use when decreasing the loss scaling.
463

464 465 466
        Examples:
            .. code-block:: python

467
                # required: gpu,xpu
468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488
                import paddle
                scaler = paddle.amp.GradScaler(enable=True,
                                               init_loss_scaling=1024,
                                               incr_ratio=2.0,
                                               decr_ratio=0.5,
                                               incr_every_n_steps=1000,
                                               decr_every_n_nan_or_inf=2,
                                               use_dynamic_loss_scaling=True)
                print(scaler.get_decr_ratio()) # 0.5
                new_decr_ratio = 0.1
                scaler.set_decr_ratio(new_decr_ratio)
                print(scaler.get_decr_ratio()) # 0.1
        """
        super(GradScaler, self).set_decr_ratio(new_decr_ratio)

    def get_incr_every_n_steps(self):
        """
        Return the num `n`, `n` represent increases loss scaling every `n` consecutive steps with finite gradients.

        Reurns:
            int:  the num `n`, `n` represent increases loss scaling every `n` consecutive steps with finite gradients.
489

490 491 492
        Examples:
            .. code-block:: python

493
                # required: gpu,xpu
494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512
                import paddle
                scaler = paddle.amp.GradScaler(enable=True,
                                               init_loss_scaling=1024,
                                               incr_ratio=2.0,
                                               decr_ratio=0.5,
                                               incr_every_n_steps=1000,
                                               decr_every_n_nan_or_inf=2,
                                               use_dynamic_loss_scaling=True)
                incr_every_n_steps = scaler.get_incr_every_n_steps()
                print(incr_every_n_steps) # 1000
        """
        return super(GradScaler, self).get_incr_every_n_steps()

    def set_incr_every_n_steps(self, new_incr_every_n_steps):
        """
        Set the num `n` by `new_incr_every_n_steps`, `n` represent increases loss scaling every `n` consecutive steps with finite gradients.

        Args:
            new_incr_every_n_steps(int):  The new_incr_every_n_steps used to update the num `n`, `n` represent increases loss scaling every `n` consecutive steps with finite gradients.
513

514 515 516
        Examples:
            .. code-block:: python

517
                # required: gpu,xpu
518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538
                import paddle
                scaler = paddle.amp.GradScaler(enable=True,
                                               init_loss_scaling=1024,
                                               incr_ratio=2.0,
                                               decr_ratio=0.5,
                                               incr_every_n_steps=1000,
                                               decr_every_n_nan_or_inf=2,
                                               use_dynamic_loss_scaling=True)
                print(scaler.get_incr_every_n_steps()) # 1000
                new_incr_every_n_steps = 2000
                scaler.set_incr_every_n_steps(new_incr_every_n_steps)
                print(scaler.get_incr_every_n_steps()) # 2000
        """
        super(GradScaler, self).set_incr_every_n_steps(new_incr_every_n_steps)

    def get_decr_every_n_nan_or_inf(self):
        """
        Return the num `n`, `n` represent decreases loss scaling every `n` accumulated steps with nan or inf gradients.

        Reurns:
            int:  the num `n`, `n` represent decreases loss scaling every `n` accumulated steps with nan or inf gradients.
539

540 541 542
        Examples:
            .. code-block:: python

543
                # required: gpu,xpu
544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562
                import paddle
                scaler = paddle.amp.GradScaler(enable=True,
                                               init_loss_scaling=1024,
                                               incr_ratio=2.0,
                                               decr_ratio=0.5,
                                               incr_every_n_steps=1000,
                                               decr_every_n_nan_or_inf=2,
                                               use_dynamic_loss_scaling=True)
                decr_every_n_nan_or_inf = scaler.get_decr_every_n_nan_or_inf()
                print(decr_every_n_nan_or_inf) # 2
        """
        return super(GradScaler, self).get_decr_every_n_nan_or_inf()

    def set_decr_every_n_nan_or_inf(self, new_decr_every_n_nan_or_inf):
        """
        Set the num `n` by `new_decr_every_n_nan_or_inf`, `n` represent decreases loss scaling every `n` accumulated steps with nan or inf gradients.

        Args:
            new_decr_every_n_nan_or_inf(int):  The new_decr_every_n_nan_or_inf used to update the num `n`, `n` represent decreases loss scaling every `n` accumulated steps with nan or inf gradients.
563

564 565 566
        Examples:
            .. code-block:: python

567
                # required: gpu,xpu
568 569 570 571 572 573 574 575 576 577 578 579 580
                import paddle
                scaler = paddle.amp.GradScaler(enable=True,
                                               init_loss_scaling=1024,
                                               incr_ratio=2.0,
                                               decr_ratio=0.5,
                                               incr_every_n_steps=1000,
                                               decr_every_n_nan_or_inf=2,
                                               use_dynamic_loss_scaling=True)
                print(scaler.get_decr_every_n_nan_or_inf()) # 2
                new_decr_every_n_nan_or_inf = 3
                scaler.set_decr_every_n_nan_or_inf(new_decr_every_n_nan_or_inf)
                print(scaler.get_decr_every_n_nan_or_inf()) # 3
        """
581 582 583
        super(GradScaler, self).set_decr_every_n_nan_or_inf(
            new_decr_every_n_nan_or_inf
        )
584 585 586 587 588 589 590

    def state_dict(self):
        """
        Returns the state of the scaler as a `dict`, If this instance is not enabled, returns an empty dict.

        Reurns:
            A dict of scaler includes:
591 592 593 594 595 596 597 598 599
            scale (tensor): The loss scaling factor.
            incr_ratio(float): The multiplier to use when increasing the loss scaling.
            decr_ratio(float): The less-than-one-multiplier to use when decreasing the loss scaling.
            incr_every_n_steps(int): Increases loss scaling every n consecutive steps with finite gradients.
            decr_every_n_nan_or_inf(int): Decreases loss scaling every n accumulated steps with nan or inf gradients.
            incr_count(int): The number of recent consecutive unskipped steps.
            decr_count(int): The number of recent consecutive skipped steps.
            use_dynamic_loss_scaling(bool): Whether to use dynamic loss scaling. If False, fixed loss_scaling is used. If True, the loss scaling is updated dynamicly. Default is True.

600

601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621
        Examples:

            .. code-block:: python

                # required: gpu,xpu
                import paddle

                scaler = paddle.amp.GradScaler(enable=True,
                                               init_loss_scaling=1024,
                                               incr_ratio=2.0,
                                               decr_ratio=0.5,
                                               incr_every_n_steps=1000,
                                               decr_every_n_nan_or_inf=2,
                                               use_dynamic_loss_scaling=True)
                scaler_state = scaler.state_dict()
        """
        return super(GradScaler, self).state_dict()

    def load_state_dict(self, state_dict):
        """
        Loads the scaler state.
622

623 624
        Args:
           state_dict(dict): scaler state.  Should be an object returned from a call to `GradScaler.state_dict()`.
625

626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643
        Examples:

            .. code-block:: python

                # required: gpu,xpu
                import paddle

                scaler = paddle.amp.GradScaler(enable=True,
                                               init_loss_scaling=1024,
                                               incr_ratio=2.0,
                                               decr_ratio=0.5,
                                               incr_every_n_steps=1000,
                                               decr_every_n_nan_or_inf=2,
                                               use_dynamic_loss_scaling=True)
                scaler_state = scaler.state_dict()
                scaler.load_state_dict(scaler_state)
        """
        super(GradScaler, self).load_state_dict(state_dict)