decorator.py 38.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
import types
import warnings
17

18
import paddle
19 20 21 22 23 24 25 26 27
from paddle.fluid import (
    core,
    default_main_program,
    default_startup_program,
    program_guard,
    unique_name,
)

from .amp_nn import check_finite_and_unscale, update_loss_scaling
28
from .fp16_lists import AutoMixedPrecisionLists, check_amp_dtype
29 30 31 32 33
from .fp16_utils import (
    cast_model_to_fp16,
    cast_parameters_to_fp16,
    update_role_var_grad,
)
34
from .function_overload import FunctionType, overload
35 36


37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
def _set_multi_precision(optimizer, multi_precision):
    if not isinstance(
        optimizer,
        (paddle.optimizer.Optimizer, paddle.fluid.optimizer.Optimizer),
    ):
        raise RuntimeError(
            "Current AMP training level is O2, optimizer is expected to be paddle.optimizer.Optimizer or paddle.fluid.optimizer.Optimizer, but receive {}.".format(
                type(optimizer)
            )
        )

    if multi_precision and hasattr(optimizer, "_multi_precision"):
        optimizer._multi_precision = multi_precision


52
class OptimizerWithMixedPrecision:
53
    """
54
    Optimizer with mixed-precision (MP) training. This is a wrapper of a common
Z
Zhen Wang 已提交
55
    optimizer, plus the support of mixed-precision pre-training. The object
56 57 58
    of this class almost has the same behavior as the common optimizer, with the
    methods `minimize()`, `backward()`, `apply_gradients()` implemented.
    Additionally, it enables the MP training automatically, i.e, the creation
59 60 61 62
    and maintenance of master parameters, scaling of loss, etc.

    Args:
        optimizer (Optimizer): A common Optimizer object.
63
        amp_lists (AutoMixedPrecisionLists): An AutoMixedPrecisionLists object.
64 65 66 67 68
        level(str): Auto mixed precision level. Accepted values are "O1", "O2" and "OD": At the O1 level, operators in the white list
             will use float16/bfloat16 inputs for calculations, and operators in the black list will use float32 inputs for calculations. At the O2
             level, model's parameters will be casted to float16/bfloat16 by using `decorator`, and operators that have all float16/bfloat16 inputs
             will be converted to float16/bfloat16, and that have any float32 input will be converted to float32. For the OD level, operators in
             default white list will compute in float16/bfloat16.
69
        dtype(str): Whether to use 'float16' or 'bfloat16'.
70 71
        init_loss_scaling (float): The initial loss scaling factor.
        use_dynamic_loss_scaling (bool): Whether to use dynamic loss scaling.
72
        incr_every_n_steps(int): Increases loss scaling every n consecutive
J
Jie Fang 已提交
73
                                 steps with finite gradients.
74 75
        decr_every_n_nan_or_inf(int): Decreases loss scaling every n
                                      accumulated steps with nan or
J
Jie Fang 已提交
76
                                      inf gradients.
77
        incr_ratio(float): The multiplier to use when increasing the loss
J
Jie Fang 已提交
78
                           scaling.
79
        decr_ratio(float): The less-than-one-multiplier to use when decreasing
J
Jie Fang 已提交
80
                           the loss scaling.
81
        use_amp_guard(bool): Whether to use `fp16_guard` when constructing the program.
82
                           Default None, which means that its value is equal to `use_pure_fp16`.
83
        use_master_grad(bool): Whether to use fp32 master gradients during optimizer. Default is False.
84
        use_promote(bool): Whether to promotes to fp32 when op has any float32 inputs. Default is False.
85 86
    """

87 88 89 90
    def __init__(
        self,
        optimizer,
        amp_lists,
91 92
        level,
        dtype,
93 94 95 96 97 98
        init_loss_scaling,
        use_dynamic_loss_scaling,
        incr_every_n_steps,
        decr_every_n_nan_or_inf,
        incr_ratio,
        decr_ratio,
99
        use_amp_guard=None,
100
        use_master_grad=False,
101
        use_promote=False,
102
    ):
103
        self._optimizer = optimizer
J
Jie Fang 已提交
104
        self._amp_lists = amp_lists
105
        self._param_grads = None
106 107
        self._train_program = None

108
        self._is_distributed = False
109
        self._use_master_grad = False
110
        self._scaled_loss = None
111 112
        self._loss_scaling = None
        self._init_loss_scaling = init_loss_scaling
113
        self._use_dynamic_loss_scaling = use_dynamic_loss_scaling
