loss_scaler.py 22.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle.fluid import core
from paddle.fluid.dygraph import to_variable
17 18 19 20 21
from paddle.fluid.framework import (
    _varbase_creator,
    _dygraph_tracer,
    dygraph_only,
)
22 23 24 25
from paddle.fluid.data_feeder import check_type
from ...wrapped_decorator import signature_safe_contextmanager, wrap_decorator
import warnings
import numpy as np
26
from paddle import _C_ops, _legacy_C_ops
27 28
from collections import defaultdict
from enum import Enum
29
from paddle.fluid import in_dygraph_mode
30

31 32 33 34 35 36 37 38 39 40 41
__all__ = ['AmpScaler', 'OptimizerState']


class OptimizerState(Enum):
    INIT = 0
    UNSCALED = 1
    STEPPED = 2


def _refresh_optimizer_state():
    return {"state": OptimizerState.INIT}
42 43


44
class AmpScaler:
45 46 47 48 49
    """
    :api_attr: imperative

    AmpScaler is used for Auto-Mixed-Precision training/inferring in imperative
    mode. It controls the scaling of loss, helps avoiding numerical overflow.
50
    The object of this class has seventeen methods `scale()`, `unscale_()`, `minimize()` and `get`/`set` api of parameters.
51 52

    `scale()` is used to multiply the loss by a scale ratio.
53 54
    `unscale_()` is used to unscale the gradients of parameters, multiplies the gradients of parameters by 1/(scale ratio)
    `minimize()` is similar as `optimizer.minimize()`, performs parameters updating, and it will update the loss_scaling.
55

56
    Commonly, it is used together with `amp_guard` to achieve Auto-Mixed-Precision in
57 58 59 60 61
    imperative mode.

    Args:
        enable(bool, optional): Enable loss scaling or not. Default is True.
        init_loss_scaling (float, optional): The initial loss scaling factor. Default is 2**15.
62
        incr_ratio(float, optional): The multiplier to use when increasing the loss
63
                        scaling. Default is 2.0.
64
        decr_ratio(float, optional): The less-than-one-multiplier to use when decreasing
65
                        the loss scaling. Default is 0.5.
66
        incr_every_n_steps(int, optional): Increases loss scaling every n consecutive
67
                                steps with finite gradients. Default is 1000.
68
        decr_every_n_nan_or_inf(int, optional): Decreases loss scaling every n
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
                                    accumulated steps with nan or inf gradients. Default is 2.
        use_dynamic_loss_scaling(bool, optional): Whether to use dynamic loss scaling. If False, fixed loss_scaling is used. If True, the loss scaling is updated dynamicly. Default is True.
    Returns:
        An AmpScaler object.

    Examples:

     .. code-block:: python

        import numpy as np
        import paddle.fluid as fluid

        data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32')
        with fluid.dygraph.guard():
            model = fluid.dygraph.Conv2D(3, 2, 3)
            optimizer = fluid.optimizer.SGDOptimizer(
                    learning_rate=0.01, parameter_list=model.parameters())
            scaler = fluid.dygraph.AmpScaler(init_loss_scaling=1024)
            data = fluid.dygraph.to_variable(data)
            with fluid.dygraph.amp_guard():
                conv = model(data)
                loss = fluid.layers.reduce_mean(conv)
                scaled = scaler.scale(loss)
                scaled.backward()
93
                scaler.minimize(optimizer, scaled)
94 95 96
    """

    @dygraph_only
97 98 99 100 101 102 103 104 105 106
    def __init__(
        self,
        enable=True,
        init_loss_scaling=2.0**15,
        incr_ratio=2.0,
        decr_ratio=0.5,
        incr_every_n_steps=1000,
        decr_every_n_nan_or_inf=1,
        use_dynamic_loss_scaling=True,
    ):
107 108 109 110

        tracer = _dygraph_tracer()
        if not tracer:
            raise ValueError(
111 112
                "current_tracer is None, maybe it is not in imperative mode."
            )
113

114 115 116 117 118 119 120
        if enable and not (
            tracer._expected_place.is_gpu_place()
            or tracer._expected_place.is_xpu_place()
            or tracer._expected_place.is_mlu_place()
            or tracer._expected_place.is_npu_place()
            or tracer._expected_place.is_custom_place()
        ):
