decorator.py 31.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
import types
import warnings
17

18
import paddle
19 20 21 22 23 24 25 26 27
from paddle.fluid import (
    core,
    default_main_program,
    default_startup_program,
    program_guard,
    unique_name,
)

from .amp_nn import check_finite_and_unscale, update_loss_scaling
28
from .fp16_lists import AutoMixedPrecisionLists, check_amp_dtype
29 30 31 32 33 34
from .fp16_utils import (
    cast_model_to_fp16,
    cast_parameters_to_fp16,
    rewrite_program,
    update_role_var_grad,
)
35
from .function_overload import FunctionType, overload
36 37


38
class OptimizerWithMixedPrecision:
39
    """
40
    Optimizer with mixed-precision (MP) training. This is a wrapper of a common
Z
Zhen Wang 已提交
41
    optimizer, plus the support of mixed-precision pre-training. The object
42 43 44
    of this class almost has the same behavior as the common optimizer, with the
    methods `minimize()`, `backward()`, `apply_gradients()` implemented.
    Additionally, it enables the MP training automatically, i.e, the creation
45 46 47 48
    and maintenance of master parameters, scaling of loss, etc.

    Args:
        optimizer (Optimizer): A common Optimizer object.
49 50 51 52 53 54 55 56
        amp_lists (AutoMixedPrecisionLists): An AutoMixedPrecisionLists object.
        level(str): Auto mixed precision level. Accepted values are
            "O1" and "O2": O1 represent mixed precision, the input data type
            of each operator will be casted by white_list and black_list;
            O2 represent Pure fp16 or bf16, all operators parameters and input
            data will be casted to fp16 or bf16, except operators in black_list,
            don't support fp16 or bf16 kernel and batch_norm.
        dtype(str): Whether to use 'float16' or 'bfloat16'.
57 58
        init_loss_scaling (float): The initial loss scaling factor.
        use_dynamic_loss_scaling (bool): Whether to use dynamic loss scaling.
59
        incr_every_n_steps(int): Increases loss scaling every n consecutive
J
Jie Fang 已提交
60
                                 steps with finite gradients.
61 62
        decr_every_n_nan_or_inf(int): Decreases loss scaling every n
                                      accumulated steps with nan or
J
Jie Fang 已提交
63
                                      inf gradients.
64
        incr_ratio(float): The multiplier to use when increasing the loss
J
Jie Fang 已提交
65
                           scaling.
66
        decr_ratio(float): The less-than-one-multiplier to use when decreasing
J
Jie Fang 已提交
67
                           the loss scaling.
68
        use_amp_guard(bool): Whether to use `fp16_guard` when constructing the program.
69
                           Default None, which means that its value is equal to `use_pure_fp16`.
70 71
    """

72 73 74 75
    def __init__(
        self,
        optimizer,
        amp_lists,
76 77
        level,
        dtype,
78 79 80 81 82 83
        init_loss_scaling,
        use_dynamic_loss_scaling,
        incr_every_n_steps,
        decr_every_n_nan_or_inf,
        incr_ratio,
        decr_ratio,
84
        use_amp_guard=None,
85
    ):
86
        self._optimizer = optimizer
J
Jie Fang 已提交
87
        self._amp_lists = amp_lists
88
        self._param_grads = None
89 90
        self._train_program = None

91
        self._is_distributed = False
92
        self._scaled_loss = None
93 94
        self._loss_scaling = None
        self._init_loss_scaling = init_loss_scaling
95
        self._use_dynamic_loss_scaling = use_dynamic_loss_scaling
96 97 98 99 100 101 102 103 104 105 106
        if dtype == "bfloat16":
            if use_dynamic_loss_scaling:
                self._use_dynamic_loss_scaling = False
                self._init_loss_scaling = 1.0
                warnings.warn(
                    "Dynamic loss scaling for bfloat16 amp training is disabled, and the init_loss_scaling is changed to 1.0 automatically by PaddlePaddle."
                )
            self._amp_vartype = core.VarDesc.VarType.BF16
        else:
            self._amp_vartype = core.VarDesc.VarType.FP16

