decorator.py 35.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
import types
import warnings
17

18
import paddle
19 20 21 22 23 24 25 26 27
from paddle.fluid import (
    core,
    default_main_program,
    default_startup_program,
    program_guard,
    unique_name,
)

from .amp_nn import check_finite_and_unscale, update_loss_scaling
28
from .fp16_lists import AutoMixedPrecisionLists, check_amp_dtype
29 30 31 32 33
from .fp16_utils import (
    cast_model_to_fp16,
    cast_parameters_to_fp16,
    update_role_var_grad,
)
34
from .function_overload import FunctionType, overload
35 36


37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
def _set_multi_precision(optimizer, multi_precision):
    if not isinstance(
        optimizer,
        (paddle.optimizer.Optimizer, paddle.fluid.optimizer.Optimizer),
    ):
        raise RuntimeError(
            "Current AMP training level is O2, optimizer is expected to be paddle.optimizer.Optimizer or paddle.fluid.optimizer.Optimizer, but receive {}.".format(
                type(optimizer)
            )
        )

    if multi_precision and hasattr(optimizer, "_multi_precision"):
        optimizer._multi_precision = multi_precision


52
class OptimizerWithMixedPrecision:
53
    """
54
    Optimizer with mixed-precision (MP) training. This is a wrapper of a common
Z
Zhen Wang 已提交
55
    optimizer, plus the support of mixed-precision pre-training. The object
56 57 58
    of this class almost has the same behavior as the common optimizer, with the
    methods `minimize()`, `backward()`, `apply_gradients()` implemented.
    Additionally, it enables the MP training automatically, i.e, the creation
59 60 61 62
    and maintenance of master parameters, scaling of loss, etc.

    Args:
        optimizer (Optimizer): A common Optimizer object.
63 64 65 66 67 68 69 70
        amp_lists (AutoMixedPrecisionLists): An AutoMixedPrecisionLists object.
        level(str): Auto mixed precision level. Accepted values are
            "O1" and "O2": O1 represent mixed precision, the input data type
            of each operator will be casted by white_list and black_list;
            O2 represent Pure fp16 or bf16, all operators parameters and input
            data will be casted to fp16 or bf16, except operators in black_list,
            don't support fp16 or bf16 kernel and batch_norm.
        dtype(str): Whether to use 'float16' or 'bfloat16'.
71 72
        init_loss_scaling (float): The initial loss scaling factor.
        use_dynamic_loss_scaling (bool): Whether to use dynamic loss scaling.
73
        incr_every_n_steps(int): Increases loss scaling every n consecutive
J
Jie Fang 已提交
74
                                 steps with finite gradients.
75 76
        decr_every_n_nan_or_inf(int): Decreases loss scaling every n
                                      accumulated steps with nan or
J
Jie Fang 已提交
77
                                      inf gradients.
78
        incr_ratio(float): The multiplier to use when increasing the loss
J
Jie Fang 已提交
79
                           scaling.
80
        decr_ratio(float): The less-than-one-multiplier to use when decreasing
J
Jie Fang 已提交
81
                           the loss scaling.
82
        use_amp_guard(bool): Whether to use `fp16_guard` when constructing the program.
83
                           Default None, which means that its value is equal to `use_pure_fp16`.
84
        use_promote(bool): Whether to promotes to fp32 when op has any float32 inputs. Default is False.
85 86
    """

87 88 89 90
    def __init__(
        self,
        optimizer,
        amp_lists,
91 92
        level,
        dtype,
93 94 95 96 97 98
        init_loss_scaling,
        use_dynamic_loss_scaling,
        incr_every_n_steps,
        decr_every_n_nan_or_inf,
        incr_ratio,
        decr_ratio,
99
        use_amp_guard=None,
100
        use_promote=False,
101
    ):
102
        self._optimizer = optimizer
J
Jie Fang 已提交
103
        self._amp_lists = amp_lists
104
        self._param_grads = None
105 106
        self._train_program = None

107
        self._is_distributed = False
108
        self._scaled_loss = None
109 110
        self._loss_scaling = None
        self._init_loss_scaling = init_loss_scaling
111
        self._use_dynamic_loss_scaling = use_dynamic_loss_scaling
