decorator.py 32.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
import types
import warnings
17

18
import paddle
19 20 21 22 23 24 25 26 27
from paddle.fluid import (
    core,
    default_main_program,
    default_startup_program,
    program_guard,
    unique_name,
)

from .amp_nn import check_finite_and_unscale, update_loss_scaling
28
from .fp16_lists import AutoMixedPrecisionLists, check_amp_dtype
29 30 31 32 33
from .fp16_utils import (
    cast_model_to_fp16,
    cast_parameters_to_fp16,
    update_role_var_grad,
)
34
from .function_overload import FunctionType, overload
35 36


37
class OptimizerWithMixedPrecision:
38
    """
39
    Optimizer with mixed-precision (MP) training. This is a wrapper of a common
Z
Zhen Wang 已提交
40
    optimizer, plus the support of mixed-precision pre-training. The object
41 42 43
    of this class almost has the same behavior as the common optimizer, with the
    methods `minimize()`, `backward()`, `apply_gradients()` implemented.
    Additionally, it enables the MP training automatically, i.e, the creation
44 45 46 47
    and maintenance of master parameters, scaling of loss, etc.

    Args:
        optimizer (Optimizer): A common Optimizer object.
48 49 50 51 52 53 54 55
        amp_lists (AutoMixedPrecisionLists): An AutoMixedPrecisionLists object.
        level(str): Auto mixed precision level. Accepted values are
            "O1" and "O2": O1 represent mixed precision, the input data type
            of each operator will be casted by white_list and black_list;
            O2 represent Pure fp16 or bf16, all operators parameters and input
            data will be casted to fp16 or bf16, except operators in black_list,
            don't support fp16 or bf16 kernel and batch_norm.
        dtype(str): Whether to use 'float16' or 'bfloat16'.
56 57
        init_loss_scaling (float): The initial loss scaling factor.
        use_dynamic_loss_scaling (bool): Whether to use dynamic loss scaling.
58
        incr_every_n_steps(int): Increases loss scaling every n consecutive
J
Jie Fang 已提交
59
                                 steps with finite gradients.
60 61
        decr_every_n_nan_or_inf(int): Decreases loss scaling every n
                                      accumulated steps with nan or
J
Jie Fang 已提交
62
                                      inf gradients.
63
        incr_ratio(float): The multiplier to use when increasing the loss
J
Jie Fang 已提交
64
                           scaling.
65
        decr_ratio(float): The less-than-one-multiplier to use when decreasing
J
Jie Fang 已提交
66
                           the loss scaling.
67
        use_amp_guard(bool): Whether to use `fp16_guard` when constructing the program.
68
                           Default None, which means that its value is equal to `use_pure_fp16`.
69
        use_promote(bool): Whether to promotes to fp32 when op has any float32 inputs. Default is False.
70 71
    """

72 73 74 75
    def __init__(
        self,
        optimizer,
        amp_lists,
76 77
        level,
        dtype,
78 79 80 81 82 83
        init_loss_scaling,
        use_dynamic_loss_scaling,
        incr_every_n_steps,
        decr_every_n_nan_or_inf,
        incr_ratio,
        decr_ratio,
84
        use_amp_guard=None,
85
        use_promote=False,
86
    ):
87
        self._optimizer = optimizer
J
Jie Fang 已提交
88
        self._amp_lists = amp_lists
89
        self._param_grads = None
90 91
        self._train_program = None

92
        self._is_distributed = False
93
        self._scaled_loss = None
94 95
        self._loss_scaling = None
        self._init_loss_scaling = init_loss_scaling
96
        self._use_dynamic_loss_scaling = use_dynamic_loss_scaling
97 98 99 100 101 102 103 104 105 106 107
        if dtype == "bfloat16":
            if use_dynamic_loss_scaling:
                self._use_dynamic_loss_scaling = False
                self._init_loss_scaling = 1.0
                warnings.warn(
                    "Dynamic loss scaling for bfloat16 amp training is disabled, and the init_loss_scaling is changed to 1.0 automatically by PaddlePaddle."
                )
            self._amp_vartype = core.VarDesc.VarType.BF16
        else:
            self._amp_vartype = core.VarDesc.VarType.FP16

