optimizer.py 60.4 KB
Newer Older
M
MRXLT 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from collections import defaultdict

18 19
import numpy as np

20
import paddle
21 22
from paddle import _C_ops, _legacy_C_ops
from paddle.fluid import core
23 24
from paddle.fluid.framework import (
    Variable,
25 26 27
    _current_expected_place,
    _in_eager_without_dygraph_check,
    _in_legacy_dygraph,
28 29
    default_main_program,
    device_guard,
30
    in_dygraph_mode,
31 32
    name_scope,
)
M
MRXLT 已提交
33

34
from ..fluid import framework, layers, unique_name
35
from ..fluid.backward import _get_no_grad_set_name, append_backward
36 37 38 39 40
from ..fluid.clip import (
    GradientClipBase,
    append_gradient_clip_ops,
    error_clip_callback,
)
41 42
from ..fluid.dygraph import base as imperative_base
from ..fluid.framework import Parameter, program_guard
M
MRXLT 已提交
43 44
from ..fluid.initializer import Constant
from ..fluid.layer_helper import LayerHelper
45
from .lr import LRScheduler
M
MRXLT 已提交
46

47 48
__all__ = []

M
MRXLT 已提交
49

50
@framework.static_only
51 52 53 54 55 56 57 58
def append_backward_new(
    loss_list,
    parameter_list=None,
    no_grad_set=None,
    callbacks=None,
    checkpoints=None,
    distop_context=None,
):
59
    from paddle.incubate.autograd.primx import Transform, orig2prim
60

61
    program = default_main_program()
62 63 64
    assert (
        program.num_blocks == 1
    ), "The append_backward_new interface is designed to process only one block."
65
    block = program.current_block()
66
    for el in loss_list:
67 68 69
        assert (
            el.block == block
        ), 'variable in loss_list should be in current block of main program'
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97

    orig2prim(block)
    ad = Transform(block)
    if parameter_list is None:
        parameter_list = program.global_block().all_parameters()
    param_dot, loss_dot = ad.linearize(parameter_list, loss_list)
    loss_bar, param_bar = ad.transpose(loss_dot, param_dot)

    # remove param_dot and their constructor ops
    op_indexes = []
    for var in param_dot:
        if var is not None:
            op_index = block.ops.index(var.op)
            assert op_index >= 0
            op_indexes.append(op_index)

    ad.erase_ops(sorted(op_indexes))
    ad.erase_dots(param_dot)

    if len(parameter_list) == 1:
        params_and_grads = [(parameter_list, param_bar)]
    else:
        params_and_grads = []
        for i, param in enumerate(parameter_list):
            params_and_grads.append((param, param_bar[i]))
    return params_and_grads


98
class Optimizer:
99
    r"""Optimizer Base class.
M
MRXLT 已提交
100 101 102 103 104 105

    Define the common interface of an optimizer.
    User should not use this class directly,
    but need to use one of it's implementation.

    Args:
106 107
        learning_rate (float|LRScheduler): The learning rate used to update ``Parameter``.
            It can be a float value or any subclass of ``LRScheduler`` .
108
        parameters (list|tuple, optional): List/Tuple of ``Tensor`` names to update to minimize ``loss``. \
109 110 111 112
            This parameter is required in dygraph mode. And you can specify different options for \
            different parameter groups such as the learning rate, weight decay, etc, \
            then the parameters are list of dict. Note that the learning_rate in paramter groups \
            represents the scale of base learning_rate. \
M
MRXLT 已提交
113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
            The default value is None in static mode, at this time all parameters will be updated.
        weight_decay (float|WeightDecayRegularizer, optional): The strategy of regularization. \
            It canbe a float value as coeff of L2 regularization or \
            :ref:`api_fluid_regularizer_L1Decay`, :ref:`api_fluid_regularizer_L2Decay`.
            If a parameter has set regularizer using :ref:`api_fluid_ParamAttr` already, \
            the regularization setting here in optimizer will be ignored for this parameter. \
            Otherwise, the regularization setting here in optimizer will take effect. \
            Default None, meaning there is no regularization.
        grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of \
            some derived class of ``GradientClipBase`` . There are three cliping strategies \
            ( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` , \
            :ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping.
        name (str, optional): Normally there is no need for user to set this property.
            For more information, please refer to :ref:`api_guide_Name`.
            The default value is None.

    Returns:
130 131
       Base class for optimizer.

M
MRXLT 已提交
132 133 134 135 136 137
    Examples:
        .. code-block:: python

            #Take the subclass adam as an example
            import paddle
            linear = paddle.nn.Linear(10, 10)
138
            inp = paddle.uniform(shape=[10, 10], min=-0.1, max=0.1)
M
MRXLT 已提交
139 140 141 142
            out = linear(inp)
            loss = paddle.mean(out)
            adam = paddle.optimizer.Adam(learning_rate=0.1,
                    parameters=linear.parameters())
R
Roc 已提交
143
            loss.backward()
M
MRXLT 已提交
144 145 146
            adam.step()
            adam.clear_grad()

147
            #Take the subclass sgd as an example
148
            #optimize parameters in linear_1 and linear2 in different options.
149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
            #Note that the learning_rate of linear_2 is 0.01.
            linear_1 = paddle.nn.Linear(10, 10)
            linear_2 = paddle.nn.Linear(10, 10)
            inp = paddle.uniform(shape=[10, 10], min=-0.1, max=0.1)
            out = linear_1(inp)
            out = linear_2(out)
            loss = paddle.mean(out)
            sgd = paddle.optimizer.SGD(
                learning_rate=0.1,
                parameters=[{
                    'params': linear_1.parameters()
                }, {
                    'params': linear_2.parameters(),
                    'weight_decay': 0.001,
                    'learning_rate': 0.1
                }],
165
                weight_decay=0.01)
R
Roc 已提交
166
            loss.backward()
167 168 169
            sgd.step()
            sgd.clear_grad()

M
MRXLT 已提交
170 171
    """

172
    @imperative_base.no_grad
173 174 175 176 177 178 179 180
    def __init__(
        self,
        learning_rate,
        parameters=None,
        weight_decay=None,
        grad_clip=None,
        name=None,
    ):
181

182 183 184 185
        if parameters is not None:
            # paddle.Tensor is also iterable, so here we don't check whether
            # the input is iterable, if the input is paddle.Tensor, the
            # list(paddle.Tensor) will be a error value
186
            if isinstance(parameters, (paddle.Tensor, core.eager.Tensor)):
187 188
                raise TypeError(
                    "`parameters` argument given to the optimizer should be "
189 190 191 192
                    "an iterable of paddle Tensors, but got argument type is `{}`.".format(
                        type(parameters)
                    )
                )
193 194 195 196
            if isinstance(parameters, dict):
                raise TypeError(
                    "`parameters` argument should not get dict type, "
                    "if parameter groups is needed, please set `parameters`"
197 198
                    " as list of dict"
                )
199 200 201 202
            self._parameter_list = list(parameters)
        else:
            self._parameter_list = None

M
MRXLT 已提交
203
        self._name = name
J
Jiabin Yang 已提交
204
        if framework._non_static_mode():
M
MRXLT 已提交
205 206 207 208 209
            if self._parameter_list is None:
                raise AttributeError(
                    "parameters argument given to the Optimizer should not be None in dygraph mode."
                )
            if weight_decay is not None:
210 211
                if not isinstance(self._parameter_list[0], dict):
                    for param in self._parameter_list:
212 213 214 215
                        if (
                            hasattr(param, 'regularizer')
                            and param.regularizer is not None
                        ):
216 217 218
                            logging.info(
                                "If regularizer of a Parameter has been set by 'paddle.ParamAttr' or 'static.WeightNormParamAttr' already. "
                                "The weight_decay[%s] in Optimizer will not take effect, and it will only be applied to other Parameters!"
219 220
                                % weight_decay.__str__()
                            )
221 222
                            break

223
        if not isinstance(learning_rate, (float, LRScheduler)):
224
            raise TypeError(
225 226 227
                "learning rate should be float or LRScheduler, got %s here"
                % type(learning_rate)
            )
M
MRXLT 已提交
228 229 230 231 232 233 234
        if grad_clip is not None:
            if not isinstance(grad_clip, GradientClipBase):
                raise TypeError(
                    "'grad_clip' should be an instance of GradientClipBase's derived class"
                )
        if isinstance(weight_decay, float):
            from ..fluid.regularizer import L2Decay
235

M
MRXLT 已提交
236 237 238 239 240
            self.regularization = L2Decay(weight_decay)
        else:
            self.regularization = weight_decay
        self._grad_clip = grad_clip
        self._learning_rate = learning_rate
L
Leo Chen 已提交
241

M
MRXLT 已提交
242
        self._dtype = None
L
Leo Chen 已提交
243 244
        # Infer the dtype form parameter
        if self._parameter_list:
245 246
            if isinstance(self._parameter_list[0], dict):
                for param_group in self._parameter_list:
247 248 249
                    assert (
                        'params' in param_group
                    ), 'params should be set in parameters if parameter groups are optimized in different options'
250 251 252
                self._dtype = self._parameter_list[0]['params'][0].dtype
            else:
                self._dtype = self._parameter_list[0].dtype
L
Leo Chen 已提交
253

M
MRXLT 已提交
254 255 256 257 258 259 260 261 262 263 264 265 266
        # each program should have a independent learning rate
        # program -> tensor(learning_rate)
        self._learning_rate_map = dict()
        # Dictionary of accumulators. Some optimizer subclasses need to
        # allocate and manage extra tensors associated with the parameters
        # to train. These tensors are called accumulators.
        # {accum_name : { paramter_name : accumulator_for_parameter, ...}, ...}
        self._accumulators = defaultdict(lambda: dict())
        self.helper = None
        self._opti_name_list = []
        self._accumulators_holder = {}
        self._param_device_map = dict()
        self.clear_gradients = self.clear_grad
267 268
        self._default_dict = {
            'weight_decay': self.regularization,
269
            'grad_clip': self._grad_clip,
270 271 272 273 274 275 276 277
        }

        self._param_groups = []
        if self._parameter_list and isinstance(self._parameter_list[0], dict):
            for param_group in self._parameter_list:
                self._add_param_group(param_group.copy())
        else:
            self._param_groups = self._parameter_list
M
MRXLT 已提交
278

279
        # NOTE: Multi Tensor: Pass in all parameters and gradients to the op kernel of the Optimizer at one time for updating for dygraph mode.
Z
zhangbo9674 已提交
280
        # Optimizer support list: [ paddle.optimizer.Momentum, paddle.optimizer.Adam].
281 282
        self._use_multi_tensor = None

283
        self._param_dict = self._create_multi_tensor_dict()
284 285 286 287 288
        self._auxiliary_vars = {}

    def _set_auxiliary_var(self, key, val):
        self._auxiliary_vars[key] = val

289 290 291 292 293 294 295
    def _create_multi_tensor_dict(self):
        n = len(self._param_groups) if self._param_groups is not None else 1
        return {
            'FP32_LODTensor': [[] for _ in range(n)],
            'FP16_LODTensor': [[] for _ in range(n)],
        }

296 297 298
    def _get_auxiliary_var(self, key):
        return self._auxiliary_vars.get(key, None)

M
MRXLT 已提交
299 300 301
    @framework.dygraph_only
    def state_dict(self):
        '''
302
        Get state dict information from optimizer. It contain all the tensor used by optimizer. For Adam optimizer, contains beta1, beta2, momentum etc. If LRScheduler have been used, global_step will be include in state dict.
M
MRXLT 已提交
303 304
        If the optimizer never be called(minimize function), the state_dict is empty.

305
        Args:
M
MRXLT 已提交
306 307 308 309
            None

        Returns:
            state_dict(dict) : dict contains all the Tensor used by optimizer
310

M
MRXLT 已提交
311 312 313 314
        Examples:
            .. code-block:: python

                import paddle
M
MRXLT 已提交
315
                emb = paddle.nn.Embedding(10, 10)
M
MRXLT 已提交
316 317 318 319 320 321 322 323 324

                adam = paddle.optimizer.Adam(0.001, parameters=emb.parameters())
                state_dict = adam.state_dict()

        '''
        state_dict = {}
        for k, v in self._accumulators.items():
            for para_name, var_tmp in v.items():
                state_dict[var_tmp.name] = var_tmp
325 326 327 328
        # if has master weight and then save master weight
        if hasattr(self, "_master_weights"):
            if len(self._master_weights) != 0:
                state_dict["master_weights"] = self._master_weights
M
MRXLT 已提交
329
        # global step if use lr decay
330
        if isinstance(self._learning_rate, LRScheduler):
M
MRXLT 已提交
331 332 333 334 335 336
            state_dict["LR_Scheduler"] = self._learning_rate.state_dict()
        return state_dict

    @framework.dygraph_only
    def set_state_dict(self, state_dict):
        '''
337
        Load optimizer state dict. For Adam optimizer, contains beta1, beta2, momentum etc. If LRScheduler have been used, global_step will be changed.
M
MRXLT 已提交
338

339
        Args:
M
MRXLT 已提交
340 341 342
            state_dict(dict) : Dict contains all the Tensor needed by optimizer
        Return:
            None
343

M
MRXLT 已提交
344 345 346 347 348
        Examples:
            .. code-block:: python

                import paddle

349
                emb = paddle.nn.Embedding(10, 10)
M
MRXLT 已提交
350

351 352
                layer_state_dict = emb.state_dict()
                paddle.save(layer_state_dict, "emb.pdparams")
M
MRXLT 已提交
353

354
                scheduler = paddle.optimizer.lr.NoamDecay(
355 356 357 358 359 360
                    d_model=0.01, warmup_steps=100, verbose=True)
                adam = paddle.optimizer.Adam(
                    learning_rate=scheduler,
                    parameters=emb.parameters())
                opt_state_dict = adam.state_dict()
                paddle.save(opt_state_dict, "adam.pdopt")
M
MRXLT 已提交
361

362
                opti_state_dict = paddle.load("adam.pdopt")
M
MRXLT 已提交
363 364 365
                adam.set_state_dict(opti_state_dict)

        '''
366
        if isinstance(self._learning_rate, LRScheduler):
367
            self._learning_rate.set_state_dict(state_dict["LR_Scheduler"])
M
MRXLT 已提交
368

369
        # NOTE: exclude learning rate scheduler's state from
370 371 372 373
        # _accumulators_holder.
        state_dict = state_dict.copy()
        if "LR_Scheduler" in state_dict:
            state_dict.pop("LR_Scheduler")
374 375 376 377
        if "master_weights" in state_dict:
            if hasattr(self, "_master_weights"):
                self._master_weights = state_dict["master_weights"]
            state_dict.pop("master_weights")