A
Aurelius84 已提交
107 108
        self._learning_rate = optimizer._learning_rate
        self._learning_rate_map = optimizer._learning_rate_map
109 110
        self._use_pure_fp16 = level == "O2"
        self._use_fp16_guard = use_amp_guard
111
        self._to_fp16_var_names = None
J
Jie Fang 已提交
112
        if self._use_dynamic_loss_scaling:
113 114
            self._incr_every_n_steps = incr_every_n_steps
            self._decr_every_n_nan_or_inf = decr_every_n_nan_or_inf
J
Jie Fang 已提交
115 116
            self._incr_ratio = incr_ratio
            self._decr_ratio = decr_ratio
117 118 119
            self._num_good_steps = None
            self._num_bad_steps = None

120 121 122 123 124 125
    def _set_distributed(self, flag):
        # if distributed, all cards will communication with each other,
        # overlap communication and computation by split the
        # check_finite_and_unscale op.
        self._is_distributed = flag

126
    def get_loss_scaling(self):
127 128 129 130
        """Return the real-time loss scaling factor."""
        assert (
            self._loss_scaling is not None
        ), 'Please call minimize() before calling get_loss_scaling().'
131 132 133 134 135 136 137 138
        return self._loss_scaling

    def get_scaled_loss(self):
        """Return the scaled loss.
        It's useful when you feed customed loss into executor.
        """
        return self._scaled_loss

139 140 141
    def _supports_check_nan_inf(self):
        return getattr(self._optimizer, "_supports_check_nan_inf", False)

142
    def _init_amp_var(self):
143
        self._loss_scaling = paddle.static.create_global_var(
144 145 146 147
            name=unique_name.generate("loss_scaling"),
            shape=[1],
            value=self._init_loss_scaling,
            dtype='float32',
148 149
            persistable=True,
        )
150 151

        if self._use_dynamic_loss_scaling:
152
            self._num_good_steps = paddle.static.create_global_var(
J
Jie Fang 已提交
153 154 155 156
                name=unique_name.generate("num_good_steps"),
                shape=[1],
                value=0,
                dtype='int32',
157 158
                persistable=True,
            )
159
            self._num_bad_steps = paddle.static.create_global_var(
J
Jie Fang 已提交
160 161 162 163
                name=unique_name.generate("num_bad_steps"),
                shape=[1],
                value=0,
                dtype='int32',
164 165
                persistable=True,
            )
166

167
        # Ensure the data type of learning rate vars is float32 (same as the
168
        # master parameter dtype)
169
        if isinstance(self._optimizer._learning_rate, float):
170 171
            self._optimizer._learning_rate_map[
                default_main_program()
172
            ] = paddle.static.create_global_var(
173 174 175 176 177 178 179 180 181 182 183 184 185 186 187
                name=unique_name.generate("learning_rate"),
                shape=[1],
                value=float(self._optimizer._learning_rate),
                dtype='float32',
                persistable=True,
            )

    def backward(
        self,
        loss,
        startup_program=None,
        parameter_list=None,
        no_grad_set=None,
        callbacks=None,
    ):
188
        """
Z
Zhen Wang 已提交
189
        Backward propagation or auto differentiation for gradients' computation.
190 191 192

        Args:
            loss (Variable): The loss Variable to minimize.
193
            startup_program (Program|None): The startup Program for initializing
194 195 196
                                       parameters in `parameter_list`.
            parameter_list (list|None): A list of Variables to update.
            no_grad_set (set|None): A set of Variables should be ignored.
Z
Zhen Wang 已提交
197
            callbacks (list|None): A list of callable objects to run when appending
198 199 200
                                   backward operator for one parameter.

        Returns:
201
            A list of (param, grad), which is a tuple of a parameter and its
202 203
            gradient respectively, and the scaled loss.
        """