A
Aurelius84 已提交
108 109
        self._learning_rate = optimizer._learning_rate
        self._learning_rate_map = optimizer._learning_rate_map
110 111
        self._use_pure_fp16 = level == "O2"
        self._use_fp16_guard = use_amp_guard
112
        self._to_fp16_var_names = None
J
Jie Fang 已提交
113
        if self._use_dynamic_loss_scaling:
114 115
            self._incr_every_n_steps = incr_every_n_steps
            self._decr_every_n_nan_or_inf = decr_every_n_nan_or_inf
J
Jie Fang 已提交
116 117
            self._incr_ratio = incr_ratio
            self._decr_ratio = decr_ratio
118 119
            self._num_good_steps = None
            self._num_bad_steps = None
120
        self.use_promote = use_promote
121

122 123 124 125 126 127
    def _set_distributed(self, flag):
        # if distributed, all cards will communication with each other,
        # overlap communication and computation by split the
        # check_finite_and_unscale op.
        self._is_distributed = flag

128
    def get_loss_scaling(self):
129 130 131 132
        """Return the real-time loss scaling factor."""
        assert (
            self._loss_scaling is not None
        ), 'Please call minimize() before calling get_loss_scaling().'
133 134 135 136 137 138 139 140
        return self._loss_scaling

    def get_scaled_loss(self):
        """Return the scaled loss.
        It's useful when you feed customed loss into executor.
        """
        return self._scaled_loss

141 142 143
    def _supports_check_nan_inf(self):
        return getattr(self._optimizer, "_supports_check_nan_inf", False)

144
    def _init_amp_var(self):
145
        self._loss_scaling = paddle.static.create_global_var(
146 147 148 149
            name=unique_name.generate("loss_scaling"),
            shape=[1],
            value=self._init_loss_scaling,
            dtype='float32',
150 151
            persistable=True,
        )
152 153

        if self._use_dynamic_loss_scaling:
154
            self._num_good_steps = paddle.static.create_global_var(
J
Jie Fang 已提交
155 156 157 158
                name=unique_name.generate("num_good_steps"),
                shape=[1],
                value=0,
                dtype='int32',
159 160
                persistable=True,
            )
161
            self._num_bad_steps = paddle.static.create_global_var(
J
Jie Fang 已提交
162 163 164 165
                name=unique_name.generate("num_bad_steps"),
                shape=[1],
                value=0,
                dtype='int32',
166 167
                persistable=True,
            )
168

169
        # Ensure the data type of learning rate vars is float32 (same as the
170
        # master parameter dtype)
171
        if isinstance(self._optimizer._learning_rate, float):
172 173
            self._optimizer._learning_rate_map[
                default_main_program()
174
            ] = paddle.static.create_global_var(
175 176 177 178 179 180 181 182 183 184 185 186 187 188 189
                name=unique_name.generate("learning_rate"),
                shape=[1],
                value=float(self._optimizer._learning_rate),
                dtype='float32',
                persistable=True,
            )

    def backward(
        self,
        loss,
        startup_program=None,
        parameter_list=None,
        no_grad_set=None,
        callbacks=None,
    ):
190
        """
Z
Zhen Wang 已提交
191
        Backward propagation or auto differentiation for gradients' computation.
192 193 194

        Args:
            loss (Variable): The loss Variable to minimize.
195
            startup_program (Program|None): The startup Program for initializing
196 197 198
                                       parameters in `parameter_list`.
            parameter_list (list|None): A list of Variables to update.
            no_grad_set (set|None): A set of Variables should be ignored.
Z
Zhen Wang 已提交
199
            callbacks (list|None): A list of callable objects to run when appending
200 201 202
                                   backward operator for one parameter.

        Returns:
203
            A list of (param, grad), which is a tuple of a parameter and its
204 205
            gradient respectively, and the scaled loss.
        """
206 207 208
        train_program = loss.block.program
        self._train_program = train_program