112 113 114 115 116 117 118 119 120 121 122
        if dtype == "bfloat16":
            if use_dynamic_loss_scaling:
                self._use_dynamic_loss_scaling = False
                self._init_loss_scaling = 1.0
                warnings.warn(
                    "Dynamic loss scaling for bfloat16 amp training is disabled, and the init_loss_scaling is changed to 1.0 automatically by PaddlePaddle."
                )
            self._amp_vartype = core.VarDesc.VarType.BF16
        else:
            self._amp_vartype = core.VarDesc.VarType.FP16

A
Aurelius84 已提交
123 124
        self._learning_rate = optimizer._learning_rate
        self._learning_rate_map = optimizer._learning_rate_map
125 126
        self._use_pure_fp16 = level == "O2"
        self._use_fp16_guard = use_amp_guard
127
        self._to_fp16_var_names = None
J
Jie Fang 已提交
128
        if self._use_dynamic_loss_scaling:
129 130
            self._incr_every_n_steps = incr_every_n_steps
            self._decr_every_n_nan_or_inf = decr_every_n_nan_or_inf
J
Jie Fang 已提交
131 132
            self._incr_ratio = incr_ratio
            self._decr_ratio = decr_ratio
133 134
            self._num_good_steps = None
            self._num_bad_steps = None
135
        self.use_promote = use_promote
136

137 138 139 140 141 142
    def _set_distributed(self, flag):
        # if distributed, all cards will communication with each other,
        # overlap communication and computation by split the
        # check_finite_and_unscale op.
        self._is_distributed = flag

143
    def get_loss_scaling(self):
144 145 146 147
        """Return the real-time loss scaling factor."""
        assert (
            self._loss_scaling is not None
        ), 'Please call minimize() before calling get_loss_scaling().'
148 149 150 151 152 153 154 155
        return self._loss_scaling

    def get_scaled_loss(self):
        """Return the scaled loss.
        It's useful when you feed customed loss into executor.
        """
        return self._scaled_loss

156 157 158
    def _supports_check_nan_inf(self):
        return getattr(self._optimizer, "_supports_check_nan_inf", False)

159
    def _init_amp_var(self):
160
        self._loss_scaling = paddle.static.create_global_var(
161 162 163 164
            name=unique_name.generate("loss_scaling"),
            shape=[1],
            value=self._init_loss_scaling,
            dtype='float32',
165 166
            persistable=True,
        )
167 168

        if self._use_dynamic_loss_scaling:
169
            self._num_good_steps = paddle.static.create_global_var(
J
Jie Fang 已提交
170 171 172 173
                name=unique_name.generate("num_good_steps"),
                shape=[1],
                value=0,
                dtype='int32',
174 175
                persistable=True,
            )
176
            self._num_bad_steps = paddle.static.create_global_var(
J
Jie Fang 已提交
177 178 179 180
                name=unique_name.generate("num_bad_steps"),
                shape=[1],
                value=0,
                dtype='int32',
181 182
                persistable=True,
            )
183

184
        # Ensure the data type of learning rate vars is float32 (same as the
185
        # master parameter dtype)
186
        if isinstance(self._optimizer._learning_rate, float):
187 188
            self._optimizer._learning_rate_map[
                default_main_program()
189
            ] = paddle.static.create_global_var(
190 191 192 193 194 195 196 197 198 199 200 201 202 203 204
                name=unique_name.generate("learning_rate"),
                shape=[1],
                value=float(self._optimizer._learning_rate),
                dtype='float32',
                persistable=True,
            )

    def backward(
        self,
        loss,
        startup_program=None,
        parameter_list=None,
        no_grad_set=None,
        callbacks=None,
    ):
205
        """
Z
Zhen Wang 已提交
206
        Backward propagation or auto differentiation for gradients' computation.
207 208 209

        Args:
            loss (Variable): The loss Variable to minimize.
210
            startup_program (Program|None): The startup Program for initializing
211 212 213
                                       parameters in `parameter_list`.
            parameter_list (list|None): A list of Variables to update.
            no_grad_set (set|None): A set of Variables should be ignored.
Z
Zhen Wang 已提交
214
            callbacks (list|None): A list of callable objects to run when appending
215 216 217
                                   backward operator for one parameter.

        Returns:
218
            A list of (param, grad), which is a tuple of a parameter and its
219 220
            gradient respectively, and the scaled loss.
        """
221 222
        train_program = loss.block.program
        self._train_program = train_program
223
        self._float_status = None
224

225
        with program_guard(self._train_program, startup_program):
226 227
            self._init_amp_var()

