decorator.py 31.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
import types
import warnings
17

18
import paddle
19 20 21 22 23 24 25 26 27
from paddle.fluid import (
    core,
    default_main_program,
    default_startup_program,
    program_guard,
    unique_name,
)

from .amp_nn import check_finite_and_unscale, update_loss_scaling
28
from .fp16_lists import AutoMixedPrecisionLists, check_amp_dtype
29 30 31 32 33 34
from .fp16_utils import (
    cast_model_to_fp16,
    cast_parameters_to_fp16,
    rewrite_program,
    update_role_var_grad,
)
35 36


37
class OptimizerWithMixedPrecision:
38
    """
39
    Optimizer with mixed-precision (MP) training. This is a wrapper of a common
Z
Zhen Wang 已提交
40
    optimizer, plus the support of mixed-precision pre-training. The object
41 42 43
    of this class almost has the same behavior as the common optimizer, with the
    methods `minimize()`, `backward()`, `apply_gradients()` implemented.
    Additionally, it enables the MP training automatically, i.e, the creation
44 45 46 47
    and maintenance of master parameters, scaling of loss, etc.

    Args:
        optimizer (Optimizer): A common Optimizer object.
48 49 50 51 52 53 54 55
        amp_lists (AutoMixedPrecisionLists): An AutoMixedPrecisionLists object.
        level(str): Auto mixed precision level. Accepted values are
            "O1" and "O2": O1 represent mixed precision, the input data type
            of each operator will be casted by white_list and black_list;
            O2 represent Pure fp16 or bf16, all operators parameters and input
            data will be casted to fp16 or bf16, except operators in black_list,
            don't support fp16 or bf16 kernel and batch_norm.
        dtype(str): Whether to use 'float16' or 'bfloat16'.
56 57
        init_loss_scaling (float): The initial loss scaling factor.
        use_dynamic_loss_scaling (bool): Whether to use dynamic loss scaling.
58
        incr_every_n_steps(int): Increases loss scaling every n consecutive
J
Jie Fang 已提交
59
                                 steps with finite gradients.
60 61
        decr_every_n_nan_or_inf(int): Decreases loss scaling every n
                                      accumulated steps with nan or
J
Jie Fang 已提交
62
                                      inf gradients.
63
        incr_ratio(float): The multiplier to use when increasing the loss
J
Jie Fang 已提交
64
                           scaling.
65
        decr_ratio(float): The less-than-one-multiplier to use when decreasing
J
Jie Fang 已提交
66
                           the loss scaling.
67
        use_amp_guard(bool): Whether to use `fp16_guard` when constructing the program.
68
                           Default None, which means that its value is equal to `use_pure_fp16`.
69 70
    """

71 72 73 74
    def __init__(
        self,
        optimizer,
        amp_lists,
75 76
        level,
        dtype,
77 78 79 80 81 82
        init_loss_scaling,
        use_dynamic_loss_scaling,
        incr_every_n_steps,
        decr_every_n_nan_or_inf,
        incr_ratio,
        decr_ratio,
83
        use_amp_guard=None,
84
    ):
85
        self._optimizer = optimizer
J
Jie Fang 已提交
86
        self._amp_lists = amp_lists
87
        self._param_grads = None
88 89
        self._train_program = None

90
        self._is_distributed = False
91
        self._scaled_loss = None
92 93
        self._loss_scaling = None
        self._init_loss_scaling = init_loss_scaling
94
        self._use_dynamic_loss_scaling = use_dynamic_loss_scaling
95 96 97 98 99 100 101 102 103 104 105
        if dtype == "bfloat16":
            if use_dynamic_loss_scaling:
                self._use_dynamic_loss_scaling = False
                self._init_loss_scaling = 1.0
                warnings.warn(
                    "Dynamic loss scaling for bfloat16 amp training is disabled, and the init_loss_scaling is changed to 1.0 automatically by PaddlePaddle."
                )
            self._amp_vartype = core.VarDesc.VarType.BF16
        else:
            self._amp_vartype = core.VarDesc.VarType.FP16