209
        # NOTE(zhiqiu): _float_status is only used for NPU.
K
Kim Yann 已提交
210
        if core.is_compiled_with_custom_device('npu'):
211 212 213
            float_status = paddle.static.data(
                name="float_status", shape=[8], dtype='float32'
            )
214 215
            self._train_program.global_block().append_op(
                type="alloc_float_status",
216 217
                outputs={"FloatStatus": float_status},
            )
218 219 220
            self._train_program.global_block().append_op(
                type="clear_float_status",
                inputs={"FloatStatus": float_status},
221 222
                outputs={"FloatStatusOut": float_status},
            )
223 224 225 226
            self._float_status = float_status
        else:
            self._float_status = None

227
        with program_guard(self._train_program, startup_program):
228 229
            self._init_amp_var()

230 231
            if self._use_pure_fp16:
                self._to_fp16_var_names = cast_model_to_fp16(
232 233 234 235
                    self._train_program,
                    self._amp_lists,
                    self._use_fp16_guard,
                    self._amp_vartype,
236 237
                    level='O2',
                    use_promote=self.use_promote,
238
                )
239
            else:
240 241 242 243 244 245 246 247
                # use_fp16_guard is not support amp-o1.
                cast_model_to_fp16(
                    self._train_program,
                    self._amp_lists,
                    use_fp16_guard=False,
                    dest_type=self._amp_vartype,
                    level='O1',
                    use_promote=self.use_promote,
248
                )
249 250 251 252 253 254 255 256 257 258

            if loss.dtype != core.VarDesc.VarType.FP32:
                loss = loss.astype('float32')
            # When not using dynamic loss scaling and the init loss scaling value is equal to 1.0,
            # the model can be optimized.
            if self._use_dynamic_loss_scaling or self._init_loss_scaling != 1.0:
                self._scaled_loss = loss * self._loss_scaling
            else:
                self._scaled_loss = loss

259 260 261 262 263 264 265
            params_grads = self._optimizer.backward(
                self._scaled_loss,
                startup_program,
                parameter_list,
                no_grad_set,
                callbacks,
            )
266 267
            if self._supports_check_nan_inf():
                self._add_cast_ops_to_startup_program(startup_program)
268
        return params_grads
269

270 271 272
    def _add_cast_ops_to_startup_program(self, startup_program):
        names = list(self._to_fp16_var_names) if self._to_fp16_var_names else []
        names.sort()
273 274 275 276 277
        startup_program = (
            default_startup_program()
            if startup_program is None
            else startup_program
        )
278 279 280 281 282 283 284
        block = startup_program.global_block()
        param_names = [p.name for p in block.all_parameters()]
        for name in names:
            if name not in param_names:
                continue

            tmp = block.create_var(dtype=core.VarDesc.VarType.FP32)
285 286 287 288 289 290 291 292 293
            block.append_op(
                type='assign', inputs={'X': [name]}, outputs={'Out': [tmp]}
            )
            block.append_op(
                type='cast',
                inputs={'X': [tmp]},
                outputs={'Out': [name]},
                attrs={
                    'in_dtype': core.VarDesc.VarType.FP32,
294
                    'out_dtype': self._amp_vartype,
295 296
                },
            )
297 298
        self._to_fp16_var_names = None

299 300 301
    def amp_init(
        self, place, scope=None, test_program=None, use_fp16_test=False
    ):