M
MRXLT 已提交
378 379 380
        self._accumulators_holder = state_dict
        for k, v in self._accumulators.items():
            for para_name, var_tmp in v.items():
381 382 383
                assert (
                    var_tmp.name in state_dict
                ), "optimizer Tensor {} not found".format(var_tmp.name)
M
MRXLT 已提交
384 385 386 387 388 389 390 391 392 393 394 395 396
                var = var_tmp.value()
                tensor = var.get_tensor()
                model_np = np.array(tensor)

                load_para = state_dict[var_tmp.name]

                if isinstance(load_para, Variable):
                    load_para_np = load_para.numpy()
                elif isinstance(load_para, core.VarBase):
                    load_para_np = load_para.numpy()
                elif isinstance(load_para, np.ndarray):
                    load_para_np = load_para
                else:
397 398 399 400 401 402 403 404 405 406 407
                    raise RuntimeError(
                        "State dict type {} not supprt".format(
                            str(type(load_para))
                        )
                    )

                assert (
                    model_np.shape == load_para_np.shape
                ), "Parameter shape not match, Dygraph Parameter [ {} ] need tensor with shape {} but load tensor with shape {}".format(
                    model_np.name, model_np.shape, load_para_np.shape
                )
M
MRXLT 已提交
408

409 410 411 412 413
                assert (
                    model_np.dtype == load_para_np.dtype
                ), "Parameter dtype not match, Dygraph Parameter [ {} ] need tensor with dtype {}  but load tensor with dtype {}".format(
                    model_np.name, model_np.dtype, load_para_np.dtype
                )
M
MRXLT 已提交
414 415 416 417 418 419 420

                tensor.set(load_para_np, framework._current_expected_place())

    def get_opti_var_name_list(self):
        return self._opti_name_list

    def _create_global_learning_rate(self):
421
        # lr var can't be float16 or bfloat16, for pure fp16 or bf16 training, should extra handle the dtype for lr
422 423 424 425 426 427
        _lr_dtype = (
            paddle.get_default_dtype() if self._dtype is None else self._dtype
        )
        _lr_dtype = (
            paddle.float32
            if (
428 429 430 431 432 433 434 435
                (
                    paddle.get_default_dtype() != "float16"
                    and _lr_dtype == paddle.float16
                )
                or (
                    paddle.get_default_dtype() != "bfloat16"
                    and _lr_dtype == paddle.bfloat16
                )
436 437 438
            )
            else _lr_dtype
        )
439
        if isinstance(self._learning_rate, LRScheduler):
440 441 442 443 444
            lr_var = self._global_learning_rate()
            # only create global lr_var once
            if not isinstance(lr_var, framework.Variable):
                lr_name = unique_name.generate('learning_rate')
                self._learning_rate._var_name = lr_name
445 446 447 448 449 450 451
                lr_var = self.helper.create_global_variable(
                    name=lr_name,
                    shape=[1],
                    persistable=True,
                    stop_gradient=True,
                    dtype=_lr_dtype,
                )
452 453 454
                main_prog = framework.default_main_program()
                main_prog.lr_sheduler = self._learning_rate
                main_prog.lr_var = lr_var
M
MRXLT 已提交
455

456
                self._learning_rate_map[
457 458
                    framework.default_main_program()
                ] = lr_var
M
MRXLT 已提交
459

460 461
            lr_value = float(self._learning_rate())
            self.helper.set_variable_initializer(
462 463
                lr_var, initializer=Constant(value=lr_value)
            )
464 465 466
        elif isinstance(self._learning_rate, float):
            # only create global lr_var once
            lr = self._global_learning_rate()
M
MRXLT 已提交
467 468 469
            if isinstance(lr, framework.Variable):
                return
            else:
470 471 472
                self._learning_rate_map[
                    framework.default_main_program()
                ] = layers.create_global_var(
473 474 475
                    name=unique_name.generate("learning_rate"),
                    shape=[1],
                    value=float(self._learning_rate),
476
                    dtype=_lr_dtype,
477 478
                    persistable=True,
                )
M
MRXLT 已提交
479 480 481 482 483

    @framework.dygraph_only
    def set_lr(self, value):
        """
        :api_attr: imperative
484

485
        Set the value of the learning rate manually in the optimizer. If the optimizer use LRScheduler,
M
MRXLT 已提交
486 487 488
        this API cannot be invoked, because it will lead to conflict.

        Args:
M
MRXLT 已提交
489
            value (float): the value of learning rate
M
MRXLT 已提交
490 491 492

        Returns:
            None
493

M
MRXLT 已提交
494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515
        Examples:
            .. code-block:: python

                import paddle
                linear = paddle.nn.Linear(10, 10)

                adam = paddle.optimizer.Adam(0.1, parameters=linear.parameters())

                # set learning rate manually by python float value
                lr_list = [0.2, 0.3, 0.4, 0.5, 0.6]
                for i in range(5):
                    adam.set_lr(lr_list[i])
                    lr = adam.get_lr()
                    print("current lr is {}".format(lr))
                # Print:
                #    current lr is 0.2
                #    current lr is 0.3
                #    current lr is 0.4
                #    current lr is 0.5
                #    current lr is 0.6

        """
516
        if not isinstance(value, (int, float)):
M
MRXLT 已提交
517
            raise TypeError(
518
                "The type of 'value' in optimizer.set_lr must be float, but received %s."
519 520
                % (type(value))
            )
521
        if isinstance(self._learning_rate, LRScheduler):
M
MRXLT 已提交
522
            raise RuntimeError(
523
                "optimizer's learning rate can't be LRScheduler when invoke this API, because this will lead to conflict."
M
MRXLT 已提交
524
            )
525 526 527
        self._learning_rate = float(value)
        current_lr = self._global_learning_rate()
        if current_lr is not None:
528 529
            if in_dygraph_mode():
                place = _current_expected_place()
530 531 532 533 534 535 536
                _C_ops.full_(
                    current_lr,
                    list(current_lr.shape),
                    float(value),
                    current_lr.dtype,
                    place,
                )
537 538

            elif _in_legacy_dygraph():
539 540 541 542 543 544 545 546 547
                _legacy_C_ops.fill_constant(
                    current_lr,
                    'value',
                    float(value),
                    'dtype',
                    current_lr.dtype,
                    'shape',
                    list(current_lr.shape),
                )
548 549
            else:
                global_block = framework.default_main_program().global_block()
550 551 552 553 554 555 556 557 558 559
                global_block.append_op(
                    type='fill_constant',
                    outputs={'Out': [current_lr]},
                    attrs={
                        'dtype': current_lr.dtype,
                        'shape': list(current_lr.shape),
                        'value': float(value),
                    },
                    stop_gradient=True,
                )