114 115 116 117 118 119 120 121 122 123 124
        if dtype == "bfloat16":
            if use_dynamic_loss_scaling:
                self._use_dynamic_loss_scaling = False
                self._init_loss_scaling = 1.0
                warnings.warn(
                    "Dynamic loss scaling for bfloat16 amp training is disabled, and the init_loss_scaling is changed to 1.0 automatically by PaddlePaddle."
                )
            self._amp_vartype = core.VarDesc.VarType.BF16
        else:
            self._amp_vartype = core.VarDesc.VarType.FP16

A
Aurelius84 已提交
125 126
        self._learning_rate = optimizer._learning_rate
        self._learning_rate_map = optimizer._learning_rate_map
127
        self._use_pure_fp16 = level == "O2"
128 129 130
        if self._use_pure_fp16 and (dtype == "bfloat16" or dtype == "float16"):
            self._use_master_grad = use_master_grad
            self._optimizer._master_grad = use_master_grad
131
        self._amp_level = level
132
        self._use_fp16_guard = use_amp_guard
133
        self._to_fp16_var_names = None
J
Jie Fang 已提交
134
        if self._use_dynamic_loss_scaling:
135 136
            self._incr_every_n_steps = incr_every_n_steps
            self._decr_every_n_nan_or_inf = decr_every_n_nan_or_inf
J
Jie Fang 已提交
137 138
            self._incr_ratio = incr_ratio
            self._decr_ratio = decr_ratio
139 140
            self._num_good_steps = None
            self._num_bad_steps = None
141
        self.use_promote = use_promote
142

143 144 145 146 147 148
    def _set_distributed(self, flag):
        # if distributed, all cards will communication with each other,
        # overlap communication and computation by split the
        # check_finite_and_unscale op.
        self._is_distributed = flag

149
    def get_loss_scaling(self):
150 151 152 153
        """Return the real-time loss scaling factor."""
        assert (
            self._loss_scaling is not None
        ), 'Please call minimize() before calling get_loss_scaling().'
154 155 156 157 158 159 160 161
        return self._loss_scaling

    def get_scaled_loss(self):
        """Return the scaled loss.
        It's useful when you feed customed loss into executor.
        """
        return self._scaled_loss

162 163 164
    def _supports_check_nan_inf(self):
        return getattr(self._optimizer, "_supports_check_nan_inf", False)

165
    def _init_amp_var(self):
166
        self._loss_scaling = paddle.static.create_global_var(
167 168 169 170
            name=unique_name.generate("loss_scaling"),
            shape=[1],
            value=self._init_loss_scaling,
            dtype='float32',
171 172
            persistable=True,
        )
173 174

        if self._use_dynamic_loss_scaling:
175
            self._num_good_steps = paddle.static.create_global_var(
J
Jie Fang 已提交
176 177 178 179
                name=unique_name.generate("num_good_steps"),
                shape=[1],
                value=0,
                dtype='int32',
180 181
                persistable=True,
            )
182
            self._num_bad_steps = paddle.static.create_global_var(
J
Jie Fang 已提交
183 184 185 186
                name=unique_name.generate("num_bad_steps"),
                shape=[1],
                value=0,
                dtype='int32',
187 188
                persistable=True,
            )
189

190
        # Ensure the data type of learning rate vars is float32 (same as the
191
        # master parameter dtype)
192
        if isinstance(self._optimizer._learning_rate, float):
193 194
            self._optimizer._learning_rate_map[
                default_main_program()
195
            ] = paddle.static.create_global_var(
196 197 198 199 200 201 202 203 204 205 206 207 208 209 210
                name=unique_name.generate("learning_rate"),
                shape=[1],
                value=float(self._optimizer._learning_rate),
                dtype='float32',
                persistable=True,
            )

    def backward(
        self,
        loss,
        startup_program=None,
        parameter_list=None,
        no_grad_set=None,
        callbacks=None,
    ):
211
        """
Z
Zhen Wang 已提交
212
        Backward propagation or auto differentiation for gradients' computation.
213 214 215

        Args:
            loss (Variable): The loss Variable to minimize.
216
            startup_program (Program|None): The startup Program for initializing
217 218 219
                                       parameters in `parameter_list`.
            parameter_list (list|None): A list of Variables to update.
            no_grad_set (set|None): A set of Variables should be ignored.
Z
Zhen Wang 已提交
220
            callbacks (list|None): A list of callable objects to run when appending
221 222 223
                                   backward operator for one parameter.

        Returns:
224
            A list of (param, grad), which is a tuple of a parameter and its
225 226
            gradient respectively, and the scaled loss.
        """
227 228
        train_program = loss.block.program
        self._train_program = train_program
229
        self._float_status = None
230

231
        with program_guard(self._train_program, startup_program):
232 233
            self._init_amp_var()

234 235
            if self._use_pure_fp16:
                self._to_fp16_var_names = cast_model_to_fp16(
236 237 238 239
                    self._train_program,
                    self._amp_lists,
                    self._use_fp16_guard,
                    self._amp_vartype,
240 241
                    level='O2',
                    use_promote=self.use_promote,
242
                )