302 303
        """
        Init the amp training, such as cast fp32 parameters to fp16 type.
304

305
        Args:
306
            place(CUDAPlace): place is used to initialize
307 308 309 310 311
                fp16 parameters with fp32 values.
            scope(Scope): The scope is used to find fp32 parameters.
            test_program(Program): The program is used for testing.
            use_fp16_test(bool): Whether to use fp16 testing.

H
huangxu96 已提交
312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332
        Examples:
            .. code-block:: python

                import numpy as np
                import paddle
                import paddle.nn.functional as F
                paddle.enable_static()

                def run_example_code():
                    place = paddle.CUDAPlace(0)
                    exe = paddle.static.Executor(place)
                    data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float32')
                    conv2d = paddle.static.nn.conv2d(input=data, num_filters=6, filter_size=3)
                    # 1) Use fp16_guard to control the range of fp16 kernels used.
                    with paddle.static.amp.fp16_guard():
                        bn = paddle.static.nn.batch_norm(input=conv2d, act="relu")
                        pool = F.max_pool2d(bn, kernel_size=2, stride=2)
                        hidden = paddle.static.nn.fc(pool, size=10)
                        loss = paddle.mean(hidden)
                    # 2) Create the optimizer and set `multi_precision` to True.
                    # Setting `multi_precision` to True can avoid the poor accuracy
333
                    # or the slow convergence in a way.
H
huangxu96 已提交
334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352
                    optimizer = paddle.optimizer.Momentum(learning_rate=0.01, multi_precision=True)
                    # 3) These ops in `custom_black_list` will keep in the float32 computation type.
                    amp_list = paddle.static.amp.CustomOpLists(
                        custom_black_list=['pool2d'])
                    # 4) The entry of Paddle AMP.
                    # Enable pure fp16 training by setting `use_pure_fp16` to True.
                    optimizer = paddle.static.amp.decorate(
                        optimizer,
                        amp_list,
                        init_loss_scaling=128.0,
                        use_dynamic_loss_scaling=True,
                        use_pure_fp16=True)
                    # If you don't use the default_startup_program(), you sholud pass
                    # your defined `startup_program` into `minimize`.
                    optimizer.minimize(loss)
                    exe.run(paddle.static.default_startup_program())
                    # 5) Use `amp_init` after FP32 parameters initialization(such as `exe.run(startup_program)`).
                    # If you want to perform the testing process, you should pass `test_program` into `amp_init`.
                    optimizer.amp_init(place, scope=paddle.static.global_scope())
353

H
huangxu96 已提交
354
                if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0:
355
                    run_example_code()
356
        """
357 358 359
        assert (
            self._train_program is not None
        ), "Please call the minimize method first."
360
        if self._use_pure_fp16:
361
            cast_parameters_to_fp16(
362 363 364 365 366
                place,
                self._train_program,
                scope,
                self._to_fp16_var_names,
                self._amp_vartype,
367
            )
368 369
        if test_program is not None:
            if self._use_pure_fp16:
370
                cast_model_to_fp16(
371 372 373 374
                    test_program,
                    self._amp_lists,
                    self._use_fp16_guard,
                    self._amp_vartype,
375 376
                    level='O2',
                    use_promote=self.use_promote,
377
                )
378
            elif use_fp16_test:
379 380 381 382 383 384 385 386
                # use_fp16_guard is not support amp-o1.
                cast_model_to_fp16(
                    test_program,
                    self._amp_lists,
                    use_fp16_guard=False,
                    dest_type=self._amp_vartype,
                    level='O1',
                    use_promote=self.use_promote,
387
                )
388

389
    def apply_gradients(self, params_grads):
390
        """
391
        Check scaled gradients to determine whether to update loss scaling and update
392
        parameters by their scaled gradients.
393

394
        Args:
395
            params_grads (list): A list of params and scaled grads.
396

397 398 399
        Returns:
            A list of optimize operators.
        """
J
Jie Fang 已提交
400

401 402 403 404
        # Change the op_role_var attr for some ops, so that gradients
        # transferred across GPUs can be FP16.
        update_role_var_grad(self._train_program, params_grads)

405 406
        # When not using dynamic loss scaling and the init loss scaling value is equal to 1.0,
        # the model can be optimized.
407 408 409 410
        if (
            not self._use_dynamic_loss_scaling
            and self._init_loss_scaling == 1.0
        ):
411 412
            return self._optimizer.apply_gradients(params_grads)

413 414 415 416 417 418 419 420
        if self._supports_check_nan_inf():
            self._optimizer._set_scale(self._loss_scaling)
            optimize_ops = self._optimizer.apply_gradients(params_grads)
            found_inf = self._optimizer._found_inf
            self._add_dynamic_loss_scaling(params_grads, found_inf)
            return optimize_ops

        found_inf = self._check_finite_and_unscale(params_grads)