A
Aurelius84 已提交
106 107
        self._learning_rate = optimizer._learning_rate
        self._learning_rate_map = optimizer._learning_rate_map
108 109
        self._use_pure_fp16 = level == "O2"
        self._use_fp16_guard = use_amp_guard
110
        self._to_fp16_var_names = None
J
Jie Fang 已提交
111
        if self._use_dynamic_loss_scaling:
112 113
            self._incr_every_n_steps = incr_every_n_steps
            self._decr_every_n_nan_or_inf = decr_every_n_nan_or_inf
J
Jie Fang 已提交
114 115
            self._incr_ratio = incr_ratio
            self._decr_ratio = decr_ratio
116 117 118
            self._num_good_steps = None
            self._num_bad_steps = None

119 120 121 122 123 124
    def _set_distributed(self, flag):
        # if distributed, all cards will communication with each other,
        # overlap communication and computation by split the
        # check_finite_and_unscale op.
        self._is_distributed = flag

125
    def get_loss_scaling(self):
126 127 128 129
        """Return the real-time loss scaling factor."""
        assert (
            self._loss_scaling is not None
        ), 'Please call minimize() before calling get_loss_scaling().'
130 131 132 133 134 135 136 137
        return self._loss_scaling

    def get_scaled_loss(self):
        """Return the scaled loss.
        It's useful when you feed customed loss into executor.
        """
        return self._scaled_loss

138 139 140
    def _supports_check_nan_inf(self):
        return getattr(self._optimizer, "_supports_check_nan_inf", False)

141
    def _init_amp_var(self):
142
        self._loss_scaling = paddle.static.create_global_var(
143 144 145 146
            name=unique_name.generate("loss_scaling"),
            shape=[1],
            value=self._init_loss_scaling,
            dtype='float32',
147 148
            persistable=True,
        )
149 150

        if self._use_dynamic_loss_scaling:
151
            self._num_good_steps = paddle.static.create_global_var(
J
Jie Fang 已提交
152 153 154 155
                name=unique_name.generate("num_good_steps"),
                shape=[1],
                value=0,
                dtype='int32',
156 157
                persistable=True,
            )
158
            self._num_bad_steps = paddle.static.create_global_var(
J
Jie Fang 已提交
159 160 161 162
                name=unique_name.generate("num_bad_steps"),
                shape=[1],
                value=0,
                dtype='int32',
163 164
                persistable=True,
            )
165

166
        # Ensure the data type of learning rate vars is float32 (same as the
167
        # master parameter dtype)
168
        if isinstance(self._optimizer._learning_rate, float):
169 170
            self._optimizer._learning_rate_map[
                default_main_program()
171
            ] = paddle.static.create_global_var(
172 173 174 175 176 177 178 179 180 181 182 183 184 185 186
                name=unique_name.generate("learning_rate"),
                shape=[1],
                value=float(self._optimizer._learning_rate),
                dtype='float32',
                persistable=True,
            )

    def backward(
        self,
        loss,
        startup_program=None,
        parameter_list=None,
        no_grad_set=None,
        callbacks=None,
    ):
187
        """
Z
Zhen Wang 已提交
188
        Backward propagation or auto differentiation for gradients' computation.
189 190 191

        Args:
            loss (Variable): The loss Variable to minimize.
192
            startup_program (Program|None): The startup Program for initializing
193 194 195
                                       parameters in `parameter_list`.
            parameter_list (list|None): A list of Variables to update.
            no_grad_set (set|None): A set of Variables should be ignored.
Z
Zhen Wang 已提交
196
            callbacks (list|None): A list of callable objects to run when appending
197 198 199
                                   backward operator for one parameter.

        Returns:
200
            A list of (param, grad), which is a tuple of a parameter and its
201 202
            gradient respectively, and the scaled loss.
        """