M
MRXLT 已提交
560 561 562

    def get_lr(self):
        """
563
        Get current learning rate of optimizer.
564 565
        If 'LRScheduler' is not used, the return value is all the same.
        If 'LRScheduler' is used, the return value is the current scheduled learing rete.
M
MRXLT 已提交
566

M
MRXLT 已提交
567
        Returns:
568
            float: The current learning rate of optimizer.
M
MRXLT 已提交
569 570 571 572

        Examples:
            .. code-block:: python

573
                # train on default dynamic graph mode
M
MRXLT 已提交
574
                import paddle
575 576 577 578 579 580 581 582 583 584 585
                import numpy as np
                emb = paddle.nn.Embedding(10, 3)

                ## example1: LRScheduler is not used, return the same value is all the same
                adam = paddle.optimizer.Adam(0.01, parameters = emb.parameters())
                for batch in range(10):
                    input = paddle.randint(low=0, high=5, shape=[5])
                    out = emb(input)
                    out.backward()
                    print("Learning rate of step{}: {}".format(batch, adam.get_lr())) # 0.01
                    adam.step()
M
MRXLT 已提交
586

587 588 589 590 591 592 593 594
                ## example2: StepDecay is used, return the scheduled learning rate
                scheduler = paddle.optimizer.lr.StepDecay(learning_rate=0.5, step_size=2, gamma=0.1)
                adam = paddle.optimizer.Adam(scheduler, parameters = emb.parameters())
                for batch in range(10):
                    input = paddle.randint(low=0, high=5, shape=[5])
                    out = emb(input)
                    out.backward()
                    print("Learning rate of step{}: {}".format(batch, adam.get_lr())) # 0.5->0.05...
M
MRXLT 已提交
595
                    adam.step()
596
                    scheduler.step()
597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615

                # train on static graph mode
                paddle.enable_static()
                main_prog = paddle.static.Program()
                start_prog = paddle.static.Program()
                with paddle.static.program_guard(main_prog, start_prog):
                    x = paddle.static.data(name='x', shape=[None, 10])
                    z = paddle.static.nn.fc(x, 100)
                    loss = paddle.mean(z)
                    scheduler = paddle.optimizer.lr.StepDecay(learning_rate=0.5, step_size=2, gamma=0.1)
                    adam = paddle.optimizer.Adam(learning_rate=scheduler)
                    adam.minimize(loss)

                exe = paddle.static.Executor()
                exe.run(start_prog)
                for batch in range(10):
                    print("Learning rate of step{}: {}", adam.get_lr())     # 0.5->0.05->0.005...
                    out = exe.run(main_prog, feed={'x': np.random.randn(3, 10).astype('float32')})
                    scheduler.step()
M
MRXLT 已提交
616 617 618 619 620

        """
        if isinstance(self._learning_rate, float):
            return self._learning_rate
        else:
621
            return self._learning_rate()
M
MRXLT 已提交
622 623 624 625 626 627 628 629 630 631 632

    def _global_learning_rate(self, program=None):
        """
        get global decayed learning rate
        :return:
        """
        if program is None:
            program = framework.default_main_program()
        return self._learning_rate_map.get(program, None)

    def _append_optimize_op(self, block, param_and_grad):
633
        """append optimize operator to block and return all the added optimize_op"""
M
MRXLT 已提交
634 635 636 637 638 639 640
        raise NotImplementedError(
            "Class \"Optimizer\" connot be used directly as an optimizer, please use its subclasses such as \"Adam\""
        )

    def _create_param_lr(self, param_and_grad):
        # create learning rate tensor for every parameter
        param = param_and_grad[0]
641 642 643 644
        if hasattr(param, 'optimize_attr'):
            param_lr = param.optimize_attr['learning_rate']
            if type(param_lr) == Variable:
                return param_lr
M
MRXLT 已提交
645
            else:
646 647 648 649
                if param_lr == 1.0:
                    return self._global_learning_rate()
                else:
                    with default_main_program()._lr_schedule_guard(
650 651
                        is_with_opt=True
                    ), framework.name_scope('scale_with_param_lr'):
652 653 654
                        return self._global_learning_rate() * param_lr
        else:
            return self._global_learning_rate()
M
MRXLT 已提交
655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677

    def _create_accumulators(self, block, parameters):
        """Create all accumulators needed by the parameters

        Args:
            block: the block in which the loss tensor is present
            parameters: list of parameter tensors for the optimizer
        """
        pass

    def _finish_update(self, block, parameters_and_grads):
        """Finish any custom updates needed
           before completing an optimization step

        Args:
            block: the block in which the loss tensor is present
            parameters: list of parameter tensors for the optimizer

        Returns:
            None
        """
        pass

678 679 680 681 682 683 684 685 686 687
    def _add_accumulator(
        self,
        name,
        param,
        dtype=None,
        fill_value=0.0,
        shape=None,
        type=None,
        device=None,
    ):
M
MRXLT 已提交
688 689 690 691 692 693 694 695 696 697 698
        """Utility function to add an accumulator for a parameter

        Args:
            block: the block in which the loss tensor is present
            name: name of the accumulator
            param: parameter tensor for which accumulator is to be added
            dtype: data type of the accumulator tensor
            fill_value: value to initialize the accumulator tensor
        """
        if self._name is not None:
            name = self._name + "_" + name
699 700 701 702
        if (
            name in self._accumulators
            and param.name in self._accumulators[name]
        ):
J
Jiabin Yang 已提交
703
            if framework._non_static_mode():
M
MRXLT 已提交
704
                return self._accumulators[name][param.name]
705 706
            raise Exception(
                "Accumulator {} already exists for parameter {}".format(
707 708 709
                    name, param.name
                )
            )
710
        if shape is None:
M
MRXLT 已提交
711 712 713 714 715 716 717 718 719 720 721
            shape = param.shape
        assert isinstance(self.helper, LayerHelper)

        var_name = param.name + "_" + name
        var_name = unique_name.generate(var_name)
        self._opti_name_list.append(var_name)

        var = self.helper.create_global_variable(
            name=var_name,
            persistable=True,
            dtype=dtype or param.dtype,
722
            type=core.VarDesc.VarType.LOD_TENSOR
723 724
            if framework._in_eager_without_dygraph_check()
            else (param.type if type is None else type),
M
MRXLT 已提交
725
            shape=shape,
726 727
            belong_to_optimizer=True,
        )
M
MRXLT 已提交
728 729
        if device is None:
            device = self._get_device_for_param(param.name)
730 731 732 733 734 735 736 737 738 739

        if in_dygraph_mode() and (
            device == 'cpu' or isinstance(device, core.CPUPlace)
        ):
            _C_ops.full_(
                var,
                var.shape,
                str(float(fill_value)),
                var.dtype,
                core.CPUPlace(),
740
            )
741 742 743 744 745
        else:
            with device_guard(device):
                self.helper.set_variable_initializer(
                    var, initializer=Constant(value=float(fill_value))
                )
M
MRXLT 已提交
746

J
Jiabin Yang 已提交
747
        if framework._non_static_mode():
M
MRXLT 已提交
748
            if len(self._accumulators_holder) > 0:
749 750 751 752 753
                assert (
                    var_name in self._accumulators_holder
                ), "Optimizer set error, {} should in state dict".format(
                    var_name
                )
M
MRXLT 已提交
754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770
                var.set_value(self._accumulators_holder[var_name])

        self._accumulators[name][param.name] = var
        return var

    def _get_accumulator(self, name, param):
        """Utility function to fetch an accumulator for a parameter

        Args:
            name: name of the accumulator
            param: parameter tensor for which accumulator is to be fetched

        Returns:
            accumulator tensor for the parameter
        """
        if self._name is not None:
            name = self._name + "_" + name
771 772 773 774
        if (
            name not in self._accumulators
            or param.name not in self._accumulators[name]
        ):
775 776
            raise Exception(
                "Accumulator {} does not exist for parameter {}".format(
777 778 779
                    name, param.name
                )
            )
M
MRXLT 已提交
780 781 782 783
        return self._accumulators[name][param.name]

    def _update_param_device_map(self, parameters_and_grads, target_block):
        for param_and_grad in parameters_and_grads:
784
            if param_and_grad[0].stop_gradient is False:
M
MRXLT 已提交
785 786
                param_name = param_and_grad[0].name
                ops = target_block.ops
787 788
                device_attr_name = (
                    core.op_proto_and_checker_maker.kOpDeviceAttrName()
M
MRXLT 已提交
789 790 791 792 793
                )
                for op in ops:
                    input_arg_names = op.input_arg_names
                    if param_name in input_arg_names:
                        self._param_device_map[param_name] = op.attr(
794 795
                            device_attr_name
                        )
M
MRXLT 已提交
796 797 798 799 800 801 802 803
                        break

    def _get_device_for_param(self, param_name):
        device = None
        if param_name in self._param_device_map:
            device = self._param_device_map[param_name]
        return device

804 805 806
    def _create_optimization_pass(
        self, parameters_and_grads, param_group_idx=0
    ):
M
MRXLT 已提交
807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833
        """Add optimization operators to update gradients to tensors.

        Args:
          parameters_and_grads(list(tuple(Tensor, Tensor))):
            a list of (tensor, gradient) pair to update.

        Returns:
          return_op_list: a list of operators that will complete one step of
            optimization. This will include parameter update ops, global step
            update ops and any other custom ops required by subclasses to manage
            their internal state.
        """
        # This is a default implementation of create_optimization_pass that
        # can be shared by most optimizers. This implementation assumes that
        # the subclass will implement the _append_optimize_op method and the
        #  _initialize_tensors method. The subclass can extend the
        # _create_accumulators method if it needs to create accumulators
        # for parameters and extend _finish_update method to add custom ops.

        # Allways called under program_guard use global block as loss block
        # But if current block is in control flow, append optimize op in the
        # grad block of current block

        global_block = framework.default_main_program().global_block()
        target_block = global_block
        current_block = framework.default_main_program().current_block()
        if current_block.idx != global_block.idx:
834 835 836
            assert (
                current_block.backward_block_idx != -1
            ), "current block is not global_block, but it doesn't have backward block."
M
MRXLT 已提交
837
            target_block = framework.default_main_program().blocks[
838 839
                current_block.backward_block_idx
            ]
M
MRXLT 已提交
840 841 842

        start = len(target_block.ops)
        self.helper = LayerHelper(self.__class__.__name__)
843

M
MRXLT 已提交
844 845
        self._create_global_learning_rate()

Z
zhangbo9674 已提交
846 847
        # NOTE: Multi Tensor support [ Momentum, Adam ] for dygraph mode
        if self._use_multi_tensor and self.__class__.__name__ in [
848 849
            'Momentum',
            'Adam',
Z
zhangbo9674 已提交
850
        ]:
851
            if (
852 853 854
                len(self._param_dict['FP32_LODTensor'][param_group_idx]) == 0
                and len(self._param_dict['FP16_LODTensor'][param_group_idx])
                == 0
855
            ):
856
                if isinstance(parameters_and_grads, list):
857
                    assert param_group_idx == 0
858 859 860 861 862 863 864
                    self._multi_tensor_init(
                        target_block,
                        [
                            p[0]
                            for p in parameters_and_grads
                            if not p[0].stop_gradient
                        ],
865
                        param_group_idx,
866
                    )
867 868
                else:
                    self._update_param_group(parameters_and_grads)
869 870 871 872 873 874 875
                    self._multi_tensor_init(
                        target_block,
                        [
                            p[0]
                            for p in parameters_and_grads['params']
                            if not p[0].stop_gradient
                        ],
876
                        param_group_idx,
877
                    )
J
Jiabin Yang 已提交
878
            if framework._non_static_mode():
879
                self._append_optimize_multi_tensor_op(
880 881 882
                    target_block,
                    parameters_and_grads,
                    param_group_idx=param_group_idx,
883
                )
884
            else:
885 886 887
                self._update_param_device_map(
                    parameters_and_grads, target_block
                )
888 889 890
                # NOTE: Multi Tensor requires all parameters to be in the same device and program.
                # param_grad_list = [p_0,g_0,p_1,g_1,....]
                param_grad_list = []
891
                for param_and_grad in parameters_and_grads:
892 893 894 895
                    if (
                        not param_and_grad[0].stop_gradient
                        and param_and_grad[1] is not None
                    ):
896 897 898
                        param_grad_list.append(param_and_grad[0])
                        param_grad_list.append(param_and_grad[1])
                with param_grad_list[0].block.program._optimized_guard(
899 900
                    param_grad_list
                ), name_scope("optimizer"):
901 902 903
                    device = self._get_device_for_param(param_grad_list[0].name)
                    with device_guard(device):
                        self._append_optimize_multi_tensor_op(
904 905 906
                            target_block,
                            parameters_and_grads,
                            param_group_idx=param_group_idx,
907
                        )
908
        else:
J
Jiabin Yang 已提交
909
            if not framework._non_static_mode():
910 911 912 913 914 915 916 917
                params_grads_device_map = (
                    parameters_and_grads['params']
                    if isinstance(parameters_and_grads, dict)
                    else parameters_and_grads
                )
                self._update_param_device_map(
                    params_grads_device_map, target_block
                )
918

919
            if isinstance(parameters_and_grads, list):
920 921 922 923 924 925 926 927
                self._create_accumulators(
                    target_block,
                    [
                        p[0]
                        for p in parameters_and_grads
                        if not p[0].stop_gradient
                    ],
                )
928
            else:
929 930
                params_acc_dict = parameters_and_grads.copy()
                params_acc_dict['params'] = [
931 932
                    p[0]
                    for p in params_acc_dict['params']
933 934 935 936
                    if not p[0].stop_gradient
                ]
                self._create_accumulators(target_block, params_acc_dict)

J
Jiabin Yang 已提交
937
            if framework._non_static_mode():
938 939 940 941 942
                if isinstance(parameters_and_grads, list):
                    for param_and_grad in parameters_and_grads:
                        if param_and_grad[1] is None:
                            continue
                        if param_and_grad[0].stop_gradient is False:
943 944 945
                            self._append_optimize_op(
                                target_block, param_and_grad
                            )
946 947 948 949 950 951 952
                else:
                    for param_and_grad in parameters_and_grads['params']:
                        if param_and_grad[1] is None:
                            continue
                        if param_and_grad[0].stop_gradient is False:
                            param_grad_dict = dict()
                            param_grad_dict['params'] = param_and_grad
953 954 955 956 957 958 959 960 961 962
                            param_grad_dict.update(
                                {
                                    k: v
                                    for k, v in parameters_and_grads.items()
                                    if k != 'params'
                                }
                            )
                            self._append_optimize_op(
                                target_block, param_grad_dict
                            )
963 964
            else:
                for param_and_grad in parameters_and_grads:
965 966
                    if param_and_grad[1] is None:
                        continue
967
                    with param_and_grad[0].block.program._optimized_guard(
968 969
                        param_and_grad
                    ), name_scope("optimizer"):
970
                        if param_and_grad[0].stop_gradient is False:
971
                            device = self._get_device_for_param(
972 973
                                param_and_grad[0].name
                            )
974 975
                            with device_guard(device):
                                optimize_op = self._append_optimize_op(
976 977
                                    target_block, param_and_grad
                                )