121
            warnings.warn(
122
                'AmpScaler can only be enabled on CUDAPlace, XPUPlace, MLUPlace, NPUPlace and CustomPlace, current place is %s, so it makes no effect.'
123 124
                % tracer._expected_place
            )
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
            enable = False

        self._enable = enable

        if self._enable:
            assert incr_ratio > 1.0, "The incr_ratio must be > 1.0."
            assert decr_ratio < 1.0, "The decr_ratio must be < 1.0."

            self._init_loss_scaling = init_loss_scaling
            self._incr_ratio = incr_ratio
            self._decr_ratio = decr_ratio
            self._incr_every_n_steps = incr_every_n_steps
            self._decr_every_n_nan_or_inf = decr_every_n_nan_or_inf
            self._incr_count = 0
            self._decr_count = 0
            self._use_dynamic_loss_scaling = use_dynamic_loss_scaling

142
            self._found_inf = to_variable(np.array([0]).astype(np.bool_))
143
            self._temp_found_inf_fp16 = to_variable(
144 145
                np.array([0]).astype(np.bool_)
            )
146
            self._temp_found_inf_bf16 = to_variable(
147 148
                np.array([0]).astype(np.bool_)
            )
149
            self._temp_found_inf_fp32 = to_variable(
150 151
                np.array([0]).astype(np.bool_)
            )
152
            self._scale = to_variable(
153 154
                np.array([self._init_loss_scaling]).astype(np.float32)
            )
155
            self._cache_founf_inf = None
156
            self._optimizer_states = defaultdict(_refresh_optimizer_state)
157 158 159

    def scale(self, var):
        """
160
        Multiplies a variable(Tensor) by the scale factor and returns scaled outputs.
161 162 163 164 165 166
        If this instance of :class:`AmpScaler` is not enabled, output are returned unmodified.

        Args:
            var (Variable):  The variable to scale.
        Returns:
            The scaled variable or original variable.
167

168
        Examples:
169

170 171
            .. code-block:: python

172 173 174 175 176 177 178 179 180 181 182 183 184 185 186
                import numpy as np
                import paddle.fluid as fluid

                data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32')
                with fluid.dygraph.guard():
                    model = fluid.dygraph.Conv2D(3, 2, 3)
                    optimizer = fluid.optimizer.SGDOptimizer(
                            learning_rate=0.01, parameter_list=model.parameters())
                    scaler = fluid.dygraph.AmpScaler(init_loss_scaling=1024)
                    data = fluid.dygraph.to_variable(data)
                    with fluid.dygraph.amp_guard():
                        conv = model(data)
                        loss = fluid.layers.reduce_mean(conv)
                        scaled = scaler.scale(loss)
                        scaled.backward()
187
                        scaler.minimize(optimizer, scaled)
188 189 190 191 192 193 194 195 196 197 198
        """
        check_type(var, "var", core.VarBase, 'AmpScaler.scale()')

        if not self._enable:
            return var

        return var * self._scale

    def minimize(self, optimizer, *args, **kwargs):
        """
        This function is similar as `Optimizer.minimize()`, which performs parameters updating.
199

200
        If the scaled gradients of parameters contains NAN or INF, the parameters updating is skipped.
201
        Otherwise, if `unscale_()` has not been called, it first unscales the scaled gradients of parameters, then updates the parameters.
202 203 204 205 206 207 208 209 210

        Finally, the loss scaling ratio is updated.

        Args:
            optimizer(Optimizer):  The optimizer used to update parameters.
            args:  Arguments, which will be forward to `optimizer.minimize()`.
            kwargs: Keyword arguments, which will be forward to `Optimizer.minimize()`.

        Examples:
211

212 213
            .. code-block:: python

214 215 216 217 218 219 220 221 222 223 224 225 226 227 228
                import numpy as np
                import paddle.fluid as fluid

                data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32')
                with fluid.dygraph.guard():
                    model = fluid.dygraph.Conv2D(3, 2, 3)
                    optimizer = fluid.optimizer.SGDOptimizer(
                            learning_rate=0.01, parameter_list=model.parameters())
                    scaler = fluid.dygraph.AmpScaler(init_loss_scaling=1024)
                    data = fluid.dygraph.to_variable(data)
                    with fluid.dygraph.amp_guard():
                        conv = model(data)
                        loss = fluid.layers.reduce_mean(conv)
                        scaled = scaler.scale(loss)
                        scaled.backward()
229
                        scaler.minimize(optimizer, scaled)
230 231 232 233
        """
        if not self._enable:
            return optimizer.minimize(*args, **kwargs)

234 235
        optimizer_state = self._optimizer_states[id(optimizer)]

236
        #  unscale the grad