203 204 205
        train_program = loss.block.program
        self._train_program = train_program

206
        # NOTE(zhiqiu): _float_status is only used for NPU.
K
Kim Yann 已提交
207
        if core.is_compiled_with_custom_device('npu'):
208 209 210
            float_status = paddle.static.data(
                name="float_status", shape=[8], dtype='float32'
            )
211 212
            self._train_program.global_block().append_op(
                type="alloc_float_status",
213 214
                outputs={"FloatStatus": float_status},
            )
215 216 217
            self._train_program.global_block().append_op(
                type="clear_float_status",
                inputs={"FloatStatus": float_status},
218 219
                outputs={"FloatStatusOut": float_status},
            )
220 221 222 223
            self._float_status = float_status
        else:
            self._float_status = None

224
        with program_guard(self._train_program, startup_program):
225 226
            self._init_amp_var()

227 228
            if self._use_pure_fp16:
                self._to_fp16_var_names = cast_model_to_fp16(
229 230 231 232
                    self._train_program,
                    self._amp_lists,
                    self._use_fp16_guard,
                    self._amp_vartype,
233
                )
234
            else:
235 236 237
                rewrite_program(
                    self._train_program, self._amp_lists, self._amp_vartype
                )
238 239 240 241 242 243 244 245 246 247

            if loss.dtype != core.VarDesc.VarType.FP32:
                loss = loss.astype('float32')
            # When not using dynamic loss scaling and the init loss scaling value is equal to 1.0,
            # the model can be optimized.
            if self._use_dynamic_loss_scaling or self._init_loss_scaling != 1.0:
                self._scaled_loss = loss * self._loss_scaling
            else:
                self._scaled_loss = loss

248 249 250 251 252 253 254
            params_grads = self._optimizer.backward(
                self._scaled_loss,
                startup_program,
                parameter_list,
                no_grad_set,
                callbacks,
            )
255 256
            if self._supports_check_nan_inf():
                self._add_cast_ops_to_startup_program(startup_program)
257
        return params_grads
258

259 260 261
    def _add_cast_ops_to_startup_program(self, startup_program):
        names = list(self._to_fp16_var_names) if self._to_fp16_var_names else []
        names.sort()
262 263 264 265 266
        startup_program = (
            default_startup_program()
            if startup_program is None
            else startup_program
        )
267 268 269 270 271 272 273
        block = startup_program.global_block()
        param_names = [p.name for p in block.all_parameters()]
        for name in names:
            if name not in param_names:
                continue

            tmp = block.create_var(dtype=core.VarDesc.VarType.FP32)
274 275 276 277 278 279 280 281 282
            block.append_op(
                type='assign', inputs={'X': [name]}, outputs={'Out': [tmp]}
            )
            block.append_op(
                type='cast',
                inputs={'X': [tmp]},
                outputs={'Out': [name]},
                attrs={
                    'in_dtype': core.VarDesc.VarType.FP32,
283
                    'out_dtype': self._amp_vartype,
284 285
                },
            )
286 287
        self._to_fp16_var_names = None

288 289 290
    def amp_init(
        self, place, scope=None, test_program=None, use_fp16_test=False
    ):