228 229
            if self._use_pure_fp16:
                self._to_fp16_var_names = cast_model_to_fp16(
230 231 232 233
                    self._train_program,
                    self._amp_lists,
                    self._use_fp16_guard,
                    self._amp_vartype,
234 235
                    level='O2',
                    use_promote=self.use_promote,
236
                )
237
            else:
238 239 240 241 242 243 244 245
                # use_fp16_guard is not support amp-o1.
                cast_model_to_fp16(
                    self._train_program,
                    self._amp_lists,
                    use_fp16_guard=False,
                    dest_type=self._amp_vartype,
                    level='O1',
                    use_promote=self.use_promote,
246
                )
247 248 249 250 251 252 253 254 255 256

            if loss.dtype != core.VarDesc.VarType.FP32:
                loss = loss.astype('float32')
            # When not using dynamic loss scaling and the init loss scaling value is equal to 1.0,
            # the model can be optimized.
            if self._use_dynamic_loss_scaling or self._init_loss_scaling != 1.0:
                self._scaled_loss = loss * self._loss_scaling
            else:
                self._scaled_loss = loss

257 258 259 260 261 262 263
            params_grads = self._optimizer.backward(
                self._scaled_loss,
                startup_program,
                parameter_list,
                no_grad_set,
                callbacks,
            )
264 265
            if self._supports_check_nan_inf():
                self._add_cast_ops_to_startup_program(startup_program)
266
        return params_grads
267

268 269 270
    def _add_cast_ops_to_startup_program(self, startup_program):
        names = list(self._to_fp16_var_names) if self._to_fp16_var_names else []
        names.sort()
271 272 273 274 275
        startup_program = (
            default_startup_program()
            if startup_program is None
            else startup_program
        )
276 277 278 279 280 281 282
        block = startup_program.global_block()
        param_names = [p.name for p in block.all_parameters()]
        for name in names:
            if name not in param_names:
                continue

            tmp = block.create_var(dtype=core.VarDesc.VarType.FP32)
283 284 285 286 287 288 289 290 291
            block.append_op(
                type='assign', inputs={'X': [name]}, outputs={'Out': [tmp]}
            )
            block.append_op(
                type='cast',
                inputs={'X': [tmp]},
                outputs={'Out': [name]},
                attrs={
                    'in_dtype': core.VarDesc.VarType.FP32,
292
                    'out_dtype': self._amp_vartype,
293 294
                },
            )
295 296
        self._to_fp16_var_names = None

297 298 299
    def amp_init(
        self, place, scope=None, test_program=None, use_fp16_test=False
    ):
300 301
        """
        Init the amp training, such as cast fp32 parameters to fp16 type.
302

303
        Args:
304
            place(CUDAPlace): place is used to initialize
305 306 307 308 309
                fp16 parameters with fp32 values.
            scope(Scope): The scope is used to find fp32 parameters.
            test_program(Program): The program is used for testing.
            use_fp16_test(bool): Whether to use fp16 testing.

H
huangxu96 已提交
310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330
        Examples:
            .. code-block:: python

                import numpy as np
                import paddle
                import paddle.nn.functional as F
                paddle.enable_static()

                def run_example_code():
                    place = paddle.CUDAPlace(0)
                    exe = paddle.static.Executor(place)
                    data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float32')
                    conv2d = paddle.static.nn.conv2d(input=data, num_filters=6, filter_size=3)
                    # 1) Use fp16_guard to control the range of fp16 kernels used.
                    with paddle.static.amp.fp16_guard():
                        bn = paddle.static.nn.batch_norm(input=conv2d, act="relu")
                        pool = F.max_pool2d(bn, kernel_size=2, stride=2)
                        hidden = paddle.static.nn.fc(pool, size=10)
                        loss = paddle.mean(hidden)
                    # 2) Create the optimizer and set `multi_precision` to True.
                    # Setting `multi_precision` to True can avoid the poor accuracy
331
                    # or the slow convergence in a way.
H
huangxu96 已提交
332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350
                    optimizer = paddle.optimizer.Momentum(learning_rate=0.01, multi_precision=True)
                    # 3) These ops in `custom_black_list` will keep in the float32 computation type.
                    amp_list = paddle.static.amp.CustomOpLists(
                        custom_black_list=['pool2d'])
                    # 4) The entry of Paddle AMP.
                    # Enable pure fp16 training by setting `use_pure_fp16` to True.
                    optimizer = paddle.static.amp.decorate(
                        optimizer,
                        amp_list,
                        init_loss_scaling=128.0,
                        use_dynamic_loss_scaling=True,
                        use_pure_fp16=True)
                    # If you don't use the default_startup_program(), you sholud pass
                    # your defined `startup_program` into `minimize`.
                    optimizer.minimize(loss)
                    exe.run(paddle.static.default_startup_program())
                    # 5) Use `amp_init` after FP32 parameters initialization(such as `exe.run(startup_program)`).
                    # If you want to perform the testing process, you should pass `test_program` into `amp_init`.
                    optimizer.amp_init(place, scope=paddle.static.global_scope())
351

H
huangxu96 已提交
352
                if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0:
353
                    run_example_code()
354
        """