237 238
        if optimizer_state["state"] is OptimizerState.INIT:
            self._unscale(optimizer)
239 240 241 242 243 244 245 246 247 248 249 250 251

        optimize_ops, params_grads = (None, None)

        if self._found_inf:
            self._cache_founf_inf = True
        else:
            optimize_ops, params_grads = optimizer.minimize(*args, **kwargs)
            self._cache_founf_inf = False

        if self._use_dynamic_loss_scaling:
            # uopdate the scale
            self._update()

252 253
        self._optimizer_states = defaultdict(_refresh_optimizer_state)

254 255 256
        return optimize_ops, params_grads

    def _unscale(self, optimizer):
257
        """
258
        Unscale the gradients of parameters, multiplies the gradients of parameters by 1/(loss scaling ratio).
259 260 261 262 263 264
        If this instance of :class:`GradScaler` is not enabled, output are returned unmodified.
        Args:
            optimizer(Optimizer):  The optimizer used to update parameters.
        Returns:
            The unscaled parameters or original parameters.
        """
265 266
        if not self._enable:
            return
267

268 269 270 271 272 273 274 275 276
        optimizer_state = self._optimizer_states[id(optimizer)]

        if optimizer_state["state"] is OptimizerState.UNSCALED:
            raise RuntimeError(
                "unscale_() has already been called on this optimizer since the last update()."
            )
        elif optimizer_state["state"] is OptimizerState.STEPPED:
            raise RuntimeError("unscale_() is being called after step().")

277
        if getattr(optimizer, '_param_groups', None) and isinstance(
278 279
            optimizer._param_groups[0], dict
        ):
280
            param_grads = []
281
            param_grads_fp16 = []
282
            param_grads_bf16 = []
283
            param_grads_fp32 = []
284 285 286 287
            for group in optimizer._param_groups:
                for param in group['params']:
                    if param._grad_ivar() is not None:
                        param_grads.append(param._grad_ivar())
288 289 290 291
                        if (
                            param._grad_ivar().dtype
                            == core.VarDesc.VarType.FP16
                        ):
292
                            param_grads_fp16.append(param._grad_ivar())
293 294 295 296
                        elif (
                            param._grad_ivar().dtype
                            == core.VarDesc.VarType.BF16
                        ):
297
                            param_grads_bf16.append(param._grad_ivar())
298 299
                        else:
                            param_grads_fp32.append(param._grad_ivar())
300
        else:
301
            if in_dygraph_mode():
302 303
                # It is very time-consuming to call c++ functions in a loop on the python side.
                # We put this part of the code on the c++ side to improve the speed in eager mode.
304 305 306 307 308 309
                (
                    param_grads_fp16,
                    param_grads_bf16,
                    param_grads_fp32,
                ) = core.eager.get_grads_lists(optimizer._parameter_list)
            else:
310 311
                # Keep the original code to support legacy mode.
                # Delete the else branch when the legacy mode exits.
312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331
                param_grads = [
                    param._grad_ivar()
                    for param in optimizer._parameter_list
                    if param._grad_ivar() is not None
                ]
                param_grads_fp16 = [
                    param
                    for param in param_grads
                    if param.dtype == core.VarDesc.VarType.FP16
                ]
                param_grads_bf16 = [
                    param
                    for param in param_grads
                    if param.dtype == core.VarDesc.VarType.BF16
                ]
                param_grads_fp32 = [
                    param
                    for param in param_grads
                    if param.dtype == core.VarDesc.VarType.FP32
                ]
F
furnace 已提交
332
        if core.is_compiled_with_npu():
333 334
            float_status = _legacy_C_ops.alloc_float_status()
            _legacy_C_ops.clear_float_status(float_status, float_status)
F
furnace 已提交
335 336

            if len(param_grads_fp16):
337
                _legacy_C_ops.check_finite_and_unscale(
338 339 340 341 342 343
                    param_grads_fp16,
                    self._scale,
                    float_status,
                    param_grads_fp16,
                    self._temp_found_inf_fp16,
                )
344 345
            if len(param_grads_bf16):
                _legacy_C_ops.check_finite_and_unscale(
346 347 348 349 350 351
                    param_grads_bf16,
                    self._scale,
                    float_status,
                    param_grads_bf16,
                    self._temp_found_inf_bf16,
                )
F
furnace 已提交
352
            if len(param_grads_fp32):
353
                _legacy_C_ops.check_finite_and_unscale(
354 355 356 357 358 359
                    param_grads_fp32,
                    self._scale,
                    float_status,
                    param_grads_fp32,
                    self._temp_found_inf_fp32,
                )