291 292
        """
        Init the amp training, such as cast fp32 parameters to fp16 type.
293

294
        Args:
295
            place(CUDAPlace): place is used to initialize
296 297 298 299 300
                fp16 parameters with fp32 values.
            scope(Scope): The scope is used to find fp32 parameters.
            test_program(Program): The program is used for testing.
            use_fp16_test(bool): Whether to use fp16 testing.

H
huangxu96 已提交
301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321
        Examples:
            .. code-block:: python

                import numpy as np
                import paddle
                import paddle.nn.functional as F
                paddle.enable_static()

                def run_example_code():
                    place = paddle.CUDAPlace(0)
                    exe = paddle.static.Executor(place)
                    data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float32')
                    conv2d = paddle.static.nn.conv2d(input=data, num_filters=6, filter_size=3)
                    # 1) Use fp16_guard to control the range of fp16 kernels used.
                    with paddle.static.amp.fp16_guard():
                        bn = paddle.static.nn.batch_norm(input=conv2d, act="relu")
                        pool = F.max_pool2d(bn, kernel_size=2, stride=2)
                        hidden = paddle.static.nn.fc(pool, size=10)
                        loss = paddle.mean(hidden)
                    # 2) Create the optimizer and set `multi_precision` to True.
                    # Setting `multi_precision` to True can avoid the poor accuracy
322
                    # or the slow convergence in a way.
H
huangxu96 已提交
323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341
                    optimizer = paddle.optimizer.Momentum(learning_rate=0.01, multi_precision=True)
                    # 3) These ops in `custom_black_list` will keep in the float32 computation type.
                    amp_list = paddle.static.amp.CustomOpLists(
                        custom_black_list=['pool2d'])
                    # 4) The entry of Paddle AMP.
                    # Enable pure fp16 training by setting `use_pure_fp16` to True.
                    optimizer = paddle.static.amp.decorate(
                        optimizer,
                        amp_list,
                        init_loss_scaling=128.0,
                        use_dynamic_loss_scaling=True,
                        use_pure_fp16=True)
                    # If you don't use the default_startup_program(), you sholud pass
                    # your defined `startup_program` into `minimize`.
                    optimizer.minimize(loss)
                    exe.run(paddle.static.default_startup_program())
                    # 5) Use `amp_init` after FP32 parameters initialization(such as `exe.run(startup_program)`).
                    # If you want to perform the testing process, you should pass `test_program` into `amp_init`.
                    optimizer.amp_init(place, scope=paddle.static.global_scope())
342

H
huangxu96 已提交
343
                if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0:
344
                    run_example_code()
345
        """
346 347 348
        assert (
            self._train_program is not None
        ), "Please call the minimize method first."
349
        if self._use_pure_fp16:
350
            cast_parameters_to_fp16(
351 352 353 354 355
                place,
                self._train_program,
                scope,
                self._to_fp16_var_names,
                self._amp_vartype,
356
            )
357 358
        if test_program is not None:
            if self._use_pure_fp16:
359
                cast_model_to_fp16(
360 361 362 363
                    test_program,
                    self._amp_lists,
                    self._use_fp16_guard,
                    self._amp_vartype,
364
                )
365
            elif use_fp16_test:
366 367 368
                rewrite_program(
                    test_program, self._amp_lists, self._amp_vartype
                )
369

370
    def apply_gradients(self, params_grads):
371
        """
372
        Check scaled gradients to determine whether to update loss scaling and update
373
        parameters by their scaled gradients.
374

375
        Args:
376
            params_grads (list): A list of params and scaled grads.
377

378 379 380
        Returns:
            A list of optimize operators.
        """
J
Jie Fang 已提交
381

382 383 384 385
        # Change the op_role_var attr for some ops, so that gradients
        # transferred across GPUs can be FP16.
        update_role_var_grad(self._train_program, params_grads)

386 387
        # When not using dynamic loss scaling and the init loss scaling value is equal to 1.0,
        # the model can be optimized.
388 389 390 391
        if (
            not self._use_dynamic_loss_scaling
            and self._init_loss_scaling == 1.0
        ):
392 393
            return self._optimizer.apply_gradients(params_grads)

394 395 396 397 398 399 400 401
        if self._supports_check_nan_inf():
            self._optimizer._set_scale(self._loss_scaling)
            optimize_ops = self._optimizer.apply_gradients(params_grads)
            found_inf = self._optimizer._found_inf
            self._add_dynamic_loss_scaling(params_grads, found_inf)
            return optimize_ops

        found_inf = self._check_finite_and_unscale(params_grads)
402 403 404 405
        if (
            self._use_dynamic_loss_scaling
            and self._amp_vartype == core.VarDesc.VarType.FP16
        ):