204 205 206
        train_program = loss.block.program
        self._train_program = train_program

207
        # NOTE(zhiqiu): _float_status is only used for NPU.
K
Kim Yann 已提交
208
        if core.is_compiled_with_custom_device('npu'):
209 210 211
            float_status = paddle.static.data(
                name="float_status", shape=[8], dtype='float32'
            )
212 213
            self._train_program.global_block().append_op(
                type="alloc_float_status",
214 215
                outputs={"FloatStatus": float_status},
            )
216 217 218
            self._train_program.global_block().append_op(
                type="clear_float_status",
                inputs={"FloatStatus": float_status},
219 220
                outputs={"FloatStatusOut": float_status},
            )
221 222 223 224
            self._float_status = float_status
        else:
            self._float_status = None

225
        with program_guard(self._train_program, startup_program):
226 227
            self._init_amp_var()

228 229
            if self._use_pure_fp16:
                self._to_fp16_var_names = cast_model_to_fp16(
230 231 232 233
                    self._train_program,
                    self._amp_lists,
                    self._use_fp16_guard,
                    self._amp_vartype,
234
                )
235
            else:
236 237 238
                rewrite_program(
                    self._train_program, self._amp_lists, self._amp_vartype
                )
239 240 241 242 243 244 245 246 247 248

            if loss.dtype != core.VarDesc.VarType.FP32:
                loss = loss.astype('float32')
            # When not using dynamic loss scaling and the init loss scaling value is equal to 1.0,
            # the model can be optimized.
            if self._use_dynamic_loss_scaling or self._init_loss_scaling != 1.0:
                self._scaled_loss = loss * self._loss_scaling
            else:
                self._scaled_loss = loss

249 250 251 252 253 254 255
            params_grads = self._optimizer.backward(
                self._scaled_loss,
                startup_program,
                parameter_list,
                no_grad_set,
                callbacks,
            )
256 257
            if self._supports_check_nan_inf():
                self._add_cast_ops_to_startup_program(startup_program)
258
        return params_grads
259

260 261 262
    def _add_cast_ops_to_startup_program(self, startup_program):
        names = list(self._to_fp16_var_names) if self._to_fp16_var_names else []
        names.sort()
263 264 265 266 267
        startup_program = (
            default_startup_program()
            if startup_program is None
            else startup_program
        )
268 269 270 271 272 273 274
        block = startup_program.global_block()
        param_names = [p.name for p in block.all_parameters()]
        for name in names:
            if name not in param_names:
                continue

            tmp = block.create_var(dtype=core.VarDesc.VarType.FP32)
275 276 277 278 279 280 281 282 283
            block.append_op(
                type='assign', inputs={'X': [name]}, outputs={'Out': [tmp]}
            )
            block.append_op(
                type='cast',
                inputs={'X': [tmp]},
                outputs={'Out': [name]},
                attrs={
                    'in_dtype': core.VarDesc.VarType.FP32,
284
                    'out_dtype': self._amp_vartype,
285 286
                },
            )
287 288
        self._to_fp16_var_names = None

289 290 291
    def amp_init(
        self, place, scope=None, test_program=None, use_fp16_test=False
    ):