F
furnace 已提交
360 361
        else:
            if len(param_grads_fp16):
362
                _legacy_C_ops.check_finite_and_unscale(
363 364 365 366 367
                    param_grads_fp16,
                    self._scale,
                    param_grads_fp16,
                    self._temp_found_inf_fp16,
                )
368 369
            if len(param_grads_bf16):
                _legacy_C_ops.check_finite_and_unscale(
370 371 372 373 374
                    param_grads_bf16,
                    self._scale,
                    param_grads_bf16,
                    self._temp_found_inf_bf16,
                )
F
furnace 已提交
375
            if len(param_grads_fp32):
376
                _legacy_C_ops.check_finite_and_unscale(
377 378 379 380 381 382 383 384 385 386 387
                    param_grads_fp32,
                    self._scale,
                    param_grads_fp32,
                    self._temp_found_inf_fp32,
                )

        self._found_inf = (
            self._temp_found_inf_fp16
            or self._temp_found_inf_bf16
            or self._temp_found_inf_fp32
        )
388

389 390
        optimizer_state["state"] = OptimizerState.UNSCALED

391 392 393 394 395 396 397 398 399 400 401 402
    def _update(self):
        """
        Updates the loss_scaling.
        """
        if not self._enable:
            return

        if self._cache_founf_inf:
            self._incr_count = 0
            self._decr_count = self._decr_count + 1
            if self._decr_count == self._decr_every_n_nan_or_inf:
                print(
403 404 405 406 407 408
                    'Found inf or nan, current scale is: {}, decrease to: {}*{}'.format(
                        float(self._scale),
                        float(self._scale),
                        float(self._decr_ratio),
                    )
                )
409 410 411 412 413 414 415 416 417 418
                self._scale = self._scale * self._decr_ratio
                self._decr_count = 0
        else:
            self._decr_count = 0
            self._incr_count = self._incr_count + 1
            if self._incr_count == self._incr_every_n_steps:
                self._scale = self._scale * self._incr_ratio
                self._incr_count = 0

        return
419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455

    def is_enable(self):
        """
        Enable loss scaling or not.

        Returns:
            bool: enable loss scaling return True else return False.
        """
        return self._enable

    def is_use_dynamic_loss_scaling(self):
        """
        Whether to use dynamic loss scaling.

        Returns:
            bool: if fixed loss_scaling is used return False, if the loss scaling is updated dynamicly return true.
        """
        return self._use_dynamic_loss_scaling

    def get_init_loss_scaling(self):
        """
        Return the initial loss scaling factor.

        Reurns:
            float:  the initial loss scaling factor.
        """
        return self._init_loss_scaling

    def set_init_loss_scaling(self, new_init_loss_scaling):
        """
        Set the initial loss scaling factor by `new_init_loss_scaling`.

        Args:
            new_init_loss_scaling(int):  The new_init_loss_scaling used to update initial loss scaling factor.s
        """
        self._init_loss_scaling = new_init_loss_scaling
        self._scale = to_variable(
456 457
            np.array([self._init_loss_scaling]).astype(np.float32)
        )