406 407 408 409 410 411 412
            self._add_dynamic_loss_scaling(params_grads, found_inf)

        # Pass found_inf to adam, to skip update for not only param, but also momentum and beta_pow
        # With fleet, optimizers are nested and the real optimizer set by user is the inner most one.
        real_optimizer = self._optimizer
        while hasattr(real_optimizer, "inner_opt"):
            real_optimizer = real_optimizer.inner_opt
413 414 415 416
        if isinstance(
            real_optimizer,
            (paddle.fluid.optimizer.Adam, paddle.optimizer.AdamW),
        ):
417 418 419
            # NOTE(zhiqiu): Since found_inf needs to be on cpu in adam op, we
            # copy it in advance to avoid multiple time copies.
            with self._train_program._optimized_guard([]):
420
                found_inf = paddle.tensor.creation._memcpy(
421 422
                    found_inf, paddle.CPUPlace()
                )
423 424 425 426 427 428 429
            real_optimizer._set_auxiliary_var('found_inf', found_inf)
        elif hasattr(real_optimizer, "_set_auxiliary_var"):
            real_optimizer._set_auxiliary_var('found_inf', found_inf)
        optimize_ops = self._optimizer.apply_gradients(params_grads)
        return optimize_ops

    def _split_grads(self, params_grads):
430
        grads = [g for _, g in params_grads]
431
        fp32_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP32]
432
        fp16_grads = [g for g in grads if g.dtype == self._amp_vartype]
433 434
        assert len(fp32_grads) + len(fp16_grads) == len(
            grads
435
        ), "Data types of all grads must be either fp16/bf16 or fp32."
436
        return grads, fp32_grads, fp16_grads
437

438 439
    def _check_finite_and_unscale(self, params_grads):
        grads, fp32_grads, fp16_grads = self._split_grads(params_grads)
440
        found_infs = []
441

442
        if self._is_distributed:
443 444
            # if distributed, split check_finite_and_unscale to overlap
            # unscale with communication
K
Kim Yann 已提交
445
            if core.is_compiled_with_custom_device('npu'):
446
                with self._train_program._optimized_guard(grads):
447
                    _, found_inf = check_finite_and_unscale(
448
                        grads,
449 450
                        self._loss_scaling,
                        name="find_infinite_scale",
451 452
                        float_status=self._float_status,
                    )
453
                    found_infs.append(found_inf)
454 455 456 457
            else:
                for p, g in params_grads:
                    with self._train_program._optimized_guard([p, g]):
                        _, found_inf = check_finite_and_unscale(
458 459 460
                            [
                                g,
                            ],
461 462
                            self._loss_scaling,
                            name="find_infinite_scale",
463 464
                            float_status=self._float_status,
                        )
465
                        found_infs.append(found_inf)
466 467 468 469 470 471
        elif self._use_pure_fp16:
            if fp32_grads:
                with self._train_program._optimized_guard(fp32_grads):
                    _, fp32_found_inf = check_finite_and_unscale(
                        fp32_grads,
                        self._loss_scaling,
472
                        name="find_infinite_scale_fp32",
473 474
                        float_status=self._float_status,
                    )
475 476 477 478 479 480
                found_infs.append(fp32_found_inf)
            if fp16_grads:
                with self._train_program._optimized_guard(fp16_grads):
                    _, fp16_found_inf = check_finite_and_unscale(
                        fp16_grads,
                        self._loss_scaling,
481
                        name="find_infinite_scale_fp16",
482 483
                        float_status=self._float_status,
                    )
484 485 486 487
                found_infs.append(fp16_found_inf)
        else:
            with self._train_program._optimized_guard(grads):
                _, found_inf = check_finite_and_unscale(
488 489 490
                    grads,
                    self._loss_scaling,
                    name="find_infinite_scale",
491 492
                    float_status=self._float_status,
                )