421 422 423 424
        if (
            self._use_dynamic_loss_scaling
            and self._amp_vartype == core.VarDesc.VarType.FP16
        ):
425 426 427 428 429 430 431
            self._add_dynamic_loss_scaling(params_grads, found_inf)

        # Pass found_inf to adam, to skip update for not only param, but also momentum and beta_pow
        # With fleet, optimizers are nested and the real optimizer set by user is the inner most one.
        real_optimizer = self._optimizer
        while hasattr(real_optimizer, "inner_opt"):
            real_optimizer = real_optimizer.inner_opt
432 433 434 435
        if isinstance(
            real_optimizer,
            (paddle.fluid.optimizer.Adam, paddle.optimizer.AdamW),
        ):
436 437 438
            # NOTE(zhiqiu): Since found_inf needs to be on cpu in adam op, we
            # copy it in advance to avoid multiple time copies.
            with self._train_program._optimized_guard([]):
439
                found_inf = paddle.tensor.creation._memcpy(
440 441
                    found_inf, paddle.CPUPlace()
                )
442 443 444 445 446 447 448
            real_optimizer._set_auxiliary_var('found_inf', found_inf)
        elif hasattr(real_optimizer, "_set_auxiliary_var"):
            real_optimizer._set_auxiliary_var('found_inf', found_inf)
        optimize_ops = self._optimizer.apply_gradients(params_grads)
        return optimize_ops

    def _split_grads(self, params_grads):
449
        grads = [g for _, g in params_grads]
450
        fp32_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP32]
451
        fp16_grads = [g for g in grads if g.dtype == self._amp_vartype]
452 453
        assert len(fp32_grads) + len(fp16_grads) == len(
            grads
454
        ), "Data types of all grads must be either fp16/bf16 or fp32."
455
        return grads, fp32_grads, fp16_grads
456

457 458
    def _check_finite_and_unscale(self, params_grads):
        grads, fp32_grads, fp16_grads = self._split_grads(params_grads)
459
        found_infs = []
460

461
        if self._is_distributed:
462 463
            # if distributed, split check_finite_and_unscale to overlap
            # unscale with communication
K
Kim Yann 已提交
464
            if core.is_compiled_with_custom_device('npu'):
465
                with self._train_program._optimized_guard(grads):
466
                    _, found_inf = check_finite_and_unscale(
467
                        grads,
468 469
                        self._loss_scaling,
                        name="find_infinite_scale",
470 471
                        float_status=self._float_status,
                    )
472
                    found_infs.append(found_inf)
473 474 475 476
            else:
                for p, g in params_grads:
                    with self._train_program._optimized_guard([p, g]):
                        _, found_inf = check_finite_and_unscale(
477 478 479
                            [
                                g,
                            ],
480 481
                            self._loss_scaling,
                            name="find_infinite_scale",
482 483
                            float_status=self._float_status,
                        )
484
                        found_infs.append(found_inf)
485 486 487 488 489 490
        elif self._use_pure_fp16:
            if fp32_grads:
                with self._train_program._optimized_guard(fp32_grads):
                    _, fp32_found_inf = check_finite_and_unscale(
                        fp32_grads,
                        self._loss_scaling,
491
                        name="find_infinite_scale_fp32",
492 493
                        float_status=self._float_status,
                    )
494 495 496 497 498 499
                found_infs.append(fp32_found_inf)
            if fp16_grads:
                with self._train_program._optimized_guard(fp16_grads):
                    _, fp16_found_inf = check_finite_and_unscale(
                        fp16_grads,
                        self._loss_scaling,
500
                        name="find_infinite_scale_fp16",
501 502
                        float_status=self._float_status,
                    )
503 504 505 506
                found_infs.append(fp16_found_inf)
        else:
            with self._train_program._optimized_guard(grads):
                _, found_inf = check_finite_and_unscale(
507 508 509
                    grads,
                    self._loss_scaling,
                    name="find_infinite_scale",
510 511
                    float_status=self._float_status,
                )