292 293
        """
        Init the amp training, such as cast fp32 parameters to fp16 type.
294

295
        Args:
296
            place(CUDAPlace): place is used to initialize
297 298 299 300 301
                fp16 parameters with fp32 values.
            scope(Scope): The scope is used to find fp32 parameters.
            test_program(Program): The program is used for testing.
            use_fp16_test(bool): Whether to use fp16 testing.

H
huangxu96 已提交
302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322
        Examples:
            .. code-block:: python

                import numpy as np
                import paddle
                import paddle.nn.functional as F
                paddle.enable_static()

                def run_example_code():
                    place = paddle.CUDAPlace(0)
                    exe = paddle.static.Executor(place)
                    data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float32')
                    conv2d = paddle.static.nn.conv2d(input=data, num_filters=6, filter_size=3)
                    # 1) Use fp16_guard to control the range of fp16 kernels used.
                    with paddle.static.amp.fp16_guard():
                        bn = paddle.static.nn.batch_norm(input=conv2d, act="relu")
                        pool = F.max_pool2d(bn, kernel_size=2, stride=2)
                        hidden = paddle.static.nn.fc(pool, size=10)
                        loss = paddle.mean(hidden)
                    # 2) Create the optimizer and set `multi_precision` to True.
                    # Setting `multi_precision` to True can avoid the poor accuracy
323
                    # or the slow convergence in a way.
H
huangxu96 已提交
324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342
                    optimizer = paddle.optimizer.Momentum(learning_rate=0.01, multi_precision=True)
                    # 3) These ops in `custom_black_list` will keep in the float32 computation type.
                    amp_list = paddle.static.amp.CustomOpLists(
                        custom_black_list=['pool2d'])
                    # 4) The entry of Paddle AMP.
                    # Enable pure fp16 training by setting `use_pure_fp16` to True.
                    optimizer = paddle.static.amp.decorate(
                        optimizer,
                        amp_list,
                        init_loss_scaling=128.0,
                        use_dynamic_loss_scaling=True,
                        use_pure_fp16=True)
                    # If you don't use the default_startup_program(), you sholud pass
                    # your defined `startup_program` into `minimize`.
                    optimizer.minimize(loss)
                    exe.run(paddle.static.default_startup_program())
                    # 5) Use `amp_init` after FP32 parameters initialization(such as `exe.run(startup_program)`).
                    # If you want to perform the testing process, you should pass `test_program` into `amp_init`.
                    optimizer.amp_init(place, scope=paddle.static.global_scope())
343

H
huangxu96 已提交
344
                if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0:
345
                    run_example_code()
346
        """
347 348 349
        assert (
            self._train_program is not None
        ), "Please call the minimize method first."
350
        if self._use_pure_fp16:
351
            cast_parameters_to_fp16(
352 353 354 355 356
                place,
                self._train_program,
                scope,
                self._to_fp16_var_names,
                self._amp_vartype,
357
            )
358 359
        if test_program is not None:
            if self._use_pure_fp16:
360
                cast_model_to_fp16(
361 362 363 364
                    test_program,
                    self._amp_lists,
                    self._use_fp16_guard,
                    self._amp_vartype,
365
                )
366
            elif use_fp16_test:
367 368 369
                rewrite_program(
                    test_program, self._amp_lists, self._amp_vartype
                )
370

371
    def apply_gradients(self, params_grads):
372
        """
373
        Check scaled gradients to determine whether to update loss scaling and update
374
        parameters by their scaled gradients.
375

376
        Args:
377
            params_grads (list): A list of params and scaled grads.
378

379 380 381
        Returns:
            A list of optimize operators.
        """
J
Jie Fang 已提交
382

383 384 385 386
        # Change the op_role_var attr for some ops, so that gradients
        # transferred across GPUs can be FP16.
        update_role_var_grad(self._train_program, params_grads)

387 388
        # When not using dynamic loss scaling and the init loss scaling value is equal to 1.0,
        # the model can be optimized.
389 390 391 392
        if (
            not self._use_dynamic_loss_scaling
            and self._init_loss_scaling == 1.0
        ):
393 394
            return self._optimizer.apply_gradients(params_grads)

395 396 397 398 399 400 401 402
        if self._supports_check_nan_inf():
            self._optimizer._set_scale(self._loss_scaling)
            optimize_ops = self._optimizer.apply_gradients(params_grads)
            found_inf = self._optimizer._found_inf
            self._add_dynamic_loss_scaling(params_grads, found_inf)
            return optimize_ops

        found_inf = self._check_finite_and_unscale(params_grads)
403 404 405 406
        if (
            self._use_dynamic_loss_scaling
            and self._amp_vartype == core.VarDesc.VarType.FP16
        ):
