adamw.py 26.5 KB
Newer Older
Z
zhaoyingli 已提交
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
M
MRXLT 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
import warnings
from collections import defaultdict
M
MRXLT 已提交
17
from .optimizer import Optimizer
18
from .lr import LRScheduler
19
from ..fluid import core
M
MRXLT 已提交
20
from ..fluid import framework
21 22 23 24 25
from ..fluid.framework import Variable, Parameter
from ..fluid import unique_name
from ..fluid import layers
from ..fluid.layer_helper import LayerHelper
from ..fluid.clip import GradientClipBase
26
from ..fluid.dygraph import base as imperative_base
27
from collections.abc import Callable
28
from .. import _C_ops
M
MRXLT 已提交
29
import paddle
30

31 32
__all__ = []

M
MRXLT 已提交
33

34
class AdamW(Optimizer):
35
    r"""
36
    The AdamW optimizer is implemented based on the AdamW Optimization
M
MRXLT 已提交
37 38 39 40 41 42 43
    in paper `DECOUPLED WEIGHT DECAY REGULARIZATION <https://arxiv.org/pdf/1711.05101.pdf>`_.
    it can resolves the problem of L2 regularization failure in the Adam optimizer.

    .. math::

        t & = t + 1

44
        moment\_1\_out & = {\beta}_1 * moment\_1 + (1 - {\beta}_1) * grad
45

46
        moemnt\_2\_out & = {\beta}_2 * moment\_2 + (1 - {\beta}_2) * grad * grad
M
MRXLT 已提交
47

48 49
        learning\_rate & = learning\_rate * 
            \frac{\sqrt{1 - {\beta}_2^t}}{1 - {beta}_1^t}
M
MRXLT 已提交
50

51
        param\_out & = param - learning\_rate * (\frac{moment\_1}{\sqrt{moment\_2} + \epsilon} + \lambda * param)
M
MRXLT 已提交
52 53 54


    Args:
55 56
        learning_rate (float|LRScheduler, optional): The learning rate used to update ``Parameter``.
            It can be a float value or a LRScheduler. The default value is 0.001.
Z
zhaoyingli 已提交
57 58
        parameters (list|tuple, optional): List/Tuple of ``Tensor`` names to update to minimize ``loss``. \
            This parameter is required in dygraph mode. And you can specify different options for \
59 60 61 62
            different parameter groups such as the learning rate, weight decay, etc, \
            then the parameters are list of dict. Note that the learning_rate in paramter groups \
            represents the scale of base learning_rate. \
	    The default value is None in static mode, at this time all parameters will be updated.
M
MRXLT 已提交
63 64 65 66 67 68 69 70
        beta1 (float|Tensor, optional): The exponential decay rate for the 1st moment estimates.
            It should be a float number or a Tensor with shape [1] and data type as float32.
            The default value is 0.9.
        beta2 (float|Tensor, optional): The exponential decay rate for the 2nd moment estimates.
            It should be a float number or a Tensor with shape [1] and data type as float32.
            The default value is 0.999.
        epsilon (float, optional): A small float value for numerical stability.
            The default value is 1e-08.
M
MRXLT 已提交
71
        weight_decay (float|Tensor, optional): The weight decay coefficient, it can be float or Tensor. The default value is 0.01.
72 73 74 75
        lr_ratio (function|None, optional): If it is not None, 
            the learning rate will be updated with layerwise learning rate ratio.
            Otherwise, the learning rate is the original.
            Default: None.
M
MRXLT 已提交
76
        apply_decay_param_fun (function|None, optional): If it is not None,
77
            only tensors that makes apply_decay_param_fun(Tensor.name)==True
H
hutuxian 已提交
78
            will be updated with weight decay. It only works when we want to specify tensors.
M
MRXLT 已提交
79
            Default: None.
80 81 82
        grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of
            some derived class of ``GradientClipBase`` . There are three cliping strategies
            ( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` ,
M
MRXLT 已提交
83 84 85 86 87 88 89 90
            :ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping.
        lazy_mode (bool, optional): The official Adam algorithm has two moving-average accumulators.
            The accumulators are updated at every step. Every element of the two moving-average
            is updated in both dense mode and sparse mode. If the size of parameter is very large,
            then the update may be very slow. The lazy mode only update the element that has
            gradient in current mini-batch, so it will be much more faster. But this mode has
            different semantics with the original Adam algorithm and may lead to different result.
            The default value is False.
91
        multi_precision (bool, optional): Whether to use multi-precision during weight updating. Default is false.
92 93 94
        name (str, optional): Normally there is no need for user to set this property.
            For more information, please refer to :ref:`api_guide_Name`.
            The default value is None.
M
MRXLT 已提交
95 96 97 98 99
    **Notes**:
        **Currently, AdamW doesn't support sparse parameter optimization.**

    Examples:
        .. code-block:: python
C
Chen Long 已提交
100
            
M
MRXLT 已提交
101 102 103
            import paddle

            linear = paddle.nn.Linear(10, 10)
104
            inp = paddle.rand([10,10], dtype="float32")
M
MRXLT 已提交
105 106 107 108 109 110
            out = linear(inp)
            loss = paddle.mean(out)

            beta1 = paddle.to_tensor([0.9], dtype="float32")
            beta2 = paddle.to_tensor([0.99], dtype="float32")

111
            opt = paddle.optimizer.AdamW(learning_rate=0.1,
M
MRXLT 已提交
112 113 114 115 116
                    parameters=linear.parameters(),
                    beta1=beta1,
                    beta2=beta2,
                    weight_decay=0.01)
            out.backward()
117 118
            opt.step()
            opt.clear_grad()
M
MRXLT 已提交
119

120 121 122 123 124 125 126 127

            #Note that the learning_rate of linear_2 is 0.01.
            linear_1 = paddle.nn.Linear(10, 10)
            linear_2 = paddle.nn.Linear(10, 10)
            inp = paddle.uniform(shape=[10, 10], min=-0.1, max=0.1)
            out = linear_1(inp)
            out = linear_2(out)
            loss = paddle.mean(out)
128
            opt = paddle.optimizer.AdamW(
129 130 131 132 133 134 135 136 137 138 139 140
                learning_rate=0.1,
                parameters=[{
                    'params': linear_1.parameters()
                }, {
                    'params': linear_2.parameters(),
                    'weight_decay': 0.001,
                    'learning_rate': 0.1,
                    'beta1': 0.8
                }],
                weight_decay=0.01,
                beta1=0.9)                   
            out.backward()
141 142
            opt.step()
            opt.clear_grad()
143

M
MRXLT 已提交
144 145
    """