355 356 357
        assert (
            self._train_program is not None
        ), "Please call the minimize method first."
358
        if self._use_pure_fp16:
359
            cast_parameters_to_fp16(
360 361 362 363 364
                place,
                self._train_program,
                scope,
                self._to_fp16_var_names,
                self._amp_vartype,
365
            )
366 367
        if test_program is not None:
            if self._use_pure_fp16:
368
                cast_model_to_fp16(
369 370 371 372
                    test_program,
                    self._amp_lists,
                    self._use_fp16_guard,
                    self._amp_vartype,
373 374
                    level='O2',
                    use_promote=self.use_promote,
375
                )
376
            elif use_fp16_test:
377 378 379 380 381 382 383 384
                # use_fp16_guard is not support amp-o1.
                cast_model_to_fp16(
                    test_program,
                    self._amp_lists,
                    use_fp16_guard=False,
                    dest_type=self._amp_vartype,
                    level='O1',
                    use_promote=self.use_promote,
385
                )
386

387
    def apply_gradients(self, params_grads):
388
        """
389
        Check scaled gradients to determine whether to update loss scaling and update
390
        parameters by their scaled gradients.
391

392
        Args:
393
            params_grads (list): A list of params and scaled grads.
394

395 396 397
        Returns:
            A list of optimize operators.
        """
J
Jie Fang 已提交
398

399 400 401 402
        # Change the op_role_var attr for some ops, so that gradients
        # transferred across GPUs can be FP16.
        update_role_var_grad(self._train_program, params_grads)

403 404
        # When not using dynamic loss scaling and the init loss scaling value is equal to 1.0,
        # the model can be optimized.
405 406 407 408
        if (
            not self._use_dynamic_loss_scaling
            and self._init_loss_scaling == 1.0
        ):
409 410
            return self._optimizer.apply_gradients(params_grads)

411 412 413 414 415 416 417 418
        if self._supports_check_nan_inf():
            self._optimizer._set_scale(self._loss_scaling)
            optimize_ops = self._optimizer.apply_gradients(params_grads)
            found_inf = self._optimizer._found_inf
            self._add_dynamic_loss_scaling(params_grads, found_inf)
            return optimize_ops

        found_inf = self._check_finite_and_unscale(params_grads)
419 420 421 422
        if (
            self._use_dynamic_loss_scaling
            and self._amp_vartype == core.VarDesc.VarType.FP16
        ):
423 424 425 426 427 428 429
            self._add_dynamic_loss_scaling(params_grads, found_inf)

        # Pass found_inf to adam, to skip update for not only param, but also momentum and beta_pow
        # With fleet, optimizers are nested and the real optimizer set by user is the inner most one.
        real_optimizer = self._optimizer
        while hasattr(real_optimizer, "inner_opt"):
            real_optimizer = real_optimizer.inner_opt
430 431 432 433
        if isinstance(
            real_optimizer,
            (paddle.fluid.optimizer.Adam, paddle.optimizer.AdamW),
        ):
434 435 436
            # NOTE(zhiqiu): Since found_inf needs to be on cpu in adam op, we
            # copy it in advance to avoid multiple time copies.
            with self._train_program._optimized_guard([]):
437
                found_inf = paddle.tensor.creation._memcpy(
438 439
                    found_inf, paddle.CPUPlace()
                )
440 441 442 443 444 445 446
            real_optimizer._set_auxiliary_var('found_inf', found_inf)
        elif hasattr(real_optimizer, "_set_auxiliary_var"):
            real_optimizer._set_auxiliary_var('found_inf', found_inf)
        optimize_ops = self._optimizer.apply_gradients(params_grads)
        return optimize_ops

    def _split_grads(self, params_grads):
447
        grads = [g for _, g in params_grads]