407 408 409 410 411 412 413
            self._add_dynamic_loss_scaling(params_grads, found_inf)

        # Pass found_inf to adam, to skip update for not only param, but also momentum and beta_pow
        # With fleet, optimizers are nested and the real optimizer set by user is the inner most one.
        real_optimizer = self._optimizer
        while hasattr(real_optimizer, "inner_opt"):
            real_optimizer = real_optimizer.inner_opt
414 415 416 417
        if isinstance(
            real_optimizer,
            (paddle.fluid.optimizer.Adam, paddle.optimizer.AdamW),
        ):
418 419 420
            # NOTE(zhiqiu): Since found_inf needs to be on cpu in adam op, we
            # copy it in advance to avoid multiple time copies.
            with self._train_program._optimized_guard([]):
421
                found_inf = paddle.tensor.creation._memcpy(
422 423
                    found_inf, paddle.CPUPlace()
                )
424 425 426 427 428 429 430
            real_optimizer._set_auxiliary_var('found_inf', found_inf)
        elif hasattr(real_optimizer, "_set_auxiliary_var"):
            real_optimizer._set_auxiliary_var('found_inf', found_inf)
        optimize_ops = self._optimizer.apply_gradients(params_grads)
        return optimize_ops

    def _split_grads(self, params_grads):
431
        grads = [g for _, g in params_grads]
432
        fp32_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP32]
433
        fp16_grads = [g for g in grads if g.dtype == self._amp_vartype]
434 435
        assert len(fp32_grads) + len(fp16_grads) == len(
            grads
436
        ), "Data types of all grads must be either fp16/bf16 or fp32."
437
        return grads, fp32_grads, fp16_grads
438

439 440
    def _check_finite_and_unscale(self, params_grads):
        grads, fp32_grads, fp16_grads = self._split_grads(params_grads)
441
        found_infs = []
442

443
        if self._is_distributed:
444 445
            # if distributed, split check_finite_and_unscale to overlap
            # unscale with communication
K
Kim Yann 已提交
446
            if core.is_compiled_with_custom_device('npu'):
447
                with self._train_program._optimized_guard(grads):
448
                    _, found_inf = check_finite_and_unscale(
449
                        grads,
450 451
                        self._loss_scaling,
                        name="find_infinite_scale",
452 453
                        float_status=self._float_status,
                    )
454
                    found_infs.append(found_inf)
455 456 457 458
            else:
                for p, g in params_grads:
                    with self._train_program._optimized_guard([p, g]):
                        _, found_inf = check_finite_and_unscale(
459 460 461
                            [
                                g,
                            ],
462 463
                            self._loss_scaling,
                            name="find_infinite_scale",
464 465
                            float_status=self._float_status,
                        )
466
                        found_infs.append(found_inf)
467 468 469 470 471 472
        elif self._use_pure_fp16:
            if fp32_grads:
                with self._train_program._optimized_guard(fp32_grads):
                    _, fp32_found_inf = check_finite_and_unscale(
                        fp32_grads,
                        self._loss_scaling,
473
                        name="find_infinite_scale_fp32",
474 475
                        float_status=self._float_status,
                    )
476 477 478 479 480 481
                found_infs.append(fp32_found_inf)
            if fp16_grads:
                with self._train_program._optimized_guard(fp16_grads):
                    _, fp16_found_inf = check_finite_and_unscale(
                        fp16_grads,
                        self._loss_scaling,
482
                        name="find_infinite_scale_fp16",
483 484
                        float_status=self._float_status,
                    )
485 486 487 488
                found_infs.append(fp16_found_inf)
        else:
            with self._train_program._optimized_guard(grads):
                _, found_inf = check_finite_and_unscale(
489 490 491
                    grads,
                    self._loss_scaling,
                    name="find_infinite_scale",
492 493
                    float_status=self._float_status,
                )
J
Jie Fang 已提交
494

495 496
        if self._is_distributed or self._use_pure_fp16:
            with self._train_program._optimized_guard([]):