146 147 148 149 150
    _moment1_acc_str = "moment1"
    _moment2_acc_str = "moment2"
    _beta1_pow_acc_str = "beta1_pow_acc"
    _beta2_pow_acc_str = "beta2_pow_acc"

M
MRXLT 已提交
151 152 153 154 155
    def __init__(self,
                 learning_rate=0.001,
                 beta1=0.9,
                 beta2=0.999,
                 epsilon=1e-8,
M
MRXLT 已提交
156 157
                 parameters=None,
                 weight_decay=0.01,
158
                 lr_ratio=None,
M
MRXLT 已提交
159 160
                 apply_decay_param_fun=None,
                 grad_clip=None,
161
                 lazy_mode=False,
162
                 multi_precision=False,
163
                 name=None):
M
MRXLT 已提交
164 165 166 167 168 169 170 171 172 173
        assert learning_rate is not None
        assert beta1 is not None
        assert beta2 is not None
        assert epsilon is not None
        if not 0 <= beta1 < 1:
            raise ValueError("Invaild value of beta1, expect beta1 in [0,1).")
        if not 0 <= beta2 < 1:
            raise ValueError("Invaild value of beta2, expect beta2 in [0,1).")
        if not 0 <= epsilon:
            raise ValueError("Invaild value of epsilon, expect epsilon >= 0.")
174 175 176
        if not isinstance(weight_decay, float) and \
                not isinstance(weight_decay, framework.Variable):
            raise TypeError("weight_decay should be float or Tensor.")
177 178
        if lr_ratio is not None:
            assert isinstance(lr_ratio, Callable)
Z
zhaoyingli 已提交
179
            if not core.is_compiled_with_cuda():