M
MRXLT 已提交
978 979 980 981 982 983 984 985 986 987 988

        # Get custom finish ops for subclasses
        # FIXME: Need to fix this once we figure out how to handle dependencies
        self._finish_update(target_block, parameters_and_grads)

        end = len(target_block.ops)
        return target_block._slice_ops(start, end)

    def _append_dgc_ops(self, param_and_grad):
        pass

989 990 991 992 993 994 995 996
    def backward(
        self,
        loss,
        startup_program=None,
        parameters=None,
        no_grad_set=None,
        callbacks=None,
    ):
M
MRXLT 已提交
997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021
        """
        The first part of ``minimize``, do auto-diff to append backward operations for
        the current program.

        Args:
            loss (Tensor): ``loss`` tensor to run optimizations.
            startup_program (Program, optional): :ref:`api_fluid_Program` for
                initializing parameters in ``parameters``. The default value
                is None, at this time :ref:`api_fluid_default_startup_program` will be used.
            parameters (list, optional): List of ``Tensor`` or ``Tensor.name`` to update
                to minimize ``loss``. The default value is None, at this time all parameters
                will be updated.
            no_grad_set (set, optional): Set of ``Tensor``  or ``Tensor.name`` that don't need
                to be updated. The default value is None.
            callbacks (list, optional): list of callable objects to run when appending backward
                operator for one parameter. The default value is None.

        Return:
            list: list of (param, grad) tensor pairs, param is ``Parameter``,
                grad is the gradient value corresponding to the parameter.

        Examples:
            .. code-block:: python

                import paddle
1022 1023
                x = paddle.arange(26, dtype="float32").reshape([2, 13])

M
MRXLT 已提交
1024
                linear = paddle.nn.Linear(13, 5)
M
MRXLT 已提交
1025
                # This can be any optimizer supported by dygraph.
1026
                adam = paddle.optimizer.Adam(learning_rate = 0.01,
M
MRXLT 已提交
1027
                                            parameters = linear.parameters())
1028
                out = linear(x)
M
MRXLT 已提交
1029 1030 1031 1032 1033
                out.backward()
                adam.step()
                adam.clear_grad()
        """
        act_no_grad_set = None
J
Jiabin Yang 已提交
1034
        if framework._non_static_mode():
M
MRXLT 已提交
1035 1036 1037 1038
            pass
        else:
            act_no_grad_set = self._get_no_grad_set(loss, no_grad_set)

L
Leo Chen 已提交
1039 1040 1041 1042
        # Infer dtype by loss if None
        if self._dtype is None:
            self._dtype = loss.dtype

J
Jiabin Yang 已提交
1043
        if framework._non_static_mode():
1044
            parameter_list = parameters if parameters else self._parameter_list
1045

1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064
            if framework.in_dygraph_mode():
                # It is very time-consuming to call c++ functions in a loop on the python side.
                # We put this part of the code on the c++ side to improve the speed in eager mode.
                params_grads = []
                grads = core.eager.get_all_grads(parameter_list)
                for index, grad in enumerate(grads):
                    if grad is not None:
                        params_grads.append((parameter_list[index], grad))
            else:
                # Keep the original code to support legacy mode.
                # Delete the else branch when the legacy mode exits.
                params_grads = []
                for param in parameter_list:
                    if param.stop_gradient:
                        continue
                    if param._grad_ivar() is not None:
                        # create gradient tensor
                        grad_var = param._grad_ivar()
                        params_grads.append((param, grad_var))
M
MRXLT 已提交
1065 1066 1067 1068
        else:
            if callbacks is None:
                callbacks = [error_clip_callback]
            else:
1069
                assert isinstance(callbacks, list)
M
MRXLT 已提交
1070
            program = loss.block.program
1071 1072
            assert len(loss.shape) == 1 and loss.shape[0] == 1, (
                "The loss.shape should be (1L,), but the current loss.shape is {}. "
M
MRXLT 已提交
1073
                "Maybe that you should call paddle.mean to process the current loss.".format(
1074 1075 1076 1077
                    loss.shape
                )
            )
            parameter_list = parameters if parameters else self._parameter_list
M
MRXLT 已提交
1078
            with program_guard(program, startup_program):
1079
                from paddle.incubate.autograd.utils import prim_enabled
1080

1081
                if prim_enabled():
1082 1083 1084
                    params_grads = append_backward_new(
                        [loss], parameter_list, act_no_grad_set, callbacks
                    )
1085
                else:
1086 1087 1088
                    params_grads = append_backward(
                        loss, parameter_list, act_no_grad_set, callbacks
                    )
M
MRXLT 已提交
1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109
                # Note: since we can't use all_reduce_op now,
                #  dgc_op should be the last op of one grad.
                self._append_dgc_ops(params_grads)
        return params_grads

    def apply_gradients(self, params_grads):
        """
        Second part of `minimize`, appending optimization operators for
        given `params_grads` pairs.

        Args:
            params_grads (list): list of (param, grad) pair to do optimization.

        Returns:
            list: A list of operators appended to the current program.

        Examples:
            .. code-block:: python

                import paddle

1110
                inp = paddle.uniform([10, 10], dtype="float32", min=-0.1, max=0.1)
M
MRXLT 已提交
1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130
                linear = paddle.nn.Linear(10, 10)
                out = linear(inp)
                loss = paddle.mean(out)
                optimizer = paddle.optimizer.Adam(learning_rate=0.1,
                        parameters=linear.parameters())
                params_grads = optimizer.backward(loss)
                optimizer.apply_gradients(params_grads)

        """

        params_grads = sorted(params_grads, key=lambda x: x[0].name)

        # 'optimizer(grad_clip)' or 'set_gradient_clip'
        if self._grad_clip is not None:
            params_grads = self._grad_clip(params_grads)
        else:

            params_grads = append_gradient_clip_ops(params_grads)

        # Add regularization if any
1131 1132 1133
        params_grads = self.append_regularization_ops(
            params_grads, self.regularization
        )
M
MRXLT 已提交
1134 1135 1136 1137

        optimize_ops = self._create_optimization_pass(params_grads)
        return optimize_ops

1138 1139 1140
    def _apply_optimize(
        self, loss, startup_program, params_grads, param_group_idx=0
    ):
M
MRXLT 已提交
1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151
        """
        Second part of `minimize`, appending optimization operators for
        given `params_grads` pairs.
        Args:
            loss (Tensor): loss tensor to run optimizations.
            startup_program (Program): startup_program for initializing parameters
                in `parameters`.
            params_grads (list): list of (param, grad) pair to do optimization.
        Returns:
            list: A list of operators appended to the current program.
        """
J
Jiabin Yang 已提交
1152
        if framework._non_static_mode():
1153 1154 1155 1156
            with program_guard(
                framework.default_main_program(),
                framework.default_startup_program(),
            ):
1157 1158 1159
                if isinstance(params_grads, list):
                    if self._grad_clip is not None:
                        params_grads = self._grad_clip(params_grads)
1160
                    params_grads = self.append_regularization_ops(
1161 1162
                        params_grads, self.regularization
                    )
1163 1164 1165
                else:
                    grad_clip = params_grads['grad_clip']
                    if grad_clip is not None:
1166
                        params_grads['params'] = grad_clip(
1167 1168
                            params_grads['params']
                        )
1169

1170
                    params_grads['params'] = self.append_regularization_ops(
1171 1172
                        params_grads['params'], self.regularization
                    )
1173 1174 1175
                optimize_ops = self._create_optimization_pass(
                    params_grads, param_group_idx=param_group_idx
                )
M
MRXLT 已提交
1176
        else:
1177
            assert param_group_idx == 0
M
MRXLT 已提交
1178 1179 1180 1181 1182
            program = loss.block.program
            with program_guard(program, startup_program):
                optimize_ops = self.apply_gradients(params_grads)
        return optimize_ops

1183
    def _create_regularization_of_grad(self, param, grad, regularization=None):
1184
        """Create and add backward regularization Operators
1185

1186 1187 1188
        Function helper of append_regularization_ops.
        """
        # If no gradient or no regularization is specified,  then we don't need to do anything
1189
        if grad is None or (
1190 1191 1192 1193 1194 1195
            (
                not hasattr(param, 'regularizer')
                or (hasattr(param, 'regularizer') and param.regularizer is None)
            )
            and regularization is None
        ):
1196 1197 1198 1199 1200 1201 1202 1203 1204 1205
            return grad
        regularization_term = None
        if hasattr(param, 'regularizer') and param.regularizer is not None:
            # Add variable for regularization term in grad block
            regularization_term = param.regularizer(param, grad, grad.block)
        elif regularization is not None:
            regularization_term = regularization(param, grad, grad.block)

        assert regularization_term is not None

1206
        if framework.in_dygraph_mode():
Y
YuanRisheng 已提交
1207
            return _C_ops.add_n([grad, regularization_term])
1208
        elif framework._in_legacy_dygraph():
1209
            return _legacy_C_ops.sum([grad, regularization_term])
1210

1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221
        new_grad = grad
        if grad.type == core.VarDesc.VarType.SELECTED_ROWS:
            # FIXME(zcd): If the grad is SELECTED_ROWS, after regularization,
            # the grad's type and name will be changed. But the gradient's name
            # is used in ParallelExecutor Reduce mode, so I add a flag for
            # the new_grad here.
            new_grad = grad.block.create_var(
                name=grad.name + core.kNewGradSuffix(),
                dtype=param.dtype,
                shape=param.shape,
                lod_level=param.lod_level,
1222 1223
                type=core.VarDesc.VarType.LOD_TENSOR,
            )
1224 1225 1226

        inputs = {"X": [grad, regularization_term]}
        outputs = {"Out": [new_grad]}
1227
        grad.block.append_op(type='sum', inputs=inputs, outputs=outputs)
1228 1229 1230

        return new_grad

1231 1232 1233
    def append_regularization_ops(
        self, parameters_and_grads, regularization=None
    ):
1234
        r"""Create and add backward regularization Operators
1235

1236 1237 1238 1239
        Creates and adds backward regularization operators in the BlockDesc.
        This will add gradients of the regularizer function to the gradients
        of the parameters and return these modified gradients. This is the
        same as implementing weight decay in optimizers for regularization.
1240

1241 1242 1243 1244 1245
        Args:
            parameters_and_grads: A list of (parameters, gradients) pairs
                                  that need to be regularized.
            regularization: A global regularizer. If the parameter is not
                            set. It will be applied with regularizer.
1246

1247 1248 1249
        Returns:
            list[(Variable, Variable)]: list of (parameters, gradients) \
            pair with the regularized gradient
1250

1251 1252 1253 1254
        Raises:
            Exception: Unknown regularization type
        """
        params_and_grads = []
J
Jiabin Yang 已提交
1255
        if framework._non_static_mode():
1256
            for param, grad in parameters_and_grads:
1257
                new_grad = self._create_regularization_of_grad(
1258 1259
                    param, grad, regularization
                )
1260 1261 1262 1263 1264
                params_and_grads.append((param, new_grad))
        else:
            repeate_regularizer = False
            with framework.name_scope('regularization'):
                for param, grad in parameters_and_grads:
1265 1266 1267 1268 1269
                    if (
                        not repeate_regularizer
                        and param.regularizer is not None
                        and regularization is not None
                    ):
1270 1271 1272 1273
                        repeate_regularizer = True
                        logging.info(
                            "If regularizer of a Parameter has been set by 'fluid.ParamAttr' or 'fluid.WeightNormParamAttr' already. "
                            "The Regularization[%s] in Optimizer will not take effect, and it will only be applied to other Parameters!"
1274 1275
                            % regularization.__str__()
                        )
1276 1277
                    with param.block.program._optimized_guard([param, grad]):
                        new_grad = self._create_regularization_of_grad(
1278 1279
                            param, grad, regularization
                        )
1280 1281 1282
                        params_and_grads.append((param, new_grad))
        return params_and_grads

M
MRXLT 已提交
1283 1284 1285
    def _get_no_grad_set(self, loss, no_grad_set=None):
        no_grad_set = _get_no_grad_set_name(no_grad_set)
        parameters = loss.block.program.global_block().all_parameters()
1286
        param_no_trainable = set(
1287 1288
            [param.name for param in parameters if param.stop_gradient is True]
        )
M
MRXLT 已提交
1289 1290 1291 1292 1293 1294
        # If the parameter is no trainable, it should not have a gradient.
        no_grad_set.update(param_no_trainable)

        return no_grad_set

    @framework.dygraph_only
1295
    def clear_grad(self, set_to_zero=True):
M
MRXLT 已提交
1296 1297
        """
        Clear the gradients of all optimized parameters for model.
1298 1299

        If not, new gradient will accumulat on previous gradient.
1300 1301

        There are two method to clear grad: set_to_zero or delete grad.
1302

1303 1304
        Args:
            set_to_zero (bool, optional): If set grads to zero or not, default is True.
1305

M
MRXLT 已提交
1306 1307
        Returns:
            None
1308

M
MRXLT 已提交
1309 1310 1311 1312
        Examples:
            .. code-block:: python

                import paddle
1313

1314
                a = paddle.arange(26, dtype="float32").reshape([2, 13])
M
MRXLT 已提交
1315
                linear = paddle.nn.Linear(13, 5)
M
MRXLT 已提交
1316
                # This can be any optimizer supported by dygraph.
1317
                adam = paddle.optimizer.Adam(learning_rate = 0.01,
M
MRXLT 已提交
1318 1319 1320 1321 1322 1323 1324
                                            parameters = linear.parameters())
                out = linear(a)
                out.backward()
                adam.step()
                adam.clear_grad()

        """
1325
        param_list = []
1326
        if self._parameter_list is None or not isinstance(
1327 1328
            self._parameter_list[0], dict
        ):
1329 1330
            for p in self._parameter_list:
                if not p.stop_gradient:
1331
                    param_list.append(p)
1332 1333 1334 1335
        else:
            for param_group in self._param_groups:
                for p in param_group['params']:
                    if not p.stop_gradient:
1336
                        param_list.append(p)
1337

J
Jiabin Yang 已提交
1338
        if _in_eager_without_dygraph_check():
1339
            for p in param_list:
1340
                p.clear_gradient(set_to_zero)
1341 1342
        else:
            core.clear_gradients(param_list, set_to_zero)
M
MRXLT 已提交
1343

1344
    @imperative_base.no_grad
1345 1346 1347
    def minimize(
        self, loss, startup_program=None, parameters=None, no_grad_set=None
    ):