497
                all_infs = paddle.concat(found_infs)
498
                found_inf = paddle.any(all_infs)
499

500 501 502 503 504 505 506 507 508 509 510 511 512 513 514
        return found_inf

    def _add_dynamic_loss_scaling(self, params_grads, found_inf):
        if self._supports_check_nan_inf():
            with self._train_program._optimized_guard([]):
                update_loss_scaling(
                    [],
                    found_inf,
                    self._loss_scaling,
                    self._num_good_steps,
                    self._num_bad_steps,
                    self._incr_every_n_steps,
                    self._decr_every_n_nan_or_inf,
                    self._incr_ratio,
                    self._decr_ratio,
515
                    stop_update=self._optimizer._get_stop_update_var(),
516 517
                    name="update_loss_scaling",
                )
518 519 520 521 522 523 524
            return

        grads, fp32_grads, fp16_grads = self._split_grads(params_grads)
        if self._use_pure_fp16:
            stop_update = False
            with self._train_program._optimized_guard([]):
                if fp32_grads:
525 526 527 528 529 530 531 532 533 534 535 536 537
                    update_loss_scaling(
                        fp32_grads,
                        found_inf,
                        self._loss_scaling,
                        self._num_good_steps,
                        self._num_bad_steps,
                        self._incr_every_n_steps,
                        self._decr_every_n_nan_or_inf,
                        self._incr_ratio,
                        self._decr_ratio,
                        stop_update=stop_update,
                        name="update_loss_scaling_fp32",
                    )
538 539
                    stop_update = True
                if fp16_grads:
540 541 542 543 544 545 546 547 548 549 550 551 552
                    update_loss_scaling(
                        fp16_grads,
                        found_inf,
                        self._loss_scaling,
                        self._num_good_steps,
                        self._num_bad_steps,
                        self._incr_every_n_steps,
                        self._decr_every_n_nan_or_inf,
                        self._incr_ratio,
                        self._decr_ratio,
                        stop_update=stop_update,
                        name="update_loss_scaling_fp16",
                    )
553
        else:
R
Roc 已提交
554
            with self._train_program._optimized_guard([]):
555 556 557 558 559 560 561 562 563 564 565 566
                update_loss_scaling(
                    grads,
                    found_inf,
                    self._loss_scaling,
                    self._num_good_steps,
                    self._num_bad_steps,
                    self._incr_every_n_steps,
                    self._decr_every_n_nan_or_inf,
                    self._incr_ratio,
                    self._decr_ratio,
                    name="update_loss_scaling",
                )
567

568 569 570 571 572 573
    def apply_optimize(self, loss, startup_program, params_grads):
        program = loss.block.program
        with program_guard(program, startup_program):
            optimize_ops = self.apply_gradients(params_grads)
        return optimize_ops

574 575 576
    def minimize(
        self, loss, startup_program=None, parameter_list=None, no_grad_set=None
    ):
577 578 579 580 581
        """
        Perform optimization by minimizing the given loss.

        Args:
            loss (Variable): The loss Variable.
G
gongweibao 已提交
582 583 584 585
            startup_program (Program): startup_program for initializing parameters
                in `parameter_list`.
            parameter_list (list): list of Variables to update.
            no_grad_set (set|None): set of Variables should be ignored.
586 587 588

        Returns:
            The scaled loss by scaling factor, the list of optimize ops, and a
589
            list of scaled parameters and gradients.
590
        """
591

592
        opt_dict = self._optimizer.__class__.__dict__
593 594 595
        if 'minimize' in opt_dict and isinstance(
            opt_dict['minimize'], types.FunctionType
        ):
596 597 598 599
            warnings.warn(
                "The decorated optimizer has its own `minimize` method, but it will not be executed."
            )

600 601 602 603 604 605
        scaled_params_grads = self.backward(
            loss,
            startup_program=startup_program,
            parameter_list=parameter_list,
            no_grad_set=no_grad_set,
        )