448
        fp32_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP32]
449
        fp16_grads = [g for g in grads if g.dtype == self._amp_vartype]
450 451
        assert len(fp32_grads) + len(fp16_grads) == len(
            grads
452
        ), "Data types of all grads must be either fp16/bf16 or fp32."
453
        return grads, fp32_grads, fp16_grads
454

455 456
    def _check_finite_and_unscale(self, params_grads):
        grads, fp32_grads, fp16_grads = self._split_grads(params_grads)
457
        found_infs = []
458

459
        if self._is_distributed:
460 461
            # if distributed, split check_finite_and_unscale to overlap
            # unscale with communication
462 463
            for p, g in params_grads:
                with self._train_program._optimized_guard([p, g]):
464
                    _, found_inf = check_finite_and_unscale(
465 466 467
                        [
                            g,
                        ],
468 469
                        self._loss_scaling,
                        name="find_infinite_scale",
470 471
                        float_status=self._float_status,
                    )
472
                    found_infs.append(found_inf)
473 474 475 476 477 478
        elif self._use_pure_fp16:
            if fp32_grads:
                with self._train_program._optimized_guard(fp32_grads):
                    _, fp32_found_inf = check_finite_and_unscale(
                        fp32_grads,
                        self._loss_scaling,
479
                        name="find_infinite_scale_fp32",
480 481
                        float_status=self._float_status,
                    )
482 483 484 485 486 487
                found_infs.append(fp32_found_inf)
            if fp16_grads:
                with self._train_program._optimized_guard(fp16_grads):
                    _, fp16_found_inf = check_finite_and_unscale(
                        fp16_grads,
                        self._loss_scaling,
488
                        name="find_infinite_scale_fp16",
489 490
                        float_status=self._float_status,
                    )
491 492 493 494
                found_infs.append(fp16_found_inf)
        else:
            with self._train_program._optimized_guard(grads):
                _, found_inf = check_finite_and_unscale(
495 496 497
                    grads,
                    self._loss_scaling,
                    name="find_infinite_scale",
498 499
                    float_status=self._float_status,
                )
J
Jie Fang 已提交
500

501 502
        if self._is_distributed or self._use_pure_fp16:
            with self._train_program._optimized_guard([]):
503
                all_infs = paddle.concat(found_infs)
504
                found_inf = paddle.any(all_infs)
505

506 507 508 509 510 511 512 513 514 515 516 517 518 519 520
        return found_inf

    def _add_dynamic_loss_scaling(self, params_grads, found_inf):
        if self._supports_check_nan_inf():
            with self._train_program._optimized_guard([]):
                update_loss_scaling(
                    [],
                    found_inf,
                    self._loss_scaling,
                    self._num_good_steps,
                    self._num_bad_steps,
                    self._incr_every_n_steps,
                    self._decr_every_n_nan_or_inf,
                    self._incr_ratio,
                    self._decr_ratio,
521
                    stop_update=self._optimizer._get_stop_update_var(),
522 523
                    name="update_loss_scaling",
                )
524 525 526 527 528 529 530
            return

        grads, fp32_grads, fp16_grads = self._split_grads(params_grads)
        if self._use_pure_fp16:
            stop_update = False
            with self._train_program._optimized_guard([]):
                if fp32_grads:
531 532 533 534 535 536 537 538 539 540 541 542 543
                    update_loss_scaling(
                        fp32_grads,
                        found_inf,
                        self._loss_scaling,
                        self._num_good_steps,
                        self._num_bad_steps,
                        self._incr_every_n_steps,
                        self._decr_every_n_nan_or_inf,
                        self._incr_ratio,
                        self._decr_ratio,
                        stop_update=stop_update,
                        name="update_loss_scaling_fp32",
                    )
544 545
                    stop_update = True
                if fp16_grads:
546 547 548 549 550 551 552 553 554 555 556 557 558
                    update_loss_scaling(
                        fp16_grads,
                        found_inf,
                        self._loss_scaling,
                        self._num_good_steps,
                        self._num_bad_steps,
                        self._incr_every_n_steps,
                        self._decr_every_n_nan_or_inf,
                        self._incr_ratio,
                        self._decr_ratio,
                        stop_update=stop_update,
                        name="update_loss_scaling_fp16",
                    )
559
        else:
R
Roc 已提交
560
            with self._train_program._optimized_guard([]):