243
            else:
244 245 246 247 248 249
                # use_fp16_guard is not support amp-o1.
                cast_model_to_fp16(
                    self._train_program,
                    self._amp_lists,
                    use_fp16_guard=False,
                    dest_type=self._amp_vartype,
250
                    level=self._amp_level,
251
                    use_promote=self.use_promote,
252
                )
253 254 255 256 257 258 259 260 261 262

            if loss.dtype != core.VarDesc.VarType.FP32:
                loss = loss.astype('float32')
            # When not using dynamic loss scaling and the init loss scaling value is equal to 1.0,
            # the model can be optimized.
            if self._use_dynamic_loss_scaling or self._init_loss_scaling != 1.0:
                self._scaled_loss = loss * self._loss_scaling
            else:
                self._scaled_loss = loss

263 264 265 266 267 268 269
            params_grads = self._optimizer.backward(
                self._scaled_loss,
                startup_program,
                parameter_list,
                no_grad_set,
                callbacks,
            )
270 271
            if self._supports_check_nan_inf():
                self._add_cast_ops_to_startup_program(startup_program)
272
        return params_grads
273

274 275 276
    def _add_cast_ops_to_startup_program(self, startup_program):
        names = list(self._to_fp16_var_names) if self._to_fp16_var_names else []
        names.sort()
277 278 279 280 281
        startup_program = (
            default_startup_program()
            if startup_program is None
            else startup_program
        )
282 283 284 285 286 287 288
        block = startup_program.global_block()
        param_names = [p.name for p in block.all_parameters()]
        for name in names:
            if name not in param_names:
                continue

            tmp = block.create_var(dtype=core.VarDesc.VarType.FP32)
289 290 291 292 293 294 295 296 297
            block.append_op(
                type='assign', inputs={'X': [name]}, outputs={'Out': [tmp]}
            )
            block.append_op(
                type='cast',
                inputs={'X': [tmp]},
                outputs={'Out': [name]},
                attrs={
                    'in_dtype': core.VarDesc.VarType.FP32,
298
                    'out_dtype': self._amp_vartype,
299 300
                },
            )
301 302
        self._to_fp16_var_names = None

303 304 305
    def amp_init(
        self, place, scope=None, test_program=None, use_fp16_test=False
    ):
306 307
        """
        Init the amp training, such as cast fp32 parameters to fp16 type.
308

309
        Args:
310
            place(CUDAPlace): place is used to initialize
311 312 313 314 315
                fp16 parameters with fp32 values.
            scope(Scope): The scope is used to find fp32 parameters.
            test_program(Program): The program is used for testing.
            use_fp16_test(bool): Whether to use fp16 testing.

H
huangxu96 已提交
316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336
        Examples:
            .. code-block:: python

                import numpy as np
                import paddle
                import paddle.nn.functional as F
                paddle.enable_static()

                def run_example_code():
                    place = paddle.CUDAPlace(0)
                    exe = paddle.static.Executor(place)
                    data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float32')
                    conv2d = paddle.static.nn.conv2d(input=data, num_filters=6, filter_size=3)
                    # 1) Use fp16_guard to control the range of fp16 kernels used.
                    with paddle.static.amp.fp16_guard():
                        bn = paddle.static.nn.batch_norm(input=conv2d, act="relu")
                        pool = F.max_pool2d(bn, kernel_size=2, stride=2)
                        hidden = paddle.static.nn.fc(pool, size=10)
                        loss = paddle.mean(hidden)
                    # 2) Create the optimizer and set `multi_precision` to True.
                    # Setting `multi_precision` to True can avoid the poor accuracy
337
                    # or the slow convergence in a way.
H
huangxu96 已提交
338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356
                    optimizer = paddle.optimizer.Momentum(learning_rate=0.01, multi_precision=True)
                    # 3) These ops in `custom_black_list` will keep in the float32 computation type.
                    amp_list = paddle.static.amp.CustomOpLists(
                        custom_black_list=['pool2d'])
                    # 4) The entry of Paddle AMP.
                    # Enable pure fp16 training by setting `use_pure_fp16` to True.
                    optimizer = paddle.static.amp.decorate(
                        optimizer,
                        amp_list,
                        init_loss_scaling=128.0,
                        use_dynamic_loss_scaling=True,
                        use_pure_fp16=True)
                    # If you don't use the default_startup_program(), you sholud pass
                    # your defined `startup_program` into `minimize`.
                    optimizer.minimize(loss)
                    exe.run(paddle.static.default_startup_program())
                    # 5) Use `amp_init` after FP32 parameters initialization(such as `exe.run(startup_program)`).
                    # If you want to perform the testing process, you should pass `test_program` into `amp_init`.
                    optimizer.amp_init(place, scope=paddle.static.global_scope())
357

H
huangxu96 已提交
358
                if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0:
359
                    run_example_code()
360
        """