180
                raise NotImplementedError(
Z
zhaoyingli 已提交
181
                    "'lr_ratio' is unimplemented in CPU, XPU and NPU")
182

183 184 185 186 187 188 189
        if parameters is not None:
            # paddle.Tensor is also iterable, so here we don't check whether
            # the input is iterable, if the input is paddle.Tensor, the
            # list(paddle.Tensor) will be a error value
            if isinstance(parameters, (paddle.Tensor, core.eager.Tensor)):
                raise TypeError(
                    "`parameters` argument given to the optimizer should be "
190 191
                    "an iterable of paddle Tensors, but got argument type is `{}`."
                    .format(type(parameters)))
192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241
            if isinstance(parameters, dict):
                raise TypeError(
                    "`parameters` argument should not get dict type, "
                    "if parameter groups is needed, please set `parameters`"
                    " as list of dict")
            self._parameter_list = list(parameters)
        else:
            self._parameter_list = None

        self._name = name
        if framework._non_static_mode():
            if self._parameter_list is None:
                raise AttributeError(
                    "parameters argument given to the Optimizer should not be None in dygraph mode."
                )

        if not isinstance(learning_rate, (float, LRScheduler)):
            raise TypeError(
                "learning rate should be float or LRScheduler, got %s here" %
                type(learning_rate))
        if grad_clip is not None:
            if not isinstance(grad_clip, GradientClipBase):
                raise TypeError(
                    "'grad_clip' should be an instance of GradientClipBase's derived class"
                )

        self._dtype = None
        # Infer the dtype form parameter
        if self._parameter_list:
            if isinstance(self._parameter_list[0], dict):
                for param_group in self._parameter_list:
                    assert 'params' in param_group, \
                        'params should be set in parameters if parameter groups are optimized in different options'
                self._dtype = self._parameter_list[0]['params'][0].dtype
            else:
                self._dtype = self._parameter_list[0].dtype

        # each program should have a independent learning rate
        # program -> tensor(learning_rate)
        self._learning_rate_map = dict()
        # Dictionary of accumulators. Some optimizer subclasses need to
        # allocate and manage extra tensors associated with the parameters
        # to train. These tensors are called accumulators.
        # {accum_name : { paramter_name : accumulator_for_parameter, ...}, ...}
        self._accumulators = defaultdict(lambda: dict())
        self.helper = None
        self._opti_name_list = []
        self._accumulators_holder = {}
        self._param_device_map = dict()
        self.clear_gradients = self.clear_grad
M
MRXLT 已提交
242

R
Roc 已提交
243
        self.type = "adamw"
244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271
        self._learning_rate = learning_rate
        self._params_name = set()
        self._apply_decay_param_fun = apply_decay_param_fun
        self._weight_decay = weight_decay
        self._grad_clip = grad_clip
        self._lr_ratio = lr_ratio
        self._beta1 = beta1
        self._beta2 = beta2
        self._epsilon = epsilon
        self._lazy_mode = lazy_mode
        self._multi_precision = multi_precision
        self._master_weights = {}

        self._default_dict = {
            'weight_decay': weight_decay,
            'beta1': beta1,
            'beta2': beta2,
            'epsilon': epsilon,
            'lazy_mode': lazy_mode,
            'grad_clip': grad_clip
        }

        self._param_groups = []
        if self._parameter_list and isinstance(self._parameter_list[0], dict):
            for param_group in self._parameter_list:
                self._add_param_group(param_group.copy())
        else:
            self._param_groups = self._parameter_list
R
Roc 已提交
272

273 274 275
        self._use_multi_tensor = None
        self.regularization = None
        self._auxiliary_vars = {}
R
Roc 已提交
276 277 278 279 280 281 282 283 284 285

    def _set_auxiliary_var(self, key, val):
        self._auxiliary_vars[key] = val

    def _get_auxiliary_var(self, key):
        if key in self._auxiliary_vars:
            return self._auxiliary_vars[key]
        else:
            return None

286
    def _add_param_group(self, param_group):
287
        """
288 289
        Add a param group to parameter_list.

290
        Args:
291 292
            param_group (dict): The group of Tensors to be optimzed with
            different optimization options.
293
        """