561 562 563 564 565 566 567 568 569 570 571 572
                update_loss_scaling(
                    grads,
                    found_inf,
                    self._loss_scaling,
                    self._num_good_steps,
                    self._num_bad_steps,
                    self._incr_every_n_steps,
                    self._decr_every_n_nan_or_inf,
                    self._incr_ratio,
                    self._decr_ratio,
                    name="update_loss_scaling",
                )
573

574 575 576 577 578 579
    def apply_optimize(self, loss, startup_program, params_grads):
        program = loss.block.program
        with program_guard(program, startup_program):
            optimize_ops = self.apply_gradients(params_grads)
        return optimize_ops

580 581 582
    def minimize(
        self, loss, startup_program=None, parameter_list=None, no_grad_set=None
    ):
583 584 585 586 587
        """
        Perform optimization by minimizing the given loss.

        Args:
            loss (Variable): The loss Variable.
G
gongweibao 已提交
588 589 590 591
            startup_program (Program): startup_program for initializing parameters
                in `parameter_list`.
            parameter_list (list): list of Variables to update.
            no_grad_set (set|None): set of Variables should be ignored.
592 593 594

        Returns:
            The scaled loss by scaling factor, the list of optimize ops, and a
595
            list of scaled parameters and gradients.
596
        """
597

598
        opt_dict = self._optimizer.__class__.__dict__
599 600 601
        if 'minimize' in opt_dict and isinstance(
            opt_dict['minimize'], types.FunctionType
        ):
602 603 604 605
            warnings.warn(
                "The decorated optimizer has its own `minimize` method, but it will not be executed."
            )

606 607 608 609 610 611
        scaled_params_grads = self.backward(
            loss,
            startup_program=startup_program,
            parameter_list=parameter_list,
            no_grad_set=no_grad_set,
        )
G
gongweibao 已提交
612

613 614 615
        optimize_ops = self.apply_optimize(
            loss, startup_program, scaled_params_grads
        )
616

G
gongweibao 已提交
617
        return optimize_ops, scaled_params_grads
618 619