G
gongweibao 已提交
606

607 608 609
        optimize_ops = self.apply_optimize(
            loss, startup_program, scaled_params_grads
        )
610

G
gongweibao 已提交
611
        return optimize_ops, scaled_params_grads
612 613


614
@overload(key=FunctionType.FP16_ONLY)
615 616 617 618 619 620 621 622 623 624 625
def decorate(
    optimizer,
    amp_lists=None,
    init_loss_scaling=2**15,
    incr_every_n_steps=1000,
    decr_every_n_nan_or_inf=2,
    incr_ratio=2.0,
    decr_ratio=0.8,
    use_dynamic_loss_scaling=True,
    use_pure_fp16=False,
    use_fp16_guard=None,
626
    use_bf16=False,
627
):
628
    """
629 630 631 632
    Decorate the given optimizer to adapt to the mixed-precision training.

    Args:
        optimizer(Optimizer): A common Optimizer.
H
huangxu96 已提交
633
        amp_lists (CustomOpLists): An CustomOpLists object.
634
        init_loss_scaling(float): The initial loss scaling factor.
635
        incr_every_n_steps(int): Increases loss scaling every n consecutive
J
Jie Fang 已提交
636
                                 steps with finite gradients.
637 638
        decr_every_n_nan_or_inf(int): Decreases loss scaling every n
                                      accumulated steps with nan or
J
Jie Fang 已提交
639
                                      inf gradients.
640
        incr_ratio(float): The multiplier to use when increasing the loss
J
Jie Fang 已提交
641
                           scaling.
642
        decr_ratio(float): The less-than-one-multiplier to use when decreasing
J
Jie Fang 已提交
643
                           the loss scaling.
644
        use_dynamic_loss_scaling(bool): Whether to use dynamic loss scaling.
645 646 647
        use_pure_fp16(bool): Whether to use the pure fp16 training. Default False.
        use_fp16_guard(bool): Whether to use `fp16_guard` when constructing the program.
                           Default None, which means that its value equals to `use_pure_fp16`.
648
        use_bf16(bool): Whether to enable bfloat16 training. Default False.
649 650

    Returns:
651
        An optimizer acting like a normal one but with mixed-precision training
652 653
        enabled.

H
huangxu96 已提交
654
    Examples 1:
655
            .. code-block:: python
H
huangxu96 已提交
656 657 658 659 660 661 662 663 664 665 666 667 668 669

            # black&white list based strategy example
            import paddle
            import paddle.static as static

            paddle.enable_static()

            data = static.data(name='X', shape=[None, 1], dtype='float32')
            hidden = static.nn.fc(x=data, size=10)
            loss = paddle.mean(hidden)
            optimizer = paddle.optimizer.Adam(learning_rate=0.001)

            mp_optimizer = static.amp.decorate(
                    optimizer=optimizer, init_loss_scaling=8.0)
670

G
gongweibao 已提交
671
            ops, param_grads = mp_optimizer.minimize(loss)
672
            scaled_loss = mp_optimizer.get_scaled_loss()
H
huangxu96 已提交
673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694

    Examples 2:
        .. code-block:: python

            # pure fp16 training example
            import numpy as np
            import paddle
            import paddle.nn.functional as F

            def run_example_code():
                place = paddle.CUDAPlace(0)
                exe = paddle.static.Executor(place)
                data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float32')
                conv2d = paddle.static.nn.conv2d(input=data, num_filters=6, filter_size=3)
                # 1) Use fp16_guard to control the range of fp16 kernels used.
                with paddle.static.amp.fp16_guard():
                    bn = paddle.static.nn.batch_norm(input=conv2d, act="relu")
                    pool = F.max_pool2d(bn, kernel_size=2, stride=2)
                    hidden = paddle.static.nn.fc(pool, size=10)
                    loss = paddle.mean(hidden)
                # 2) Create the optimizer and set `multi_precision` to True.
                # Setting `multi_precision` to True can avoid the poor accuracy
695
                # or the slow convergence in a way.
H
huangxu96 已提交
696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714
                optimizer = paddle.optimizer.Momentum(learning_rate=0.01, multi_precision=True)
                # 3) These ops in `custom_black_list` will keep in the float32 computation type.
                amp_list = paddle.static.amp.CustomOpLists(
                    custom_black_list=['pool2d'])
                # 4) The entry of Paddle AMP.
                # Enable pure fp16 training by setting `use_pure_fp16` to True.
                optimizer = paddle.static.amp.decorate(
                    optimizer,
                    amp_list,
                    init_loss_scaling=128.0,
                    use_dynamic_loss_scaling=True,
                    use_pure_fp16=True)
                # If you don't use the default_startup_program(), you sholud pass
                # your defined `startup_program` into `minimize`.
                optimizer.minimize(loss)
                exe.run(paddle.static.default_startup_program())
                # 5) Use `amp_init` after FP32 parameters initialization(such as `exe.run(startup_program)`).
                # If you want to perform the testing process, you should pass `test_program` into `amp_init`.
                optimizer.amp_init(place, scope=paddle.static.global_scope())
715

H
huangxu96 已提交
716 717
            if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0:
                run_example_code()
718
    """