361 362 363
        assert (
            self._train_program is not None
        ), "Please call the minimize method first."
364
        if self._use_pure_fp16:
365
            cast_parameters_to_fp16(
366 367 368 369 370
                place,
                self._train_program,
                scope,
                self._to_fp16_var_names,
                self._amp_vartype,
371
            )
372 373
        if test_program is not None:
            if self._use_pure_fp16:
374
                cast_model_to_fp16(
375 376 377 378
                    test_program,
                    self._amp_lists,
                    self._use_fp16_guard,
                    self._amp_vartype,
379 380
                    level='O2',
                    use_promote=self.use_promote,
381
                )
382
            elif use_fp16_test:
383 384 385 386 387 388
                # use_fp16_guard is not support amp-o1.
                cast_model_to_fp16(
                    test_program,
                    self._amp_lists,
                    use_fp16_guard=False,
                    dest_type=self._amp_vartype,
389
                    level=self._amp_level,
390
                    use_promote=self.use_promote,
391
                )
392

393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437
    def _append_cast_to_master_grad_op(self, param_grads):
        """
        Create master gradient vars and add cast gradient to master gradient op in main program

        Args:
          param_grads(list(tuple(Tensor, Tensor))): A list of (parameter, gradient) pair to update.

        Returns:
          list: A list of (parameter, master_gradient) pair. In the following grad clip step and optimizer step, params can be updated by master gradient. main_prog will also append cast ops before grad clip ops.

        """

        if not self._use_master_grad:
            return param_grads

        global_block = self._train_program.global_block()
        target_block = global_block
        current_block = self._train_program.current_block()
        if current_block.idx != global_block.idx:
            target_block = self._train_program.blocks[
                current_block.backward_block_idx
            ]
        params_master_grads = []

        assert isinstance(target_block, paddle.fluid.framework.Block)
        # create
        for p, g in param_grads:
            if g.name not in self._optimizer._master_grads.keys():
                if self._optimizer._is_dtype_fp16_or_bf16(g.dtype):
                    master_g = self._optimizer._create_master_grad(g)
                    params_master_grads.append((p, master_g))
                    target_block.append_op(
                        type="cast",
                        inputs={"X": [g]},
                        outputs={"Out": [master_g]},
                        attrs={
                            "in_dtype": g.dtype,
                            "out_dtype": master_g.dtype,
                        },
                    )
                else:
                    params_master_grads.append((p, g))

        return params_master_grads

438
    def apply_gradients(self, params_grads):
439
        """
440
        Check scaled gradients to determine whether to update loss scaling and update
441
        parameters by their scaled gradients.
442

443
        Args:
444
            params_grads (list): A list of params and scaled grads.
445

446 447 448
        Returns:
            A list of optimize operators.
        """
J
Jie Fang 已提交
449

450 451 452 453
        # Change the op_role_var attr for some ops, so that gradients
        # transferred across GPUs can be FP16.
        update_role_var_grad(self._train_program, params_grads)

454 455 456
        # Create master grad and add cast op into program
        params_grads = self._append_cast_to_master_grad_op(params_grads)

457 458
        # When not using dynamic loss scaling and the init loss scaling value is equal to 1.0,
        # the model can be optimized.
459 460 461 462
        if (
            not self._use_dynamic_loss_scaling
            and self._init_loss_scaling == 1.0
        ):
463 464
            return self._optimizer.apply_gradients(params_grads)

465 466 467 468 469 470 471 472
        if self._supports_check_nan_inf():
            self._optimizer._set_scale(self._loss_scaling)
            optimize_ops = self._optimizer.apply_gradients(params_grads)
            found_inf = self._optimizer._found_inf
            self._add_dynamic_loss_scaling(params_grads, found_inf)
            return optimize_ops

        found_inf = self._check_finite_and_unscale(params_grads)
473 474 475 476
        if (
            self._use_dynamic_loss_scaling
            and self._amp_vartype == core.VarDesc.VarType.FP16
        ):
477 478 479 480 481 482 483
            self._add_dynamic_loss_scaling(params_grads, found_inf)

        # Pass found_inf to adam, to skip update for not only param, but also momentum and beta_pow
        # With fleet, optimizers are nested and the real optimizer set by user is the inner most one.
        real_optimizer = self._optimizer
        while hasattr(real_optimizer, "inner_opt"):
            real_optimizer = real_optimizer.inner_opt
484 485 486 487
        if isinstance(
            real_optimizer,
            (paddle.fluid.optimizer.Adam, paddle.optimizer.AdamW),
        ):