458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531

    def get_incr_ratio(self):
        """
        Return the multiplier to use when increasing the loss scaling.

        Reurns:
            float:  the multiplier to use when increasing the loss scaling.
        """
        return self._incr_ratio

    def set_incr_ratio(self, new_incr_ratio):
        """
        Set the multiplier to use when increasing the loss scaling by `new_incr_ratio`, `new_incr_ratio` should > 1.0.

        Args:
            new_incr_ratio(float):  The new_incr_ratio used to update the multiplier to use when increasing the loss scaling.
        """
        assert new_incr_ratio > 1.0, "The new_incr_ratio must be > 1.0."
        self._incr_ratio = new_incr_ratio

    def get_decr_ratio(self):
        """
        Get the less-than-one-multiplier to use when decreasing the loss scaling.

        Reurns:
            float:  the less-than-one-multiplier to use when decreasing the loss scaling.
        """
        return self._decr_ratio

    def set_decr_ratio(self, new_decr_ratio):
        """
        Set the less-than-one-multiplier to use when decreasing the loss scaling by `new_incr_ratio`, `new_decr_ratio` should < 1.0.

        Args:
            new_decr_ratio(float):  The new_decr_ratio used to update the less-than-one-multiplier to use when decreasing the loss scaling.
        """
        assert new_decr_ratio < 1.0, "The new_decr_ratio must be < 1.0."
        self._decr_ratio = new_decr_ratio

    def get_incr_every_n_steps(self):
        """
        Return the num `n`, `n` represent increases loss scaling every `n` consecutive steps with finite gradients.

        Reurns:
            int:  the num `n`, `n` represent increases loss scaling every `n` consecutive steps with finite gradients.
        """
        return self._incr_every_n_steps

    def set_incr_every_n_steps(self, new_incr_every_n_steps):
        """
        Set the num `n` by `new_incr_every_n_steps`, `n` represent increases loss scaling every `n` consecutive steps with finite gradients.

        Args:
            new_incr_every_n_steps(int):  The new_incr_every_n_steps used to update the num `n`, `n` represent increases loss scaling every `n` consecutive steps with finite gradients.
        """
        self._incr_every_n_steps = new_incr_every_n_steps

    def get_decr_every_n_nan_or_inf(self):
        """
        Return the num `n`, `n` represent decreases loss scaling every `n` accumulated steps with nan or inf gradients.

        Reurns:
            int:  the num `n`, `n` represent decreases loss scaling every `n` accumulated steps with nan or inf gradients.
        """
        return self._decr_every_n_nan_or_inf

    def set_decr_every_n_nan_or_inf(self, new_decr_every_n_nan_or_inf):
        """
        Set the num `n` by `new_decr_every_n_nan_or_inf`, `n` represent decreases loss scaling every `n` accumulated steps with nan or inf gradients.

        Args:
            new_decr_every_n_nan_or_inf(int):  The new_decr_every_n_nan_or_inf used to update the num `n`, `n` represent decreases loss scaling every `n` accumulated steps with nan or inf gradients.
        """
        self._decr_every_n_nan_or_inf = new_decr_every_n_nan_or_inf
532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547

    def state_dict(self):
        """
        Returns the state of the scaler as a `dict`, If this instance is not enabled, returns an empty dict.

        Reurns:
            A dict of scaler includes:
            scale (tensor): The loss scaling factor.
            incr_ratio(float): The multiplier to use when increasing the loss scaling.
            decr_ratio(float): The less-than-one-multiplier to use when decreasing the loss scaling.
            incr_every_n_steps(int): Increases loss scaling every n consecutive steps with finite gradients.
            decr_every_n_nan_or_inf(int): Decreases loss scaling every n accumulated steps with nan or inf gradients.
            incr_count(int): The number of recent consecutive unskipped steps.
            decr_count(int): The number of recent consecutive skipped steps.
            use_dynamic_loss_scaling(bool): Whether to use dynamic loss scaling. If False, fixed loss_scaling is used. If True, the loss scaling is updated dynamicly. Default is True.
        """
548 549 550 551 552 553 554 555 556 557 558 559 560 561
        return (
            {
                "scale": self._scale.numpy(),
                "incr_ratio": self._incr_ratio,
                "decr_ratio": self._decr_ratio,
                "incr_every_n_steps": self._incr_every_n_steps,
                "decr_every_n_nan_or_inf": self._decr_every_n_nan_or_inf,
                "incr_count": self._incr_count,
                "decr_count": self._decr_count,
                "use_dynamic_loss_scaling": self._use_dynamic_loss_scaling,
            }
            if self._enable
            else {}
        )
562 563 564 565

    def load_state_dict(self, state_dict):
        """
        Loads the scaler state.
566

567 568 569 570 571 572 573 574 575
        Args:
           state_dict(dict): scaler state.  Should be an object returned from a call to `AmpScaler.state_dict()`.
        """
        if not self._enable:
            return

        if len(state_dict) == 0:
            raise RuntimeError(
                "The input state dict is empty, possibly because it was saved "
576 577
                "from a disabled instance of GradScaler."
            )
578 579 580

        self._init_loss_scaling = state_dict["scale"][0]
        self._scale = to_variable(
581 582
            np.array([self._init_loss_scaling]).astype(np.float32)
        )
583 584 585 586 587 588 589
        self._incr_ratio = state_dict["incr_ratio"]
        self._decr_ratio = state_dict["decr_ratio"]
        self._incr_every_n_steps = state_dict["incr_every_n_steps"]
        self._decr_every_n_nan_or_inf = state_dict["decr_every_n_nan_or_inf"]
        self._incr_count = state_dict["incr_count"]
        self._decr_count = state_dict["decr_count"]
        self._use_dynamic_loss_scaling = state_dict["use_dynamic_loss_scaling"]