J
Jie Fang 已提交
512

513 514
        if self._is_distributed or self._use_pure_fp16:
            with self._train_program._optimized_guard([]):
515
                all_infs = paddle.concat(found_infs)
516
                found_inf = paddle.any(all_infs)
517

518 519 520 521 522 523 524 525 526 527 528 529 530 531 532
        return found_inf

    def _add_dynamic_loss_scaling(self, params_grads, found_inf):
        if self._supports_check_nan_inf():
            with self._train_program._optimized_guard([]):
                update_loss_scaling(
                    [],
                    found_inf,
                    self._loss_scaling,
                    self._num_good_steps,
                    self._num_bad_steps,
                    self._incr_every_n_steps,
                    self._decr_every_n_nan_or_inf,
                    self._incr_ratio,
                    self._decr_ratio,
533
                    stop_update=self._optimizer._get_stop_update_var(),
534 535
                    name="update_loss_scaling",
                )
536 537 538 539 540 541 542
            return

        grads, fp32_grads, fp16_grads = self._split_grads(params_grads)
        if self._use_pure_fp16:
            stop_update = False
            with self._train_program._optimized_guard([]):
                if fp32_grads:
543 544 545 546 547 548 549 550 551 552 553 554 555
                    update_loss_scaling(
                        fp32_grads,
                        found_inf,
                        self._loss_scaling,
                        self._num_good_steps,
                        self._num_bad_steps,
                        self._incr_every_n_steps,
                        self._decr_every_n_nan_or_inf,
                        self._incr_ratio,
                        self._decr_ratio,
                        stop_update=stop_update,
                        name="update_loss_scaling_fp32",
                    )
556 557
                    stop_update = True
                if fp16_grads:
558 559 560 561 562 563 564 565 566 567 568 569 570
                    update_loss_scaling(
                        fp16_grads,
                        found_inf,
                        self._loss_scaling,
                        self._num_good_steps,
                        self._num_bad_steps,
                        self._incr_every_n_steps,
                        self._decr_every_n_nan_or_inf,
                        self._incr_ratio,
                        self._decr_ratio,
                        stop_update=stop_update,
                        name="update_loss_scaling_fp16",
                    )
571
        else:
R
Roc 已提交
572
            with self._train_program._optimized_guard([]):
573 574 575 576 577 578 579 580 581 582 583 584
                update_loss_scaling(
                    grads,
                    found_inf,
                    self._loss_scaling,
                    self._num_good_steps,
                    self._num_bad_steps,
                    self._incr_every_n_steps,
                    self._decr_every_n_nan_or_inf,
                    self._incr_ratio,
                    self._decr_ratio,
                    name="update_loss_scaling",
                )
585

586 587 588 589 590 591
    def apply_optimize(self, loss, startup_program, params_grads):
        program = loss.block.program
        with program_guard(program, startup_program):
            optimize_ops = self.apply_gradients(params_grads)
        return optimize_ops

592 593 594
    def minimize(
        self, loss, startup_program=None, parameter_list=None, no_grad_set=None
    ):
595 596 597 598 599
        """
        Perform optimization by minimizing the given loss.

        Args:
            loss (Variable): The loss Variable.
G
gongweibao 已提交
600 601 602 603
            startup_program (Program): startup_program for initializing parameters
                in `parameter_list`.
            parameter_list (list): list of Variables to update.
            no_grad_set (set|None): set of Variables should be ignored.
604 605 606

        Returns:
            The scaled loss by scaling factor, the list of optimize ops, and a
607
            list of scaled parameters and gradients.
608
        """
609

610
        opt_dict = self._optimizer.__class__.__dict__
611 612 613
        if 'minimize' in opt_dict and isinstance(
            opt_dict['minimize'], types.FunctionType
        ):
614 615 616 617
            warnings.warn(
                "The decorated optimizer has its own `minimize` method, but it will not be executed."
            )

618 619 620 621 622 623
        scaled_params_grads = self.backward(
            loss,
            startup_program=startup_program,
            parameter_list=parameter_list,
            no_grad_set=no_grad_set,
        )