488 489 490
            # NOTE(zhiqiu): Since found_inf needs to be on cpu in adam op, we
            # copy it in advance to avoid multiple time copies.
            with self._train_program._optimized_guard([]):
491
                found_inf = paddle.tensor.creation._memcpy(
492 493
                    found_inf, paddle.CPUPlace()
                )
494 495 496 497 498 499 500
            real_optimizer._set_auxiliary_var('found_inf', found_inf)
        elif hasattr(real_optimizer, "_set_auxiliary_var"):
            real_optimizer._set_auxiliary_var('found_inf', found_inf)
        optimize_ops = self._optimizer.apply_gradients(params_grads)
        return optimize_ops

    def _split_grads(self, params_grads):
501
        grads = [g for _, g in params_grads]
502
        fp32_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP32]
503
        fp16_grads = [g for g in grads if g.dtype == self._amp_vartype]
504 505
        assert len(fp32_grads) + len(fp16_grads) == len(
            grads
506
        ), "Data types of all grads must be either fp16/bf16 or fp32."
507
        return grads, fp32_grads, fp16_grads
508

509 510
    def _check_finite_and_unscale(self, params_grads):
        grads, fp32_grads, fp16_grads = self._split_grads(params_grads)
511
        found_infs = []
512

513
        if self._is_distributed:
514 515
            # if distributed, split check_finite_and_unscale to overlap
            # unscale with communication
516 517
            for p, g in params_grads:
                with self._train_program._optimized_guard([p, g]):
518
                    _, found_inf = check_finite_and_unscale(
519 520 521
                        [
                            g,
                        ],
522 523
                        self._loss_scaling,
                        name="find_infinite_scale",
524 525
                        float_status=self._float_status,
                    )
526
                    found_infs.append(found_inf)
527 528 529 530 531 532
        elif self._use_pure_fp16:
            if fp32_grads:
                with self._train_program._optimized_guard(fp32_grads):
                    _, fp32_found_inf = check_finite_and_unscale(
                        fp32_grads,
                        self._loss_scaling,
533
                        name="find_infinite_scale_fp32",
534 535
                        float_status=self._float_status,
                    )
536 537 538 539 540 541
                found_infs.append(fp32_found_inf)
            if fp16_grads:
                with self._train_program._optimized_guard(fp16_grads):
                    _, fp16_found_inf = check_finite_and_unscale(
                        fp16_grads,
                        self._loss_scaling,
542
                        name="find_infinite_scale_fp16",
543 544
                        float_status=self._float_status,
                    )
545 546 547 548
                found_infs.append(fp16_found_inf)
        else:
            with self._train_program._optimized_guard(grads):
                _, found_inf = check_finite_and_unscale(
549 550 551
                    grads,
                    self._loss_scaling,
                    name="find_infinite_scale",
552 553
                    float_status=self._float_status,
                )
J
Jie Fang 已提交
554

555 556
        if self._is_distributed or self._use_pure_fp16:
            with self._train_program._optimized_guard([]):
557
                all_infs = paddle.concat(found_infs)
558
                found_inf = paddle.any(all_infs)
559

560 561 562 563 564 565 566 567 568 569 570 571 572 573 574
        return found_inf

    def _add_dynamic_loss_scaling(self, params_grads, found_inf):
        if self._supports_check_nan_inf():
            with self._train_program._optimized_guard([]):
                update_loss_scaling(
                    [],
                    found_inf,
                    self._loss_scaling,
                    self._num_good_steps,
                    self._num_bad_steps,
                    self._incr_every_n_steps,
                    self._decr_every_n_nan_or_inf,
                    self._incr_ratio,
                    self._decr_ratio,
575
                    stop_update=self._optimizer._get_stop_update_var(),
576 577
                    name="update_loss_scaling",
                )
578 579 580 581 582 583 584
            return

        grads, fp32_grads, fp16_grads = self._split_grads(params_grads)
        if self._use_pure_fp16:
            stop_update = False
            with self._train_program._optimized_guard([]):
                if fp32_grads:
585 586 587 588 589 590 591 592 593 594 595 596 597
                    update_loss_scaling(
                        fp32_grads,
                        found_inf,
                        self._loss_scaling,
                        self._num_good_steps,
                        self._num_bad_steps,
                        self._incr_every_n_steps,
                        self._decr_every_n_nan_or_inf,
                        self._incr_ratio,
                        self._decr_ratio,
                        stop_update=stop_update,
                        name="update_loss_scaling_fp32",
                    )
598 599
                    stop_update = True
                if fp16_grads:
600 601 602 603 604 605 606 607 608 609 610 611 612
                    update_loss_scaling(
                        fp16_grads,
                        found_inf,
                        self._loss_scaling,
                        self._num_good_steps,
                        self._num_bad_steps,
                        self._incr_every_n_steps,
                        self._decr_every_n_nan_or_inf,
                        self._incr_ratio,
                        self._decr_ratio,
                        stop_update=stop_update,
                        name="update_loss_scaling_fp16",
                    )