294 295 296 297 298 299 300 301 302
        params = param_group['params']
        if isinstance(params, Parameter):
            param_group['params'] = [params]
        elif isinstance(params, set):
            raise TypeError(
                "optimizer parameters should be in ordered collections,"
                "but received set, please use list instead.")
        else:
            param_group['params'] = list(params)
303

304 305 306 307 308 309 310 311 312 313 314
        # Update optimization options for each groups
        for k, v in self._default_dict.items():
            param_group.setdefault(k, v)

        param_set = set()
        for group in self._param_groups:
            param_set.update(set(group['params']))

        if not param_set.isdisjoint(set(param_group['params'])):
            raise ValueError(
                "some parameters appear in more than one parameter group")
315

316 317 318 319 320 321 322 323 324
        for param in param_group['params']:
            param.optimize_attr['learning_rate'] = param_group.get(
                'learning_rate', 1.)

        self._param_groups.append(param_group)

    def _create_master_weight(self, param):
        if param.name in self._master_weights:
            var = self._master_weights[param.name]
325
        else:
326 327 328 329
            assert isinstance(self.helper, LayerHelper)

            var_name = param.name + "_fp32_master"
            var_name = unique_name.generate(var_name)
330 331 332 333 334
            var = layers.create_global_var(name=var_name,
                                           shape=param.shape,
                                           value=0,
                                           dtype='float32',
                                           persistable=True)
335
            block = self.helper.startup_program.global_block()
336 337 338 339 340 341 342
            block.append_op(type="cast",
                            inputs={"X": [param]},
                            outputs={"Out": [var]},
                            attrs={
                                "in_dtype": param.dtype,
                                "out_dtype": core.VarDesc.VarType.FP32
                            })
343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359
            self._master_weights[param.name] = var
        return var

    def _get_accumulator(self, name, param):
        """Utility function to fetch an accumulator for a parameter
        Args:
            name: name of the accumulator
            param: parameter variable for which accumulator is to be fetched
        Returns:
            accumulator variable for the parameter
        """
        if self._name is not None:
            name = self._name + "_" + name
        find_master = self._multi_precision and param.dtype == core.VarDesc.VarType.FP16
        target_param = self._master_weights[
            param.name] if find_master else param
        target_name = target_param.name
360 361 362 363 364
        if (name not in self._accumulators
                or target_name not in self._accumulators[name]):
            raise Exception(
                "Accumulator {} does not exist for parameter {}".format(
                    name, target_name))
365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406
        return self._accumulators[name][target_name]

    def _add_moments_pows(self, p):
        acc_dtype = p.dtype
        if acc_dtype == core.VarDesc.VarType.FP16:
            acc_dtype = core.VarDesc.VarType.FP32
        self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype)
        self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype)
        self._add_accumulator(
            name=self._beta1_pow_acc_str,
            param=p,
            dtype=acc_dtype,
            fill_value=0.9 if isinstance(self._beta1, Variable) \
                    else self._beta1,
            shape=[1],
            type=core.VarDesc.VarType.LOD_TENSOR, device='cpu')
        self._add_accumulator(
            name=self._beta2_pow_acc_str,
            param=p,
            dtype=acc_dtype,
            fill_value=0.999 if isinstance(self._beta2, Variable) \
                    else self._beta2,
            shape=[1],
            type=core.VarDesc.VarType.LOD_TENSOR, device='cpu')

    def _create_accumulators(self, block, parameters):
        assert isinstance(block, framework.Block)
        if isinstance(parameters, dict):
            parameters = self._update_param_group(parameters)

        # Create accumulator tensors for first and second moments
        for p in parameters:
            if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16:
                master_p = self._create_master_weight(p)
                self._add_moments_pows(master_p)
                continue
            if p.dtype == core.VarDesc.VarType.FP16 and not self._multi_precision:
                warnings.warn(
                    "Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence."
                    "Consider using multi_precision=True option of the Adam optimizer."
                )
            self._add_moments_pows(p)
407

W
WangXi 已提交
408
    def _append_optimize_op(self, block, param_and_grad):