G
gongweibao 已提交
624

625 626 627
        optimize_ops = self.apply_optimize(
            loss, startup_program, scaled_params_grads
        )
628

G
gongweibao 已提交
629
        return optimize_ops, scaled_params_grads
630 631


632
@overload(key=FunctionType.FP16_ONLY)
633 634 635 636 637 638 639 640 641 642 643
def decorate(
    optimizer,
    amp_lists=None,
    init_loss_scaling=2**15,
    incr_every_n_steps=1000,
    decr_every_n_nan_or_inf=2,
    incr_ratio=2.0,
    decr_ratio=0.8,
    use_dynamic_loss_scaling=True,
    use_pure_fp16=False,
    use_fp16_guard=None,
644
    use_bf16=False,
645
    use_promote=False,
646
):
647
    """
648 649 650 651
    Decorate the given optimizer to adapt to the mixed-precision training.

    Args:
        optimizer(Optimizer): A common Optimizer.
H
huangxu96 已提交
652
        amp_lists (CustomOpLists): An CustomOpLists object.
653
        init_loss_scaling(float): The initial loss scaling factor.
654
        incr_every_n_steps(int): Increases loss scaling every n consecutive
J
Jie Fang 已提交
655
                                 steps with finite gradients.
656 657
        decr_every_n_nan_or_inf(int): Decreases loss scaling every n
                                      accumulated steps with nan or
J
Jie Fang 已提交
658
                                      inf gradients.
659
        incr_ratio(float): The multiplier to use when increasing the loss
J
Jie Fang 已提交
660
                           scaling.
661
        decr_ratio(float): The less-than-one-multiplier to use when decreasing
J
Jie Fang 已提交
662
                           the loss scaling.
663
        use_dynamic_loss_scaling(bool): Whether to use dynamic loss scaling.
664 665 666
        use_pure_fp16(bool): Whether to use the pure fp16 training. Default False.
        use_fp16_guard(bool): Whether to use `fp16_guard` when constructing the program.
                           Default None, which means that its value equals to `use_pure_fp16`.
667
        use_bf16(bool): Whether to enable bfloat16 training. Default False.
668 669

    Returns:
670
        An optimizer acting like a normal one but with mixed-precision training
671 672
        enabled.

H
huangxu96 已提交
673
    Examples 1:
674
            .. code-block:: python
H
huangxu96 已提交
675 676 677 678 679 680 681 682 683 684 685 686 687 688

            # black&white list based strategy example
            import paddle
            import paddle.static as static

            paddle.enable_static()

            data = static.data(name='X', shape=[None, 1], dtype='float32')
            hidden = static.nn.fc(x=data, size=10)
            loss = paddle.mean(hidden)
            optimizer = paddle.optimizer.Adam(learning_rate=0.001)

            mp_optimizer = static.amp.decorate(
                    optimizer=optimizer, init_loss_scaling=8.0)
689

G
gongweibao 已提交
690
            ops, param_grads = mp_optimizer.minimize(loss)
691
            scaled_loss = mp_optimizer.get_scaled_loss()
H
huangxu96 已提交
692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713

    Examples 2:
        .. code-block:: python

            # pure fp16 training example
            import numpy as np
            import paddle
            import paddle.nn.functional as F

            def run_example_code():
                place = paddle.CUDAPlace(0)
                exe = paddle.static.Executor(place)
                data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float32')
                conv2d = paddle.static.nn.conv2d(input=data, num_filters=6, filter_size=3)
                # 1) Use fp16_guard to control the range of fp16 kernels used.
                with paddle.static.amp.fp16_guard():
                    bn = paddle.static.nn.batch_norm(input=conv2d, act="relu")
                    pool = F.max_pool2d(bn, kernel_size=2, stride=2)
                    hidden = paddle.static.nn.fc(pool, size=10)
                    loss = paddle.mean(hidden)
                # 2) Create the optimizer and set `multi_precision` to True.
                # Setting `multi_precision` to True can avoid the poor accuracy
714
                # or the slow convergence in a way.
H
huangxu96 已提交
715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733
                optimizer = paddle.optimizer.Momentum(learning_rate=0.01, multi_precision=True)
                # 3) These ops in `custom_black_list` will keep in the float32 computation type.
                amp_list = paddle.static.amp.CustomOpLists(
                    custom_black_list=['pool2d'])
                # 4) The entry of Paddle AMP.
                # Enable pure fp16 training by setting `use_pure_fp16` to True.
                optimizer = paddle.static.amp.decorate(
                    optimizer,
                    amp_list,
                    init_loss_scaling=128.0,
                    use_dynamic_loss_scaling=True,
                    use_pure_fp16=True)
                # If you don't use the default_startup_program(), you sholud pass
                # your defined `startup_program` into `minimize`.
                optimizer.minimize(loss)
                exe.run(paddle.static.default_startup_program())
                # 5) Use `amp_init` after FP32 parameters initialization(such as `exe.run(startup_program)`).
                # If you want to perform the testing process, you should pass `test_program` into `amp_init`.
                optimizer.amp_init(place, scope=paddle.static.global_scope())
734

H
huangxu96 已提交
735 736
            if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0:
                run_example_code()
737
    """