613
        else:
R
Roc 已提交
614
            with self._train_program._optimized_guard([]):
615 616 617 618 619 620 621 622 623 624 625 626
                update_loss_scaling(
                    grads,
                    found_inf,
                    self._loss_scaling,
                    self._num_good_steps,
                    self._num_bad_steps,
                    self._incr_every_n_steps,
                    self._decr_every_n_nan_or_inf,
                    self._incr_ratio,
                    self._decr_ratio,
                    name="update_loss_scaling",
                )
627

628 629 630 631 632 633
    def apply_optimize(self, loss, startup_program, params_grads):
        program = loss.block.program
        with program_guard(program, startup_program):
            optimize_ops = self.apply_gradients(params_grads)
        return optimize_ops

634 635 636
    def minimize(
        self, loss, startup_program=None, parameter_list=None, no_grad_set=None
    ):
637 638 639 640 641
        """
        Perform optimization by minimizing the given loss.

        Args:
            loss (Variable): The loss Variable.
G
gongweibao 已提交
642 643 644 645
            startup_program (Program): startup_program for initializing parameters
                in `parameter_list`.
            parameter_list (list): list of Variables to update.
            no_grad_set (set|None): set of Variables should be ignored.
646 647 648

        Returns:
            The scaled loss by scaling factor, the list of optimize ops, and a
649
            list of scaled parameters and gradients.
650
        """
651

652
        opt_dict = self._optimizer.__class__.__dict__
653 654 655
        if 'minimize' in opt_dict and isinstance(
            opt_dict['minimize'], types.FunctionType
        ):
656 657 658 659
            warnings.warn(
                "The decorated optimizer has its own `minimize` method, but it will not be executed."
            )

660 661 662 663 664 665
        scaled_params_grads = self.backward(
            loss,
            startup_program=startup_program,
            parameter_list=parameter_list,
            no_grad_set=no_grad_set,
        )
G
gongweibao 已提交
666

667 668 669
        optimize_ops = self.apply_optimize(
            loss, startup_program, scaled_params_grads
        )
670

G
gongweibao 已提交
671
        return optimize_ops, scaled_params_grads
672 673