J
Jie Fang 已提交
493

494 495
        if self._is_distributed or self._use_pure_fp16:
            with self._train_program._optimized_guard([]):
496
                all_infs = paddle.concat(found_infs)
497
                found_inf = paddle.any(all_infs)
498

499 500 501 502 503 504 505 506 507 508 509 510 511 512 513
        return found_inf

    def _add_dynamic_loss_scaling(self, params_grads, found_inf):
        if self._supports_check_nan_inf():
            with self._train_program._optimized_guard([]):
                update_loss_scaling(
                    [],
                    found_inf,
                    self._loss_scaling,
                    self._num_good_steps,
                    self._num_bad_steps,
                    self._incr_every_n_steps,
                    self._decr_every_n_nan_or_inf,
                    self._incr_ratio,
                    self._decr_ratio,
514
                    stop_update=self._optimizer._get_stop_update_var(),
515 516
                    name="update_loss_scaling",
                )
517 518 519 520 521 522 523
            return

        grads, fp32_grads, fp16_grads = self._split_grads(params_grads)
        if self._use_pure_fp16:
            stop_update = False
            with self._train_program._optimized_guard([]):
                if fp32_grads:
524 525 526 527 528 529 530 531 532 533 534 535 536
                    update_loss_scaling(
                        fp32_grads,
                        found_inf,
                        self._loss_scaling,
                        self._num_good_steps,
                        self._num_bad_steps,
                        self._incr_every_n_steps,
                        self._decr_every_n_nan_or_inf,
                        self._incr_ratio,
                        self._decr_ratio,
                        stop_update=stop_update,
                        name="update_loss_scaling_fp32",
                    )
537 538
                    stop_update = True
                if fp16_grads:
539 540 541 542 543 544 545 546 547 548 549 550 551
                    update_loss_scaling(
                        fp16_grads,
                        found_inf,
                        self._loss_scaling,
                        self._num_good_steps,
                        self._num_bad_steps,
                        self._incr_every_n_steps,
                        self._decr_every_n_nan_or_inf,
                        self._incr_ratio,
                        self._decr_ratio,
                        stop_update=stop_update,
                        name="update_loss_scaling_fp16",
                    )
552
        else:
R
Roc 已提交
553
            with self._train_program._optimized_guard([]):
554 555 556 557 558 559 560 561 562 563 564 565
                update_loss_scaling(
                    grads,
                    found_inf,
                    self._loss_scaling,
                    self._num_good_steps,
                    self._num_bad_steps,
                    self._incr_every_n_steps,
                    self._decr_every_n_nan_or_inf,
                    self._incr_ratio,
                    self._decr_ratio,
                    name="update_loss_scaling",
                )
566

567 568 569 570 571 572
    def apply_optimize(self, loss, startup_program, params_grads):
        program = loss.block.program
        with program_guard(program, startup_program):
            optimize_ops = self.apply_gradients(params_grads)
        return optimize_ops

573 574 575
    def minimize(
        self, loss, startup_program=None, parameter_list=None, no_grad_set=None
    ):
576 577 578 579 580
        """
        Perform optimization by minimizing the given loss.

        Args:
            loss (Variable): The loss Variable.
G
gongweibao 已提交
581 582 583 584
            startup_program (Program): startup_program for initializing parameters
                in `parameter_list`.
            parameter_list (list): list of Variables to update.
            no_grad_set (set|None): set of Variables should be ignored.
585 586 587

        Returns:
            The scaled loss by scaling factor, the list of optimize ops, and a
588
            list of scaled parameters and gradients.
589
        """
590

591
        opt_dict = self._optimizer.__class__.__dict__
592 593 594
        if 'minimize' in opt_dict and isinstance(
            opt_dict['minimize'], types.FunctionType
        ):
595 596 597 598
            warnings.warn(
                "The decorated optimizer has its own `minimize` method, but it will not be executed."
            )