R
Roc 已提交
409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433
        assert isinstance(block, framework.Block)
        if isinstance(param_and_grad, dict):
            param_and_grad = self._update_param_group(param_and_grad)
        param, grad = param_and_grad

        # Whether we should do weight decay for the parameter.
        with_decay = True
        if self._apply_decay_param_fun is not None \
                and not self._apply_decay_param_fun(param.name):
            with_decay = False

        moment1 = self._get_accumulator(self._moment1_acc_str,
                                        param_and_grad[0])
        moment2 = self._get_accumulator(self._moment2_acc_str,
                                        param_and_grad[0])
        beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str,
                                              param_and_grad[0])
        beta2_pow_acc = self._get_accumulator(self._beta2_pow_acc_str,
                                              param_and_grad[0])
        find_master = self._multi_precision and param_and_grad[
            0].dtype == core.VarDesc.VarType.FP16
        master_weight = (self._master_weights[param_and_grad[0].name]
                         if find_master else None)
        lr = self._create_param_lr(param_and_grad)

Z
zhaoyingli 已提交
434
        # create the adamw optimize op
J
Jiabin Yang 已提交
435
        if framework._non_static_mode():
436 437
            lr_ratio_ = 1. if self._lr_ratio is None else self._lr_ratio(
                param_and_grad[0])
R
Roc 已提交
438 439 440 441 442

            _beta1 = self._beta1 if not isinstance(
                self._beta1, Variable) else self._beta1.numpy().item(0)
            _beta2 = self._beta2 if not isinstance(
                self._beta2, Variable) else self._beta2.numpy().item(0)
443

C
chentianyu03 已提交
444 445
            if framework.in_dygraph_mode():
                found_inf = self._get_auxiliary_var('found_inf')
C
Charles-hit 已提交
446
                _, _, _, _, _, _ = _C_ops.final_state_adamw_(
C
chentianyu03 已提交
447 448
                    param_and_grad[0], param_and_grad[1], lr, moment1, moment2,
                    beta1_pow_acc, beta2_pow_acc, master_weight, found_inf,
449 450 451
                    _beta1, _beta2, self._epsilon, lr_ratio_,
                    self._weight_decay, with_decay, self._lazy_mode, 1000,
                    find_master, False)
C
chentianyu03 已提交
452 453 454 455 456 457 458 459 460
            else:
                _, _, _, _, _, _ = _C_ops.adamw(
                    param_and_grad[0], param_and_grad[1], lr, moment1, moment2,
                    beta1_pow_acc, beta2_pow_acc, master_weight,
                    param_and_grad[0], moment1, moment2, beta1_pow_acc,
                    beta2_pow_acc, master_weight, 'epsilon', self._epsilon,
                    'lazy_mode', self._lazy_mode,
                    'min_row_size_to_use_multithread', 1000, 'beta1', _beta1,
                    'beta2', _beta2, "with_decay", with_decay, 'coeff',
461 462
                    self._weight_decay, 'multi_precision', find_master,
                    'lr_ratio', lr_ratio_)
R
Roc 已提交
463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488
            return None

        inputs = {
            "Param": [param_and_grad[0]],
            "Grad": [param_and_grad[1]],
            "LearningRate": [lr],
            "Moment1": [moment1],
            "Moment2": [moment2],
            "Beta1Pow": [beta1_pow_acc],
            "Beta2Pow": [beta2_pow_acc],
        }

        # Pass found_inf to adamw, to skip update for not only param, but also momentum and beta_pow
        found_inf = self._get_auxiliary_var('found_inf')

        if found_inf:
            inputs['SkipUpdate'] = found_inf

        outputs = {
            "ParamOut": [param_and_grad[0]],
            "Moment1Out": [moment1],
            "Moment2Out": [moment2],
            "Beta1PowOut": [beta1_pow_acc],
            "Beta2PowOut": [beta2_pow_acc],
        }
        attrs = {
489 490 491 492 493 494 495 496 497 498 499 500
            "lazy_mode":
            self._lazy_mode,
            "min_row_size_to_use_multithread":
            1000,
            "multi_precision":
            find_master,
            "with_decay":
            with_decay,
            "coeff":
            self._weight_decay,
            "lr_ratio":
            1. if self._lr_ratio is None else self._lr_ratio(param_and_grad[0])
R
Roc 已提交
501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519
        }

        if isinstance(self._beta1, Variable):
            inputs['Beta1Tensor'] = self._beta1
        else:
            attrs['beta1'] = self._beta1
        if isinstance(self._beta2, Variable):
            inputs['Beta2Tensor'] = self._beta2
        else:
            attrs['beta2'] = self._beta2
        if isinstance(self._epsilon, Variable):
            inputs['EpsilonTensor'] = self._epsilon
        else:
            attrs['epsilon'] = self._epsilon

        if find_master:
            inputs["MasterParam"] = master_weight
            outputs["MasterParamOut"] = master_weight