674
@overload(key=FunctionType.FP16_ONLY)
675 676 677 678 679 680 681 682 683 684 685
def decorate(
    optimizer,
    amp_lists=None,
    init_loss_scaling=2**15,
    incr_every_n_steps=1000,
    decr_every_n_nan_or_inf=2,
    incr_ratio=2.0,
    decr_ratio=0.8,
    use_dynamic_loss_scaling=True,
    use_pure_fp16=False,
    use_fp16_guard=None,
686
    use_bf16=False,
687
    use_promote=False,
688
):
689
    """
690 691 692 693
    Decorate the given optimizer to adapt to the mixed-precision training.

    Args:
        optimizer(Optimizer): A common Optimizer.
H
huangxu96 已提交
694
        amp_lists (CustomOpLists): An CustomOpLists object.
695
        init_loss_scaling(float): The initial loss scaling factor.
696
        incr_every_n_steps(int): Increases loss scaling every n consecutive
J
Jie Fang 已提交
697
                                 steps with finite gradients.
698 699
        decr_every_n_nan_or_inf(int): Decreases loss scaling every n
                                      accumulated steps with nan or
J
Jie Fang 已提交
700
                                      inf gradients.
701
        incr_ratio(float): The multiplier to use when increasing the loss
J
Jie Fang 已提交
702
                           scaling.
703
        decr_ratio(float): The less-than-one-multiplier to use when decreasing
J
Jie Fang 已提交
704
                           the loss scaling.
705
        use_dynamic_loss_scaling(bool): Whether to use dynamic loss scaling.
706 707 708
        use_pure_fp16(bool): Whether to use the pure fp16 training. Default False.
        use_fp16_guard(bool): Whether to use `fp16_guard` when constructing the program.
                           Default None, which means that its value equals to `use_pure_fp16`.
709
        use_bf16(bool): Whether to enable bfloat16 training. Default False.
710 711

    Returns:
712
        An optimizer acting like a normal one but with mixed-precision training
713 714
        enabled.

H
huangxu96 已提交
715
    Examples 1:
716
            .. code-block:: python
H
huangxu96 已提交
717 718 719 720 721 722 723 724 725 726 727 728 729 730

            # black&white list based strategy example
            import paddle
            import paddle.static as static

            paddle.enable_static()

            data = static.data(name='X', shape=[None, 1], dtype='float32')
            hidden = static.nn.fc(x=data, size=10)
            loss = paddle.mean(hidden)
            optimizer = paddle.optimizer.Adam(learning_rate=0.001)

            mp_optimizer = static.amp.decorate(
                    optimizer=optimizer, init_loss_scaling=8.0)
731

G
gongweibao 已提交
732
            ops, param_grads = mp_optimizer.minimize(loss)
733
            scaled_loss = mp_optimizer.get_scaled_loss()
H
huangxu96 已提交
734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755

    Examples 2:
        .. code-block:: python

            # pure fp16 training example
            import numpy as np
            import paddle
            import paddle.nn.functional as F

            def run_example_code():
                place = paddle.CUDAPlace(0)
                exe = paddle.static.Executor(place)
                data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float32')
                conv2d = paddle.static.nn.conv2d(input=data, num_filters=6, filter_size=3)
                # 1) Use fp16_guard to control the range of fp16 kernels used.
                with paddle.static.amp.fp16_guard():
                    bn = paddle.static.nn.batch_norm(input=conv2d, act="relu")
                    pool = F.max_pool2d(bn, kernel_size=2, stride=2)
                    hidden = paddle.static.nn.fc(pool, size=10)
                    loss = paddle.mean(hidden)
                # 2) Create the optimizer and set `multi_precision` to True.
                # Setting `multi_precision` to True can avoid the poor accuracy
756
                # or the slow convergence in a way.
H
huangxu96 已提交
757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775
                optimizer = paddle.optimizer.Momentum(learning_rate=0.01, multi_precision=True)
                # 3) These ops in `custom_black_list` will keep in the float32 computation type.
                amp_list = paddle.static.amp.CustomOpLists(
                    custom_black_list=['pool2d'])
                # 4) The entry of Paddle AMP.
                # Enable pure fp16 training by setting `use_pure_fp16` to True.
                optimizer = paddle.static.amp.decorate(
                    optimizer,
                    amp_list,
                    init_loss_scaling=128.0,
                    use_dynamic_loss_scaling=True,
                    use_pure_fp16=True)
                # If you don't use the default_startup_program(), you sholud pass
                # your defined `startup_program` into `minimize`.
                optimizer.minimize(loss)
                exe.run(paddle.static.default_startup_program())
                # 5) Use `amp_init` after FP32 parameters initialization(such as `exe.run(startup_program)`).
                # If you want to perform the testing process, you should pass `test_program` into `amp_init`.
                optimizer.amp_init(place, scope=paddle.static.global_scope())
776

H
huangxu96 已提交
777 778
            if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0:
                run_example_code()
779
    """
780
    amp_dtype = "bfloat16" if use_bf16 else "float16"
J
Jie Fang 已提交
781
    if amp_lists is None:
782
        amp_lists = AutoMixedPrecisionLists(dtype=amp_dtype)
783 784 785 786

    if use_fp16_guard is None:
        use_fp16_guard = use_pure_fp16

787
    amp_level = "O2" if use_pure_fp16 else "O1"
Z
Zhen Wang 已提交
788
    mp_optimizer = OptimizerWithMixedPrecision(
789 790
        optimizer,
        amp_lists,
791 792 793 794 795 796 797 798 799
        level=amp_level,
        dtype=amp_dtype,
        init_loss_scaling=init_loss_scaling,
        use_dynamic_loss_scaling=use_dynamic_loss_scaling,
        incr_every_n_steps=incr_every_n_steps,
        decr_every_n_nan_or_inf=decr_every_n_nan_or_inf,
        incr_ratio=incr_ratio,
        decr_ratio=decr_ratio,
        use_amp_guard=use_fp16_guard,
800
        use_promote=use_promote,
801 802 803 804 805
    )

    return mp_optimizer