738
    amp_dtype = "bfloat16" if use_bf16 else "float16"
J
Jie Fang 已提交
739
    if amp_lists is None:
740
        amp_lists = AutoMixedPrecisionLists(dtype=amp_dtype)
741 742 743 744

    if use_fp16_guard is None:
        use_fp16_guard = use_pure_fp16

745
    amp_level = "O2" if use_pure_fp16 else "O1"
Z
Zhen Wang 已提交
746
    mp_optimizer = OptimizerWithMixedPrecision(
747 748
        optimizer,
        amp_lists,
749 750 751 752 753 754 755 756 757
        level=amp_level,
        dtype=amp_dtype,
        init_loss_scaling=init_loss_scaling,
        use_dynamic_loss_scaling=use_dynamic_loss_scaling,
        incr_every_n_steps=incr_every_n_steps,
        decr_every_n_nan_or_inf=decr_every_n_nan_or_inf,
        incr_ratio=incr_ratio,
        decr_ratio=decr_ratio,
        use_amp_guard=use_fp16_guard,
758
        use_promote=use_promote,
759 760 761 762 763
    )

    return mp_optimizer


764 765
@overload(key=FunctionType.COMMON)
def decorate(
766 767 768 769 770 771 772 773 774 775 776
    optimizer,
    amp_lists=None,
    level='O1',
    dtype='float16',
    init_loss_scaling=2**15,
    incr_every_n_steps=1000,
    decr_every_n_nan_or_inf=2,
    incr_ratio=2.0,
    decr_ratio=0.8,
    use_dynamic_loss_scaling=True,
    use_amp_guard=False,
777
    use_promote=False,
778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804
):
    """
    Decorate the given optimizer to adapt to the mixed-precision training.
    """
    amp_dtype = check_amp_dtype(dtype)
    if amp_lists is None:
        amp_lists = AutoMixedPrecisionLists(dtype=amp_dtype)

    # check amp_level: O0-O2
    level = level.upper()
    if not (level in ['O0', 'O1', 'O2']):
        raise ValueError(
            "level should be O0, O1 or O2. O0 represents fp32 train mode, O1 represents AMP train mode, O2 represents pure fp16/bf16 train mode."
        )

    mp_optimizer = OptimizerWithMixedPrecision(
        optimizer,
        amp_lists,
        level=level,
        dtype=amp_dtype,
        init_loss_scaling=init_loss_scaling,
        use_dynamic_loss_scaling=use_dynamic_loss_scaling,
        incr_every_n_steps=incr_every_n_steps,
        decr_every_n_nan_or_inf=decr_every_n_nan_or_inf,
        incr_ratio=incr_ratio,
        decr_ratio=decr_ratio,
        use_amp_guard=use_amp_guard,
805
        use_promote=use_promote,
806
    )
807 808

    return mp_optimizer