599 600 601 602 603 604
        scaled_params_grads = self.backward(
            loss,
            startup_program=startup_program,
            parameter_list=parameter_list,
            no_grad_set=no_grad_set,
        )
G
gongweibao 已提交
605

606 607 608
        optimize_ops = self.apply_optimize(
            loss, startup_program, scaled_params_grads
        )
609

G
gongweibao 已提交
610
        return optimize_ops, scaled_params_grads
611 612


613 614 615 616 617 618 619 620 621 622 623
def decorate(
    optimizer,
    amp_lists=None,
    init_loss_scaling=2**15,
    incr_every_n_steps=1000,
    decr_every_n_nan_or_inf=2,
    incr_ratio=2.0,
    decr_ratio=0.8,
    use_dynamic_loss_scaling=True,
    use_pure_fp16=False,
    use_fp16_guard=None,
624
    use_bf16=False,
625
):
626
    """
627 628 629 630
    Decorate the given optimizer to adapt to the mixed-precision training.

    Args:
        optimizer(Optimizer): A common Optimizer.
H
huangxu96 已提交
631
        amp_lists (CustomOpLists): An CustomOpLists object.
632
        init_loss_scaling(float): The initial loss scaling factor.
633
        incr_every_n_steps(int): Increases loss scaling every n consecutive
J
Jie Fang 已提交
634
                                 steps with finite gradients.
635 636
        decr_every_n_nan_or_inf(int): Decreases loss scaling every n
                                      accumulated steps with nan or
J
Jie Fang 已提交
637
                                      inf gradients.
638
        incr_ratio(float): The multiplier to use when increasing the loss
J
Jie Fang 已提交
639
                           scaling.
640
        decr_ratio(float): The less-than-one-multiplier to use when decreasing
J
Jie Fang 已提交
641
                           the loss scaling.
642
        use_dynamic_loss_scaling(bool): Whether to use dynamic loss scaling.
643 644 645
        use_pure_fp16(bool): Whether to use the pure fp16 training. Default False.
        use_fp16_guard(bool): Whether to use `fp16_guard` when constructing the program.
                           Default None, which means that its value equals to `use_pure_fp16`.
646
        use_bf16(bool): Whether to enable bfloat16 training. Default False.
647 648

    Returns:
649
        An optimizer acting like a normal one but with mixed-precision training
650 651
        enabled.

H
huangxu96 已提交
652
    Examples 1:
653
            .. code-block:: python
H
huangxu96 已提交
654 655 656 657 658 659 660 661 662 663 664 665 666 667

            # black&white list based strategy example
            import paddle
            import paddle.static as static

            paddle.enable_static()

            data = static.data(name='X', shape=[None, 1], dtype='float32')
            hidden = static.nn.fc(x=data, size=10)
            loss = paddle.mean(hidden)
            optimizer = paddle.optimizer.Adam(learning_rate=0.001)

            mp_optimizer = static.amp.decorate(
                    optimizer=optimizer, init_loss_scaling=8.0)
668

G
gongweibao 已提交
669
            ops, param_grads = mp_optimizer.minimize(loss)
670
            scaled_loss = mp_optimizer.get_scaled_loss()
H
huangxu96 已提交
671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692

    Examples 2:
        .. code-block:: python

            # pure fp16 training example
            import numpy as np
            import paddle
            import paddle.nn.functional as F

            def run_example_code():
                place = paddle.CUDAPlace(0)
                exe = paddle.static.Executor(place)
                data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float32')
                conv2d = paddle.static.nn.conv2d(input=data, num_filters=6, filter_size=3)
                # 1) Use fp16_guard to control the range of fp16 kernels used.
                with paddle.static.amp.fp16_guard():
                    bn = paddle.static.nn.batch_norm(input=conv2d, act="relu")
                    pool = F.max_pool2d(bn, kernel_size=2, stride=2)
                    hidden = paddle.static.nn.fc(pool, size=10)
                    loss = paddle.mean(hidden)
                # 2) Create the optimizer and set `multi_precision` to True.
                # Setting `multi_precision` to True can avoid the poor accuracy
693
                # or the slow convergence in a way.
H
huangxu96 已提交
694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712
                optimizer = paddle.optimizer.Momentum(learning_rate=0.01, multi_precision=True)
                # 3) These ops in `custom_black_list` will keep in the float32 computation type.
                amp_list = paddle.static.amp.CustomOpLists(
                    custom_black_list=['pool2d'])
                # 4) The entry of Paddle AMP.
                # Enable pure fp16 training by setting `use_pure_fp16` to True.
                optimizer = paddle.static.amp.decorate(
                    optimizer,
                    amp_list,
                    init_loss_scaling=128.0,
                    use_dynamic_loss_scaling=True,
                    use_pure_fp16=True)
                # If you don't use the default_startup_program(), you sholud pass
                # your defined `startup_program` into `minimize`.
                optimizer.minimize(loss)
                exe.run(paddle.static.default_startup_program())
                # 5) Use `amp_init` after FP32 parameters initialization(such as `exe.run(startup_program)`).
                # If you want to perform the testing process, you should pass `test_program` into `amp_init`.
                optimizer.amp_init(place, scope=paddle.static.global_scope())
713

H
huangxu96 已提交
714 715
            if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0:
                run_example_code()
716
    """