806 807
@overload(key=FunctionType.COMMON)
def decorate(
808 809 810 811
    optimizer,
    amp_lists=None,
    level='O1',
    dtype='float16',
812
    master_weight=None,
813
    master_grad=False,
814 815 816 817 818
    init_loss_scaling=2**15,
    incr_every_n_steps=1000,
    decr_every_n_nan_or_inf=2,
    incr_ratio=2.0,
    decr_ratio=0.8,
819
    use_dynamic_loss_scaling=None,
820
    use_amp_guard=False,
821
    use_promote=False,
822 823 824 825
):
    """
    Decorate the given optimizer to adapt to the mixed-precision training.

826 827 828 829 830
    Args:
        optimizer(Optimizer): A common Optimizer.
        amp_lists(CustomOpLists, optional): An CustomOpLists object. The default
            white_list and black_list will be used for AMP training when it is
            not set. Default is None.
831 832 833 834 835
        level(str, optional): Auto mixed precision level. Accepted values are "O1", "O2" and "OD": At the O1 level, operators in the white list
             will use float16/bfloat16 inputs for calculations, and operators in the black list will use float32 inputs for calculations. At the O2
             level, model's parameters will be casted to float16/bfloat16 by using `decorator`, and operators that have all float16/bfloat16 inputs
             will be converted to float16/bfloat16, and that have any float32 input will be converted to float32. For the OD level, operators in
             default white list will compute in float16/bfloat16, and the others will compute in float32. Default is O1.
836 837 838 839
        dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'.
        master_weight(bool, optinal): For level='O2', whether to use multi-precision
            during weight updating. If master_weight is None, in O2 level optimizer
            will use multi-precision. Default is None.
840 841 842
        master_grad(bool, optinal): For level='O2', whether to use master_grad
            during weight updating. If master_grad is False, in O2 level optimizer
            will not use master grad. Default is False.
843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904
        init_loss_scaling(float, optional): The initial loss scaling factor.
            Default is 32768.
        incr_every_n_steps(int, optional): Increases loss scaling every n
            consecutive steps with finite gradients. Default is 1000.
        decr_every_n_nan_or_inf(int, optional): Decreases loss scaling every n
            accumulated steps with nan or inf gradients. Default is 2.
        incr_ratio(float, optional): The multiplier to use when increasing the
            loss scaling. Default is 2.
        decr_ratio(float, optional): The less-than-one-multiplier to use when
            decreasing the loss scaling. Default is 0.8.
        use_dynamic_loss_scaling(bool, None): Whether to use dynamic loss
            scaling. Default is None, which means True for float16, and False
            for bfloat16.

    Returns:
        An optimizer acting like a normal one but with mixed-precision training

    Examples:

     .. code-block:: python

        import paddle

        paddle.enable_static()

        class SimpleConvNet(paddle.nn.Layer):
            def __init__(self):
                super().__init__()
                self.conv = paddle.nn.Conv2D(in_channels=1, out_channels=6, kernel_size=3)
                self.linear = paddle.nn.Linear(in_features=26, out_features=10)

            def forward(self, x):
                out = self.conv(x)
                out = paddle.nn.functional.relu(out)
                out = self.linear(out)
                out = paddle.nn.functional.softmax(out)
                return out

        main_program = paddle.static.Program()
        startup_program = paddle.static.Program()
        with paddle.utils.unique_name.guard():
            with paddle.static.program_guard(main_program, startup_program):
                model = SimpleConvNet()
                x = paddle.static.data(
                    name='input', shape=[None, 1, 28, 28], dtype='float32'
                )
                out = model(x)
                loss = paddle.mean(out)
                optimizer = paddle.optimizer.AdamW()
                optimizer = paddle.static.amp.decorate(optimizer, level="O2", dtype="float16")
                optimizer.minimize(loss)

        if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0:
            place = paddle.CUDAPlace(0)
            exe = paddle.static.Executor(place)
            exe.run(startup_program)

            # Call `amp_init` after FP32 parameters initialization, such as `exe.run(startup_program)`,
            # to convert FP32 parameters to low precision FP16 / BF16.
            optimizer.amp_init(place, scope=paddle.static.global_scope())

    """
905 906
    # check amp_level: O0-O2
    level = level.upper()
907 908
    if not (level in ['O0', 'OD', 'O1', 'O2']):
        raise ValueError("level should be O0, OD, O1 or O2.")
909

910
    amp_dtype = check_amp_dtype(dtype)
911
    if amp_lists is None or level == 'OD':
912 913
        amp_lists = AutoMixedPrecisionLists(dtype=amp_dtype)

914 915 916 917 918 919 920
    if level == 'OD':
        if amp_lists is not None:
            warnings.warn(
                "If the Amp level is set to OD, the amp list will not be used."
            )
        amp_lists.black_list = amp_lists.all_list - amp_lists.white_list

921 922 923 924 925 926 927 928
    if use_dynamic_loss_scaling is None:
        use_dynamic_loss_scaling = dtype == "float16"

    if optimizer is not None:
        # support master_weight
        multi_precision = not (master_weight is False)
        _set_multi_precision(optimizer, multi_precision)

929 930 931 932 933 934 935 936 937 938 939 940
    mp_optimizer = OptimizerWithMixedPrecision(
        optimizer,
        amp_lists,
        level=level,
        dtype=amp_dtype,
        init_loss_scaling=init_loss_scaling,
        use_dynamic_loss_scaling=use_dynamic_loss_scaling,
        incr_every_n_steps=incr_every_n_steps,
        decr_every_n_nan_or_inf=decr_every_n_nan_or_inf,
        incr_ratio=incr_ratio,
        decr_ratio=decr_ratio,
        use_amp_guard=use_amp_guard,
941
        use_promote=use_promote,
942
        use_master_grad=master_grad,
943
    )
944 945

    return mp_optimizer