620
@overload(key=FunctionType.FP16_ONLY)
621 622 623 624 625 626 627 628 629 630 631
def decorate(
    optimizer,
    amp_lists=None,
    init_loss_scaling=2**15,
    incr_every_n_steps=1000,
    decr_every_n_nan_or_inf=2,
    incr_ratio=2.0,
    decr_ratio=0.8,
    use_dynamic_loss_scaling=True,
    use_pure_fp16=False,
    use_fp16_guard=None,
632
    use_bf16=False,
633
    use_promote=False,
634
):
635
    """
636 637 638 639
    Decorate the given optimizer to adapt to the mixed-precision training.

    Args:
        optimizer(Optimizer): A common Optimizer.
H
huangxu96 已提交
640
        amp_lists (CustomOpLists): An CustomOpLists object.
641
        init_loss_scaling(float): The initial loss scaling factor.
642
        incr_every_n_steps(int): Increases loss scaling every n consecutive
J
Jie Fang 已提交
643
                                 steps with finite gradients.
644 645
        decr_every_n_nan_or_inf(int): Decreases loss scaling every n
                                      accumulated steps with nan or
J
Jie Fang 已提交
646
                                      inf gradients.
647
        incr_ratio(float): The multiplier to use when increasing the loss
J
Jie Fang 已提交
648
                           scaling.
649
        decr_ratio(float): The less-than-one-multiplier to use when decreasing
J
Jie Fang 已提交
650
                           the loss scaling.
651
        use_dynamic_loss_scaling(bool): Whether to use dynamic loss scaling.
652 653 654
        use_pure_fp16(bool): Whether to use the pure fp16 training. Default False.
        use_fp16_guard(bool): Whether to use `fp16_guard` when constructing the program.
                           Default None, which means that its value equals to `use_pure_fp16`.
655
        use_bf16(bool): Whether to enable bfloat16 training. Default False.
656 657

    Returns:
658
        An optimizer acting like a normal one but with mixed-precision training
659 660
        enabled.

H
huangxu96 已提交
661
    Examples 1:
662
            .. code-block:: python
H
huangxu96 已提交
663 664 665 666 667 668 669 670 671 672 673 674 675 676

            # black&white list based strategy example
            import paddle
            import paddle.static as static

            paddle.enable_static()

            data = static.data(name='X', shape=[None, 1], dtype='float32')
            hidden = static.nn.fc(x=data, size=10)
            loss = paddle.mean(hidden)
            optimizer = paddle.optimizer.Adam(learning_rate=0.001)

            mp_optimizer = static.amp.decorate(
                    optimizer=optimizer, init_loss_scaling=8.0)
677

G
gongweibao 已提交
678
            ops, param_grads = mp_optimizer.minimize(loss)
679
            scaled_loss = mp_optimizer.get_scaled_loss()
H
huangxu96 已提交
680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701

    Examples 2:
        .. code-block:: python

            # pure fp16 training example
            import numpy as np
            import paddle
            import paddle.nn.functional as F

            def run_example_code():
                place = paddle.CUDAPlace(0)
                exe = paddle.static.Executor(place)
                data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float32')
                conv2d = paddle.static.nn.conv2d(input=data, num_filters=6, filter_size=3)
                # 1) Use fp16_guard to control the range of fp16 kernels used.
                with paddle.static.amp.fp16_guard():
                    bn = paddle.static.nn.batch_norm(input=conv2d, act="relu")
                    pool = F.max_pool2d(bn, kernel_size=2, stride=2)
                    hidden = paddle.static.nn.fc(pool, size=10)
                    loss = paddle.mean(hidden)
                # 2) Create the optimizer and set `multi_precision` to True.
                # Setting `multi_precision` to True can avoid the poor accuracy
702
                # or the slow convergence in a way.
H
huangxu96 已提交
703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721
                optimizer = paddle.optimizer.Momentum(learning_rate=0.01, multi_precision=True)
                # 3) These ops in `custom_black_list` will keep in the float32 computation type.
                amp_list = paddle.static.amp.CustomOpLists(
                    custom_black_list=['pool2d'])
                # 4) The entry of Paddle AMP.
                # Enable pure fp16 training by setting `use_pure_fp16` to True.
                optimizer = paddle.static.amp.decorate(
                    optimizer,
                    amp_list,
                    init_loss_scaling=128.0,
                    use_dynamic_loss_scaling=True,
                    use_pure_fp16=True)
                # If you don't use the default_startup_program(), you sholud pass
                # your defined `startup_program` into `minimize`.
                optimizer.minimize(loss)
                exe.run(paddle.static.default_startup_program())
                # 5) Use `amp_init` after FP32 parameters initialization(such as `exe.run(startup_program)`).
                # If you want to perform the testing process, you should pass `test_program` into `amp_init`.
                optimizer.amp_init(place, scope=paddle.static.global_scope())
722

H
huangxu96 已提交
723 724
            if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0:
                run_example_code()
725
    """
726
    amp_dtype = "bfloat16" if use_bf16 else "float16"
J
Jie Fang 已提交
727
    if amp_lists is None:
728
        amp_lists = AutoMixedPrecisionLists(dtype=amp_dtype)
729 730 731 732

    if use_fp16_guard is None:
        use_fp16_guard = use_pure_fp16

733
    amp_level = "O2" if use_pure_fp16 else "O1"
Z
Zhen Wang 已提交
734
    mp_optimizer = OptimizerWithMixedPrecision(
735 736
        optimizer,
        amp_lists,
737 738 739 740 741 742 743 744 745
        level=amp_level,
        dtype=amp_dtype,
        init_loss_scaling=init_loss_scaling,
        use_dynamic_loss_scaling=use_dynamic_loss_scaling,
        incr_every_n_steps=incr_every_n_steps,
        decr_every_n_nan_or_inf=decr_every_n_nan_or_inf,
        incr_ratio=incr_ratio,
        decr_ratio=decr_ratio,
        use_amp_guard=use_fp16_guard,
746
        use_promote=use_promote,
747 748 749 750 751
    )

    return mp_optimizer