M
MRXLT 已提交
1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365
        """
        Add operations to minimize ``loss`` by updating ``parameters``.

        Args:
            loss (Tensor): A ``Tensor`` containing the value to minimize.
            startup_program (Program, optional): :ref:`api_fluid_Program` for
                initializing parameters in ``parameters``. The default value
                is None, at this time :ref:`api_fluid_default_startup_program` will be used.
            parameters (list, optional): List of ``Tensor`` or ``Tensor.name`` to update
                to minimize ``loss``. The default value is None, at this time all parameters
                will be updated.
            no_grad_set (set, optional): Set of ``Tensor``  or ``Tensor.name`` that don't need
                to be updated. The default value is None.

        Returns:
            tuple: tuple (optimize_ops, params_grads), A list of operators appended
            by minimize and a list of (param, grad) tensor pairs, param is
            ``Parameter``, grad is the gradient value corresponding to the parameter.
1366 1367
            In static graph mode, the returned tuple can be passed to ``fetch_list`` in ``Executor.run()`` to
            indicate program pruning. If so, the program will be pruned by ``feed`` and
M
MRXLT 已提交
1368 1369 1370 1371
            ``fetch_list`` before run, see details in ``Executor``.

        Examples:
            .. code-block:: python
1372

M
MRXLT 已提交
1373
                import paddle
M
MRXLT 已提交
1374
                linear = paddle.nn.Linear(10, 10)
1375 1376
                input = paddle.uniform(shape=[10, 10], min=-0.1, max=0.1)
                out = linear(input)
M
MRXLT 已提交
1377 1378 1379 1380 1381 1382 1383 1384
                loss = paddle.mean(out)

                beta1 = paddle.to_tensor([0.9], dtype="float32")
                beta2 = paddle.to_tensor([0.99], dtype="float32")

                adam = paddle.optimizer.Adam(learning_rate=0.1,
                        parameters=linear.parameters(),
                        weight_decay=0.01)
R
Roc 已提交
1385
                loss.backward()
M
MRXLT 已提交
1386 1387 1388
                adam.minimize(loss)
                adam.clear_grad()

M
MRXLT 已提交
1389 1390 1391
        """
        assert isinstance(loss, Variable), "The loss should be an Tensor."

1392
        parameter_list = parameters if parameters else self._parameter_list
1393

1394 1395 1396 1397 1398 1399
        params_grads = self.backward(
            loss,
            startup_program=startup_program,
            parameters=parameter_list,
            no_grad_set=no_grad_set,
        )
M
MRXLT 已提交
1400

1401 1402 1403
        optimize_ops = self._apply_optimize(
            loss, startup_program=startup_program, params_grads=params_grads
        )
M
MRXLT 已提交
1404 1405 1406

        return optimize_ops, params_grads

L
Leo Chen 已提交
1407
    @imperative_base.no_grad
M
MRXLT 已提交
1408 1409 1410
    @framework.dygraph_only
    def step(self):
        """
M
MRXLT 已提交
1411
        Execute the optimizer and update parameters once.
1412

M
MRXLT 已提交
1413 1414 1415 1416 1417 1418 1419
        Returns:
            None

        Examples:
            .. code-block:: python

                import paddle
1420

1421
                a = paddle.arange(26, dtype="float32").reshape([2, 13])
M
MRXLT 已提交
1422
                linear = paddle.nn.Linear(13, 5)
M
MRXLT 已提交
1423
                # This can be any optimizer supported by dygraph.
1424
                adam = paddle.optimizer.Adam(learning_rate = 0.01,
1425
                                        parameters = linear.parameters())
M
MRXLT 已提交
1426 1427 1428 1429 1430
                out = linear(a)
                out.backward()
                adam.step()
                adam.clear_grad()
        """
1431 1432 1433 1434 1435 1436 1437 1438 1439 1440

        if not isinstance(self._param_groups[0], dict):
            params_grads = []
            for param in self._param_groups:
                if param.stop_gradient:
                    continue
                if param._grad_ivar() is not None:
                    grad_var = param._grad_ivar()
                    params_grads.append((param, grad_var))

1441
            self._apply_optimize(
1442 1443 1444 1445
                loss=None,
                startup_program=None,
                params_grads=params_grads,
                param_group_idx=0,
1446
            )
1447 1448 1449

        else:
            # optimize parameters in groups
1450
            for idx, param_group in enumerate(self._param_groups):
1451 1452 1453 1454 1455 1456 1457 1458
                params_grads = defaultdict(lambda: list())
                for param in param_group['params']:
                    if param.stop_gradient:
                        continue
                    if param._grad_ivar() is not None:
                        grad_var = param._grad_ivar()
                        params_grads['params'].append((param, grad_var))
                params_grads.update(
1459 1460 1461
                    {k: v for k, v in param_group.items() if k != 'params'}
                )
                self._apply_optimize(
1462 1463 1464 1465
                    loss=None,
                    startup_program=None,
                    params_grads=params_grads,
                    param_group_idx=idx,
1466
                )
1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481

    def _add_param_group(self, param_group):
        """
        Add a param group to parameter_list.

        Args:
            param_group (dict): The group of Tensors to be optimzed with
            different optimization options.
        """
        params = param_group['params']
        if isinstance(params, Parameter):
            param_group['params'] = [params]
        elif isinstance(params, set):
            raise TypeError(
                "optimizer parameters should be in ordered collections,"
1482 1483
                "but received set, please use list instead."
            )
1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496
        else:
            param_group['params'] = list(params)

        # Update optimization options for each groups
        for k, v in self._default_dict.items():
            param_group.setdefault(k, v)

        param_set = set()
        for group in self._param_groups:
            param_set.update(set(group['params']))

        if not param_set.isdisjoint(set(param_group['params'])):
            raise ValueError(
1497 1498
                "some parameters appear in more than one parameter group"
            )
1499 1500 1501 1502 1503

        for param in param_group['params']:
            weight_decay = param_group['weight_decay']
            if isinstance(weight_decay, float):
                from ..fluid.regularizer import L2Decay
1504

1505 1506 1507 1508
                regularization = L2Decay(weight_decay)
            else:
                regularization = weight_decay
            param.regularizer = regularization
W
wangguanzhong 已提交
1509
            param.optimize_attr['learning_rate'] = param_group.get(
1510 1511
                'learning_rate', 1.0
            )
1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522

        self._param_groups.append(param_group)

    def _update_param_group(self, parameters):
        """
        Update the param group with new entry
        Args:
            parameters (dict): The extra group of Tensors to be optimzed with
            different optimization options. Only used in child class.
        """
        pass
1523 1524

    @framework.dygraph_only
1525
    def _multi_tensor_init(self, target_block, parameters, param_group_idx):
1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536
        """
        All parameters used for optimizer (such as: parameters, master_weight, velocity_acc for momentum) calculations are grouped into a python list by data type (float16, float32).
        This function will be overridden in the corresponding optimizer file.

        Args:
            target_block: the block in which the loss tensor is present
            parameters: list of parameter tensors for the optimizer
        """
        pass

    @framework.dygraph_only
1537
    def _append_optimize_multi_tensor_op(
1538
        self, target_block, parameters_and_grads, param_group_idx
1539
    ):
1540
        """
1541 1542 1543
        For Multi Tensor, append optimize merged_operator to block.
        """
        pass
1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557

    def _is_dtype_fp16_or_bf16(self, dtype):
        """
        check the dtype is fp16 or the dtype is bf16
        :param dtype: instance of core.VarDesc.VarType
        :return: True if dtype is one of fp16 or bf16, False otherwise
        """
        assert isinstance(
            dtype, core.VarDesc.VarType
        ), "The dtype should be an instance of core.VarDesc.VarType."
        return (
            dtype == core.VarDesc.VarType.FP16
            or dtype == core.VarDesc.VarType.BF16
        )