719
    amp_dtype = "bfloat16" if use_bf16 else "float16"
J
Jie Fang 已提交
720
    if amp_lists is None:
721
        amp_lists = AutoMixedPrecisionLists(dtype=amp_dtype)
722 723 724 725

    if use_fp16_guard is None:
        use_fp16_guard = use_pure_fp16

726
    amp_level = "O2" if use_pure_fp16 else "O1"
Z
Zhen Wang 已提交
727
    mp_optimizer = OptimizerWithMixedPrecision(
728 729
        optimizer,
        amp_lists,
730 731 732 733 734 735 736 737 738 739 740 741 742 743
        level=amp_level,
        dtype=amp_dtype,
        init_loss_scaling=init_loss_scaling,
        use_dynamic_loss_scaling=use_dynamic_loss_scaling,
        incr_every_n_steps=incr_every_n_steps,
        decr_every_n_nan_or_inf=decr_every_n_nan_or_inf,
        incr_ratio=incr_ratio,
        decr_ratio=decr_ratio,
        use_amp_guard=use_fp16_guard,
    )

    return mp_optimizer


744 745
@overload(key=FunctionType.COMMON)
def decorate(
746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783
    optimizer,
    amp_lists=None,
    level='O1',
    dtype='float16',
    init_loss_scaling=2**15,
    incr_every_n_steps=1000,
    decr_every_n_nan_or_inf=2,
    incr_ratio=2.0,
    decr_ratio=0.8,
    use_dynamic_loss_scaling=True,
    use_amp_guard=False,
):
    """
    Decorate the given optimizer to adapt to the mixed-precision training.
    """
    amp_dtype = check_amp_dtype(dtype)
    if amp_lists is None:
        amp_lists = AutoMixedPrecisionLists(dtype=amp_dtype)

    # check amp_level: O0-O2
    level = level.upper()
    if not (level in ['O0', 'O1', 'O2']):
        raise ValueError(
            "level should be O0, O1 or O2. O0 represents fp32 train mode, O1 represents AMP train mode, O2 represents pure fp16/bf16 train mode."
        )

    mp_optimizer = OptimizerWithMixedPrecision(
        optimizer,
        amp_lists,
        level=level,
        dtype=amp_dtype,
        init_loss_scaling=init_loss_scaling,
        use_dynamic_loss_scaling=use_dynamic_loss_scaling,
        incr_every_n_steps=incr_every_n_steps,
        decr_every_n_nan_or_inf=decr_every_n_nan_or_inf,
        incr_ratio=incr_ratio,
        decr_ratio=decr_ratio,
        use_amp_guard=use_amp_guard,
784
    )
785 786

    return mp_optimizer