717
    amp_dtype = "bfloat16" if use_bf16 else "float16"
J
Jie Fang 已提交
718
    if amp_lists is None:
719
        amp_lists = AutoMixedPrecisionLists(dtype=amp_dtype)
720 721 722 723

    if use_fp16_guard is None:
        use_fp16_guard = use_pure_fp16

724
    amp_level = "O2" if use_pure_fp16 else "O1"
Z
Zhen Wang 已提交
725
    mp_optimizer = OptimizerWithMixedPrecision(
726 727
        optimizer,
        amp_lists,
728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780
        level=amp_level,
        dtype=amp_dtype,
        init_loss_scaling=init_loss_scaling,
        use_dynamic_loss_scaling=use_dynamic_loss_scaling,
        incr_every_n_steps=incr_every_n_steps,
        decr_every_n_nan_or_inf=decr_every_n_nan_or_inf,
        incr_ratio=incr_ratio,
        decr_ratio=decr_ratio,
        use_amp_guard=use_fp16_guard,
    )

    return mp_optimizer


def amp_decorate(
    optimizer,
    amp_lists=None,
    level='O1',
    dtype='float16',
    init_loss_scaling=2**15,
    incr_every_n_steps=1000,
    decr_every_n_nan_or_inf=2,
    incr_ratio=2.0,
    decr_ratio=0.8,
    use_dynamic_loss_scaling=True,
    use_amp_guard=False,
):
    """
    Decorate the given optimizer to adapt to the mixed-precision training.
    """
    amp_dtype = check_amp_dtype(dtype)
    if amp_lists is None:
        amp_lists = AutoMixedPrecisionLists(dtype=amp_dtype)

    # check amp_level: O0-O2
    level = level.upper()
    if not (level in ['O0', 'O1', 'O2']):
        raise ValueError(
            "level should be O0, O1 or O2. O0 represents fp32 train mode, O1 represents AMP train mode, O2 represents pure fp16/bf16 train mode."
        )

    mp_optimizer = OptimizerWithMixedPrecision(
        optimizer,
        amp_lists,
        level=level,
        dtype=amp_dtype,
        init_loss_scaling=init_loss_scaling,
        use_dynamic_loss_scaling=use_dynamic_loss_scaling,
        incr_every_n_steps=incr_every_n_steps,
        decr_every_n_nan_or_inf=decr_every_n_nan_or_inf,
        incr_ratio=incr_ratio,
        decr_ratio=decr_ratio,
        use_amp_guard=use_amp_guard,
781
    )
782 783

    return mp_optimizer