520 521 522 523 524
        adamw_op = block.append_op(type=self.type,
                                   inputs=inputs,
                                   outputs=outputs,
                                   attrs=attrs,
                                   stop_gradient=True)
R
Roc 已提交
525 526

        return adamw_op
M
MRXLT 已提交
527 528 529

    def __str__(self):
        return " ".join(["Weight Decay, params:", ",".join(self._params_name)])
530

531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569
    @imperative_base.no_grad
    @framework.dygraph_only
    def step(self):
        """
        Execute the optimizer and update parameters once.

        Returns:
            None

        Examples:
            .. code-block:: python

                import paddle
                
                a = paddle.rand([2,13], dtype="float32")
                linear = paddle.nn.Linear(13, 5)
                # This can be any optimizer supported by dygraph.
                opt = paddle.optimizer.AdamW(learning_rate = 0.01,
                                            parameters = linear.parameters())
                out = linear(a)
                out.backward()
                opt.step()
                opt.clear_grad()
        """
        if not isinstance(self._parameter_list[0], dict):
            params_grads = []
            for param in self._parameter_list:
                if param.stop_gradient:
                    continue
                if param._grad_ivar() is not None:
                    grad_var = param._grad_ivar()
                    if framework.in_dygraph_mode():
                        if hasattr(grad_var, "is_selected_rows"
                                   ) and grad_var.is_selected_rows(
                                   ) and self.regularization is not None:
                            raise RuntimeError(
                                "AdamW don't support weight_decay with sparse parameters, please set it to None."
                            )
                    else:
570 571 572
                        if hasattr(
                                grad_var, "_is_sparse") and grad_var._is_sparse(
                                ) and self.regularization is not None:
573 574 575 576 577
                            raise RuntimeError(
                                "AdamW don't support weight_decay with sparse parameters, please set it to None."
                            )
                    params_grads.append((param, grad_var))

578 579 580
            optimize_ops = self._apply_optimize(loss=None,
                                                startup_program=None,
                                                params_grads=params_grads)
581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607
        else:
            # optimize parameters in groups
            for param_group in self._param_groups:
                params_grads = defaultdict(lambda: list())
                for param in param_group['params']:
                    if param.stop_gradient:
                        continue
                    if param._grad_ivar() is not None:
                        grad_var = param._grad_ivar()
                        if framework.in_dygraph_mode():
                            if hasattr(grad_var, "is_selected_rows"
                                       ) and grad_var.is_selected_rows(
                                       ) and self.regularization is not None:
                                raise RuntimeError(
                                    "AdamW don't support weight_decay with sparse parameters, please set it to None."
                                )
                        else:
                            if hasattr(grad_var,
                                       "_is_sparse") and grad_var._is_sparse(
                                       ) and self.regularization is not None:
                                raise RuntimeError(
                                    "AdamW don't support weight_decay with sparse parameters, please set it to None."
                                )
                        params_grads['params'].append((param, grad_var))
                params_grads.update(
                    {k: v
                     for k, v in param_group.items() if k != 'params'})
608 609 610
                self._apply_optimize(loss=None,
                                     startup_program=None,
                                     params_grads=params_grads)
611

612
    def _update_param_group(self, parameters):
613 614 615 616 617 618 619
        self._beta1 = parameters.get('beta1', self._default_dict['beta1'])
        self._beta2 = parameters.get('beta2', self._default_dict['beta2'])
        self._epsilon = parameters.get('epsilon', self._default_dict['epsilon'])
        self._lazy_mode = parameters.get('lazy_mode',
                                         self._default_dict['lazy_mode'])
        self._weight_decay = parameters.get('weight_decay',
                                            self._default_dict['weight_decay'])
620
        parameters = parameters.get('params')
621

622
        return parameters