752 753
@overload(key=FunctionType.COMMON)
def decorate(
754 755 756 757
    optimizer,
    amp_lists=None,
    level='O1',
    dtype='float16',
758
    master_weight=None,
759 760 761 762 763
    init_loss_scaling=2**15,
    incr_every_n_steps=1000,
    decr_every_n_nan_or_inf=2,
    incr_ratio=2.0,
    decr_ratio=0.8,
764
    use_dynamic_loss_scaling=None,
765
    use_amp_guard=False,
766
    use_promote=False,
767 768 769 770
):
    """
    Decorate the given optimizer to adapt to the mixed-precision training.

771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847
    Args:
        optimizer(Optimizer): A common Optimizer.
        amp_lists(CustomOpLists, optional): An CustomOpLists object. The default
            white_list and black_list will be used for AMP training when it is
            not set. Default is None.
        level(str, optional): Auto mixed precision level. Accepted values are
            "O1" and "O2": O1 represent mixed precision, the input data type of
            each operator will be casted by white_list and black_list;
            O2 represent pure FP16 / BF16 training, all operators parameters
            and input data will be casted to FP16 / BF16, except operators in
            black_list, don't support FP16 / BF16 kernel and batch_norm. Default is O1.
        dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'.
        master_weight(bool, optinal): For level='O2', whether to use multi-precision
            during weight updating. If master_weight is None, in O2 level optimizer
            will use multi-precision. Default is None.
        init_loss_scaling(float, optional): The initial loss scaling factor.
            Default is 32768.
        incr_every_n_steps(int, optional): Increases loss scaling every n
            consecutive steps with finite gradients. Default is 1000.
        decr_every_n_nan_or_inf(int, optional): Decreases loss scaling every n
            accumulated steps with nan or inf gradients. Default is 2.
        incr_ratio(float, optional): The multiplier to use when increasing the
            loss scaling. Default is 2.
        decr_ratio(float, optional): The less-than-one-multiplier to use when
            decreasing the loss scaling. Default is 0.8.
        use_dynamic_loss_scaling(bool, None): Whether to use dynamic loss
            scaling. Default is None, which means True for float16, and False
            for bfloat16.

    Returns:
        An optimizer acting like a normal one but with mixed-precision training

    Examples:

     .. code-block:: python

        import paddle

        paddle.enable_static()

        class SimpleConvNet(paddle.nn.Layer):
            def __init__(self):
                super().__init__()
                self.conv = paddle.nn.Conv2D(in_channels=1, out_channels=6, kernel_size=3)
                self.linear = paddle.nn.Linear(in_features=26, out_features=10)

            def forward(self, x):
                out = self.conv(x)
                out = paddle.nn.functional.relu(out)
                out = self.linear(out)
                out = paddle.nn.functional.softmax(out)
                return out

        main_program = paddle.static.Program()
        startup_program = paddle.static.Program()
        with paddle.utils.unique_name.guard():
            with paddle.static.program_guard(main_program, startup_program):
                model = SimpleConvNet()
                x = paddle.static.data(
                    name='input', shape=[None, 1, 28, 28], dtype='float32'
                )
                out = model(x)
                loss = paddle.mean(out)
                optimizer = paddle.optimizer.AdamW()
                optimizer = paddle.static.amp.decorate(optimizer, level="O2", dtype="float16")
                optimizer.minimize(loss)

        if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0:
            place = paddle.CUDAPlace(0)
            exe = paddle.static.Executor(place)
            exe.run(startup_program)

            # Call `amp_init` after FP32 parameters initialization, such as `exe.run(startup_program)`,
            # to convert FP32 parameters to low precision FP16 / BF16.
            optimizer.amp_init(place, scope=paddle.static.global_scope())

    """
848 849 850 851 852 853 854
    # check amp_level: O0-O2
    level = level.upper()
    if not (level in ['O0', 'O1', 'O2']):
        raise ValueError(
            "level should be O0, O1 or O2. O0 represents fp32 train mode, O1 represents AMP train mode, O2 represents pure fp16/bf16 train mode."
        )

855 856 857 858 859 860 861 862 863 864 865 866
    amp_dtype = check_amp_dtype(dtype)
    if amp_lists is None:
        amp_lists = AutoMixedPrecisionLists(dtype=amp_dtype)

    if use_dynamic_loss_scaling is None:
        use_dynamic_loss_scaling = dtype == "float16"

    if optimizer is not None:
        # support master_weight
        multi_precision = not (master_weight is False)
        _set_multi_precision(optimizer, multi_precision)

867 868 869 870 871 872 873 874 875 876 877 878
    mp_optimizer = OptimizerWithMixedPrecision(
        optimizer,
        amp_lists,
        level=level,
        dtype=amp_dtype,
        init_loss_scaling=init_loss_scaling,
        use_dynamic_loss_scaling=use_dynamic_loss_scaling,
        incr_every_n_steps=incr_every_n_steps,
        decr_every_n_nan_or_inf=decr_every_n_nan_or_inf,
        incr_ratio=incr_ratio,
        decr_ratio=decr_ratio,
        use_amp_guard=use_amp_guard,
879
        use_promote=use_promote,
880
    )
881 882

    return mp_optimizer