adam.py 30.3 KB
Newer Older
M
MRXLT 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import warnings
16
from collections import defaultdict
M
MRXLT 已提交
17

18
import paddle
19
from paddle import _C_ops
20

21 22 23
from ..base import core, framework
from ..base.dygraph import base as imperative_base
from ..base.framework import Variable, in_dygraph_mode
24 25
from .optimizer import Optimizer

26 27
__all__ = []

28
GRAD_TYPES = [int(paddle.float32), int(paddle.float16), int(paddle.bfloat16)]
29

M
MRXLT 已提交
30 31

class Adam(Optimizer):
32
    r"""
M
MRXLT 已提交
33 34 35 36
    The Adam optimizer uses an optimization described at the end
    of section 2 of `Adam paper <https://arxiv.org/abs/1412.6980>`_ ,
    it can dynamically adjusts the learning rate of each parameter using
    the 1st moment estimates and the 2nd moment estimates of the gradient.
37

M
MRXLT 已提交
38 39 40 41 42 43
    The parameter ``param_out`` update rule with gradient ``grad``:

    .. math::

        t & = t + 1

44
        moment\_1\_out & = {\beta}_1 * moment\_1 + (1 - {\beta}_1) * grad
M
MRXLT 已提交
45

46
        moment\_2\_out & = {\beta}_2 * moment\_2 + (1 - {\beta}_2) * grad * grad
M
MRXLT 已提交
47

48 49
        learning\_rate & = learning\_rate * \
                          \frac{\sqrt{1 - {\beta}_2^t}}{1 - {\beta}_1^t}
M
MRXLT 已提交
50

51
        param\_out & = param - learning\_rate * \frac{moment\_1}{\sqrt{moment\_2} + \epsilon}
M
MRXLT 已提交
52 53 54 55

    Related paper: `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_

    Args:
56 57
        learning_rate (float|LRScheduler, optional): The learning rate used to update ``Parameter``.
            It can be a float value or a LRScheduler. The default value is 0.001.
M
MRXLT 已提交
58 59 60 61 62 63
        beta1 (float|Tensor, optional): The exponential decay rate for the 1st moment estimates.
            It should be a float number or a Tensor with shape [1] and data type as float32.
            The default value is 0.9.
        beta2 (float|Tensor, optional): The exponential decay rate for the 2nd moment estimates.
            It should be a float number or a Tensor with shape [1] and data type as float32.
            The default value is 0.999.
64 65
        epsilon (float|Tensor, optional): A small float value for numerical stability.
            It should be a float number or a Tensor with shape [1] and data type as float32.
M
MRXLT 已提交
66
            The default value is 1e-08.
67 68 69 70 71
        parameters (list|tuple, optional): List/Tuple of ``Tensor`` to update to minimize ``loss``.
            This parameter is required in dygraph mode. And you can specify different options for
            different parameter groups such as the learning rate, weight decay, etc,
            then the parameters are list of dict. Note that the learning_rate in paramter groups
            represents the scale of base learning_rate.
72
            The default value is None in static graph mode, at this time all parameters will be updated.
73 74
        weight_decay (float|WeightDecayRegularizer, optional): The strategy of regularization.
            It canbe a float value as coeff of L2 regularization or
75 76
            :ref:`api_base_regularizer_L1Decay`, :ref:`api_base_regularizer_L2Decay`.
            If a parameter has set regularizer using :ref:`api_base_ParamAttr` already,
77 78 79
            the regularization setting here in optimizer will be ignored for this parameter.
            Otherwise, the regularization setting here in optimizer will take effect.
            Default None, meaning there is no regularization.
80 81
        grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of
            some derived class of ``GradientClipBase`` . There are three cliping strategies
82 83
            ( :ref:`api_base_clip_GradientClipByGlobalNorm` , :ref:`api_base_clip_GradientClipByNorm` ,
            :ref:`api_base_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping.
M
MRXLT 已提交
84 85 86 87 88 89 90
        lazy_mode (bool, optional): The official Adam algorithm has two moving-average accumulators.
            The accumulators are updated at every step. Every element of the two moving-average
            is updated in both dense mode and sparse mode. If the size of parameter is very large,
            then the update may be very slow. The lazy mode only update the element that has
            gradient in current mini-batch, so it will be much more faster. But this mode has
            different semantics with the original Adam algorithm and may lead to different result.
            The default value is False.
91
        multi_precision (bool, optional): Whether to use multi-precision during weight updating. Default is false.
Z
zhangbo9674 已提交
92
        use_multi_tensor (bool, optional): Whether to use multi-tensor strategy to update all parameters at once . Default is false.
93 94 95
        name (str, optional): Normally there is no need for user to set this property.
            For more information, please refer to :ref:`api_guide_Name`.
            The default value is None.
M
MRXLT 已提交
96 97 98

    Examples:
        .. code-block:: python
99
            :name: code-example1
M
MRXLT 已提交
100

101 102 103 104 105 106 107 108 109 110 111
            >>> import paddle

            >>> linear = paddle.nn.Linear(10, 10)
            >>> inp = paddle.rand([10,10], dtype="float32")
            >>> out = linear(inp)
            >>> loss = paddle.mean(out)
            >>> adam = paddle.optimizer.Adam(learning_rate=0.1,
            ...         parameters=linear.parameters())
            >>> loss.backward()
            >>> adam.step()
            >>> adam.clear_grad()
M
MRXLT 已提交
112 113

        .. code-block:: python
114
            :name: code-example2
M
MRXLT 已提交
115

116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
            >>> # Adam with beta1/beta2 as Tensor and weight_decay as float
            >>> import paddle

            >>> linear = paddle.nn.Linear(10, 10)
            >>> inp = paddle.rand([10,10], dtype="float32")
            >>> out = linear(inp)
            >>> loss = paddle.mean(out)
            >>> beta1 = paddle.to_tensor([0.9], dtype="float32")
            >>> beta2 = paddle.to_tensor([0.99], dtype="float32")
            >>> adam = paddle.optimizer.Adam(learning_rate=0.1,
            ...         parameters=linear.parameters(),
            ...         beta1=beta1,
            ...         beta2=beta2,
            ...         weight_decay=0.01)
            >>> loss.backward()
            >>> adam.step()
            >>> adam.clear_grad()

            >>> # Note that the learning_rate of linear_2 is 0.01.
            >>> linear_1 = paddle.nn.Linear(10, 10)
            >>> linear_2 = paddle.nn.Linear(10, 10)
            >>> inp = paddle.uniform(shape=[10, 10], min=-0.1, max=0.1)
            >>> out = linear_1(inp)
            >>> out = linear_2(out)
            >>> loss = paddle.mean(out)
            >>> adam = paddle.optimizer.Adam(
            ...     learning_rate=0.1,
            ...     parameters=[{
            ...         'params': linear_1.parameters()
            ...     }, {
            ...         'params': linear_2.parameters(),
            ...         'weight_decay': 0.001,
            ...         'learning_rate': 0.1,
            ...         'beta1': 0.8
            ...     }],
            ...     weight_decay=0.01,
            ...     beta1=0.9)
            >>> loss.backward()
            >>> adam.step()
            >>> adam.clear_grad()
156

M
MRXLT 已提交
157 158 159 160 161 162
    """
    _moment1_acc_str = "moment1"
    _moment2_acc_str = "moment2"
    _beta1_pow_acc_str = "beta1_pow_acc"
    _beta2_pow_acc_str = "beta2_pow_acc"

163 164 165 166 167 168 169 170 171 172 173 174 175 176
    def __init__(
        self,
        learning_rate=0.001,
        beta1=0.9,
        beta2=0.999,
        epsilon=1e-8,
        parameters=None,
        weight_decay=None,
        grad_clip=None,
        lazy_mode=False,
        multi_precision=False,
        use_multi_tensor=False,
        name=None,
    ):
M
MRXLT 已提交
177 178 179 180
        assert learning_rate is not None
        assert beta1 is not None
        assert beta2 is not None
        assert epsilon is not None
181 182 183
        if not isinstance(beta1, Variable):
            if not 0 <= beta1 < 1:
                raise ValueError(
184 185
                    "Invaild value of beta1, expect beta1 in [0,1)."
                )
186 187 188
        if not isinstance(beta2, Variable):
            if not 0 <= beta2 < 1:
                raise ValueError(
189 190
                    "Invaild value of beta2, expect beta2 in [0,1)."
                )
191 192 193
        if not isinstance(epsilon, Variable):
            if not 0 <= epsilon:
                raise ValueError(
194 195
                    "Invaild value of epsilon, expect epsilon >= 0."
                )
196
        super().__init__(
197 198 199 200 201 202
            learning_rate=learning_rate,
            parameters=parameters,
            weight_decay=weight_decay,
            grad_clip=grad_clip,
            name=name,
        )
M
MRXLT 已提交
203 204 205 206 207
        self.type = "adam"
        self._beta1 = beta1
        self._beta2 = beta2
        self._epsilon = epsilon
        self._lazy_mode = lazy_mode
208 209
        self._multi_precision = multi_precision
        self._master_weights = {}
210 211 212 213 214 215
        self._default_dict = {
            'beta1': beta1,
            'beta2': beta2,
            'epsilon': epsilon,
            'lazy_mode': lazy_mode,
        }
216

Z
zhangbo9674 已提交
217 218
        self._use_multi_tensor = use_multi_tensor
        if self._use_multi_tensor:
219 220 221 222 223 224 225
            self._param_dict = self._create_multi_tensor_dict()
            self._moment1_dict = self._create_multi_tensor_dict()
            self._moment2_dict = self._create_multi_tensor_dict()
            self._beta1_pow_acc_dict = self._create_multi_tensor_dict()
            self._beta2_pow_acc_dict = self._create_multi_tensor_dict()
            self._master_weight_dict = self._create_multi_tensor_dict()
            self._master_weight_dict['FP32_LODTensor'] = None
Z
zhangbo9674 已提交
226

227 228
    def _add_moments_pows(self, p):
        acc_dtype = p.dtype
229
        if self._is_dtype_fp16_or_bf16(acc_dtype):
230 231 232 233 234 235 236
            acc_dtype = core.VarDesc.VarType.FP32
        self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype)
        self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype)
        self._add_accumulator(
            name=self._beta1_pow_acc_str,
            param=p,
            dtype=acc_dtype,
237 238 239
            fill_value=0.9
            if isinstance(self._beta1, Variable)
            else self._beta1,
240
            shape=[1],
241 242 243
            type=core.VarDesc.VarType.LOD_TENSOR,
            device='cpu',
        )
244 245 246 247
        self._add_accumulator(
            name=self._beta2_pow_acc_str,
            param=p,
            dtype=acc_dtype,
248 249 250
            fill_value=0.999
            if isinstance(self._beta2, Variable)
            else self._beta2,
251
            shape=[1],
252 253 254
            type=core.VarDesc.VarType.LOD_TENSOR,
            device='cpu',
        )
M
MRXLT 已提交
255 256 257

    def _create_accumulators(self, block, parameters):
        assert isinstance(block, framework.Block)
258 259
        if isinstance(parameters, dict):
            parameters = self._update_param_group(parameters)
M
MRXLT 已提交
260 261 262

        # Create accumulator tensors for first and second moments
        for p in parameters:
W
wanghuancoder 已提交
263 264
            if p.name in self._already_create_accumulater:
                continue
265
            if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype):
266 267
                master_p = self._create_master_weight(p)
                self._add_moments_pows(master_p)
W
wanghuancoder 已提交
268
                self._already_create_accumulater.add(p.name)
269
                continue
270
            if (
271
                self._is_dtype_fp16_or_bf16(p.dtype)
272 273
                and not self._multi_precision
            ):
274
                warnings.warn(
275
                    "Accumulating with FP16 or BF16 in optimizer can lead to poor accuracy or slow convergence."
276
                    "Consider using multi_precision=True option of the Adam optimizer."
277 278
                )
            self._add_moments_pows(p)
W
wanghuancoder 已提交
279
            self._already_create_accumulater.add(p.name)
M
MRXLT 已提交
280 281 282

    def _append_optimize_op(self, block, param_and_grad):
        assert isinstance(block, framework.Block)
283 284
        if isinstance(param_and_grad, dict):
            param_and_grad = self._update_param_group(param_and_grad)
M
MRXLT 已提交
285

286
        moment1 = self._get_accumulator_master(
287 288
            self._moment1_acc_str, param_and_grad[0]
        )
289
        moment2 = self._get_accumulator_master(
290 291
            self._moment2_acc_str, param_and_grad[0]
        )
292
        beta1_pow_acc = self._get_accumulator_master(
293 294
            self._beta1_pow_acc_str, param_and_grad[0]
        )
295
        beta2_pow_acc = self._get_accumulator_master(
296 297
            self._beta2_pow_acc_str, param_and_grad[0]
        )
298 299
        find_master = self._multi_precision and self._is_dtype_fp16_or_bf16(
            param_and_grad[0].dtype
300 301 302 303 304 305
        )
        master_weight = (
            self._master_weights[param_and_grad[0].name]
            if find_master
            else None
        )
M
MRXLT 已提交
306 307 308
        lr = self._create_param_lr(param_and_grad)
        # create the adam optimize op

C
chentianyu03 已提交
309
        if framework.in_dygraph_mode():
310 311 312
            _beta1 = (
                self._beta1
                if not isinstance(self._beta1, Variable)
313
                else self._beta1.item(0)
314 315 316 317
            )
            _beta2 = (
                self._beta2
                if not isinstance(self._beta2, Variable)
318
                else self._beta2.item(0)
319
            )
C
chentianyu03 已提交
320

321
            _, _, _, _, _, _ = _C_ops.adam_(
322 323 324 325 326 327 328 329
                param_and_grad[0],
                param_and_grad[1],
                lr,
                moment1,
                moment2,
                beta1_pow_acc,
                beta2_pow_acc,
                master_weight,
W
wanghuancoder 已提交
330
                None,
331 332 333 334 335 336 337 338
                _beta1,
                _beta2,
                self._epsilon,
                self._lazy_mode,
                1000,
                find_master,
                False,
            )
C
chentianyu03 已提交
339 340

            return None
M
MRXLT 已提交
341
        else:
342 343 344 345 346 347 348 349 350
            inputs = {
                "Param": [param_and_grad[0]],
                "Grad": [param_and_grad[1]],
                "LearningRate": [lr],
                "Moment1": [moment1],
                "Moment2": [moment2],
                "Beta1Pow": [beta1_pow_acc],
                "Beta2Pow": [beta2_pow_acc],
            }
L
LoneRanger 已提交
351 352 353 354 355 356 357

            # Pass found_inf to adam, to skip update for not only param, but also momentum and beta_pow
            found_inf = self._get_auxiliary_var('found_inf')

            if found_inf:
                inputs['SkipUpdate'] = found_inf

358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394
            outputs = {
                "ParamOut": [param_and_grad[0]],
                "Moment1Out": [moment1],
                "Moment2Out": [moment2],
                "Beta1PowOut": [beta1_pow_acc],
                "Beta2PowOut": [beta2_pow_acc],
            }
            attrs = {
                "lazy_mode": self._lazy_mode,
                "min_row_size_to_use_multithread": 1000,
                "multi_precision": find_master,
            }

            if isinstance(self._beta1, Variable):
                inputs['Beta1Tensor'] = self._beta1
            else:
                attrs['beta1'] = self._beta1
            if isinstance(self._beta2, Variable):
                inputs['Beta2Tensor'] = self._beta2
            else:
                attrs['beta2'] = self._beta2
            if isinstance(self._epsilon, Variable):
                inputs['EpsilonTensor'] = self._epsilon
            else:
                attrs['epsilon'] = self._epsilon

            if find_master:
                inputs["MasterParam"] = master_weight
                outputs["MasterParamOut"] = master_weight

            adam_op = block.append_op(
                type=self.type,
                inputs=inputs,
                outputs=outputs,
                attrs=attrs,
                stop_gradient=True,
            )
M
MRXLT 已提交
395

396
            return adam_op
397

W
WangXi 已提交
398
    @imperative_base.no_grad
399
    @framework.non_static_only
400 401 402
    def step(self):
        """
        Execute the optimizer and update parameters once.
403

404 405 406 407 408 409
        Returns:
            None

        Examples:
            .. code-block:: python

410 411 412 413 414 415 416 417 418 419 420
                >>> import paddle

                >>> a = paddle.rand([2,13], dtype="float32")
                >>> linear = paddle.nn.Linear(13, 5)
                >>> # This can be any optimizer supported by dygraph.
                >>> adam = paddle.optimizer.Adam(learning_rate = 0.01,
                ...                             parameters = linear.parameters())
                >>> out = linear(a)
                >>> out.backward()
                >>> adam.step()
                >>> adam.clear_grad()
421
        """
422
        if paddle.base.dygraph.base.in_to_static_mode():
423 424 425
            self._declarative_step()
            return

426 427 428 429 430 431 432
        if not isinstance(self._parameter_list[0], dict):
            params_grads = []
            for param in self._parameter_list:
                if param.stop_gradient:
                    continue
                if param._grad_ivar() is not None:
                    grad_var = param._grad_ivar()
433
                    if in_dygraph_mode():
434 435 436 437 438
                        if (
                            hasattr(grad_var, "is_selected_rows")
                            and grad_var.is_selected_rows()
                            and self.regularization is not None
                        ):
439 440 441 442
                            raise RuntimeError(
                                "Adam don't support weight_decay with sparse parameters, please set it to None."
                            )
                    else:
443 444 445 446 447
                        if (
                            hasattr(grad_var, "_is_sparse")
                            and grad_var._is_sparse()
                            and self.regularization is not None
                        ):
448 449 450
                            raise RuntimeError(
                                "Adam don't support weight_decay with sparse parameters, please set it to None."
                            )
451 452
                    params_grads.append((param, grad_var))

453
            optimize_ops = self._apply_optimize(
454 455 456 457
                loss=None,
                startup_program=None,
                params_grads=params_grads,
                param_group_idx=0,
458
            )
459 460
        else:
            # optimize parameters in groups
461
            for idx, param_group in enumerate(self._param_groups):
462
                params_grads = defaultdict(lambda: [])
463 464 465 466 467 468 469
                for param in param_group['params']:
                    if param.stop_gradient:
                        continue
                    if param._grad_ivar() is not None:
                        grad_var = param._grad_ivar()
                        params_grads['params'].append((param, grad_var))
                params_grads.update(
470 471 472
                    {k: v for k, v in param_group.items() if k != 'params'}
                )
                self._apply_optimize(
473 474 475 476
                    loss=None,
                    startup_program=None,
                    params_grads=params_grads,
                    param_group_idx=idx,
477
                )
478

479
    def _multi_tensor_init(self, target_block, parameters, param_group_idx):
Z
zhangbo9674 已提交
480
        """
481
        All parameters used for optimizer (such as: parameters, master_weight, velocity_acc for momentum) calculations are grouped into a python list by data type (bfloat16, float16, float32).
Z
zhangbo9674 已提交
482 483 484 485 486 487 488
        This function will be overridden in the corresponding optimizer file.
        Args:
            target_block: the block in which the loss tensor is present
            parameters: list of parameter tensors for the optimizer
        """
        self._create_accumulators(target_block, parameters)
        for param in parameters:
489 490 491
            moment1 = self._get_accumulator_master(self._moment1_acc_str, param)
            moment2 = self._get_accumulator_master(self._moment2_acc_str, param)
            beta1_pow_acc = self._get_accumulator_master(
492 493
                self._beta1_pow_acc_str, param
            )
494
            beta2_pow_acc = self._get_accumulator_master(
495 496
                self._beta2_pow_acc_str, param
            )
Z
zhangbo9674 已提交
497 498

            if param.dtype == paddle.float32:
499 500 501 502 503 504 505 506 507 508 509 510 511 512 513
                self._param_dict['FP32_LODTensor'][param_group_idx].append(
                    param
                )
                self._moment1_dict['FP32_LODTensor'][param_group_idx].append(
                    moment1
                )
                self._moment2_dict['FP32_LODTensor'][param_group_idx].append(
                    moment2
                )
                self._beta1_pow_acc_dict['FP32_LODTensor'][
                    param_group_idx
                ].append(beta1_pow_acc)
                self._beta2_pow_acc_dict['FP32_LODTensor'][
                    param_group_idx
                ].append(beta2_pow_acc)
514
            elif self._is_dtype_fp16_or_bf16(param.dtype):
515 516 517 518 519 520 521 522 523 524 525 526 527 528 529
                self._param_dict['FP16_LODTensor'][param_group_idx].append(
                    param
                )
                self._moment1_dict['FP16_LODTensor'][param_group_idx].append(
                    moment1
                )
                self._moment2_dict['FP16_LODTensor'][param_group_idx].append(
                    moment2
                )
                self._beta1_pow_acc_dict['FP16_LODTensor'][
                    param_group_idx
                ].append(beta1_pow_acc)
                self._beta2_pow_acc_dict['FP16_LODTensor'][
                    param_group_idx
                ].append(beta2_pow_acc)
Z
zhangbo9674 已提交
530
                if self._multi_precision:
531 532 533
                    self._master_weight_dict['FP16_LODTensor'][
                        param_group_idx
                    ].append(self._master_weights[param.name])
Z
zhangbo9674 已提交
534 535 536 537
                else:
                    self._master_weight_dict['FP16_LODTensor'] = None
            else:
                raise ValueError(
538
                    "Now multi_tensor_momentum only support fp32, fp16 or bf16 parameters and grad is LOD_TENSOR."
Z
zhangbo9674 已提交
539 540
                )

541
    def _append_optimize_multi_tensor_op(
542 543 544 545
        self,
        target_block,
        parameters_and_grads,
        param_group_idx,
546
    ):
547
        """
Z
zhangbo9674 已提交
548 549 550 551 552 553 554 555
        For Multi Tensor, append optimize merged_operator to block.
        """
        assert isinstance(target_block, framework.Block)

        grad_dict = {'FP32_LODTensor': [], 'FP16_LODTensor': []}
        lr_dict = {'FP32_LODTensor': [], 'FP16_LODTensor': []}

        if isinstance(parameters_and_grads, list):
556 557 558 559 560 561 562 563 564
            if framework.in_dygraph_mode():
                params = [pair[0] for pair in parameters_and_grads]
                grads_types = core.eager.get_grads_types(params)
                for index, tp in enumerate(grads_types):
                    if tp == GRAD_TYPES[0]:
                        grad_dict['FP32_LODTensor'].append(
                            parameters_and_grads[index][1]
                        )
                        lr = self._create_param_lr(parameters_and_grads[index])
Z
zhangbo9674 已提交
565
                        lr_dict['FP32_LODTensor'].append(lr)
566
                    elif tp == GRAD_TYPES[1] or tp == GRAD_TYPES[2]:
567 568 569 570
                        grad_dict['FP16_LODTensor'].append(
                            parameters_and_grads[index][1]
                        )
                        lr = self._create_param_lr(parameters_and_grads[index])
Z
zhangbo9674 已提交
571
                        lr_dict['FP16_LODTensor'].append(lr)
572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587
            else:
                for param_and_grad in parameters_and_grads:
                    if param_and_grad[1] is None:
                        continue
                    if param_and_grad[0].stop_gradient is False:
                        if (
                            param_and_grad[0].dtype == paddle.float32
                            and param_and_grad[1].type
                            == core.VarDesc.VarType.LOD_TENSOR
                        ):
                            grad_dict['FP32_LODTensor'].append(
                                param_and_grad[1]
                            )
                            lr = self._create_param_lr(param_and_grad)
                            lr_dict['FP32_LODTensor'].append(lr)
                        elif (
588
                            self._is_dtype_fp16_or_bf16(param_and_grad[0].dtype)
589 590 591 592 593 594 595 596
                            and param_and_grad[1].type
                            == core.VarDesc.VarType.LOD_TENSOR
                        ):
                            grad_dict['FP16_LODTensor'].append(
                                param_and_grad[1]
                            )
                            lr = self._create_param_lr(param_and_grad)
                            lr_dict['FP16_LODTensor'].append(lr)
Z
zhangbo9674 已提交
597 598 599 600 601
        else:
            for param_and_grad in parameters_and_grads['params']:
                if param_and_grad[1] is None:
                    continue
                if param_and_grad[0].stop_gradient is False:
602
                    param_grad_dict = {}
Z
zhangbo9674 已提交
603
                    param_grad_dict['params'] = param_and_grad
604 605 606 607 608 609 610
                    param_grad_dict.update(
                        {
                            k: v
                            for k, v in parameters_and_grads.items()
                            if k != 'params'
                        }
                    )
Z
zhangbo9674 已提交
611
                    param_and_grad = self._update_param_group(param_grad_dict)
612 613 614 615 616
                    if (
                        param_and_grad[0].dtype == paddle.float32
                        and param_and_grad[1].type
                        == core.VarDesc.VarType.LOD_TENSOR
                    ):
Z
zhangbo9674 已提交
617 618 619
                        grad_dict['FP32_LODTensor'].append(param_and_grad[1])
                        lr = self._create_param_lr(param_and_grad)
                        lr_dict['FP32_LODTensor'].append(lr)
620
                    elif (
621
                        self._is_dtype_fp16_or_bf16(param_and_grad[0].dtype)
622 623 624
                        and param_and_grad[1].type
                        == core.VarDesc.VarType.LOD_TENSOR
                    ):
Z
zhangbo9674 已提交
625 626 627 628 629 630
                        grad_dict['FP16_LODTensor'].append(param_and_grad[1])
                        lr = self._create_param_lr(param_and_grad)
                        lr_dict['FP16_LODTensor'].append(lr)

        multi_tensor_list = ['FP32_LODTensor', 'FP16_LODTensor']
        for key in multi_tensor_list:
631
            if len(self._param_dict[key][param_group_idx]) > 0:
632
                find_master = self._multi_precision and key == 'FP16_LODTensor'
Z
zhangbo9674 已提交
633

634 635 636
                _beta1 = (
                    self._beta1
                    if not isinstance(self._beta1, Variable)
637
                    else self._beta1.item(0)
638 639 640 641
                )
                _beta2 = (
                    self._beta2
                    if not isinstance(self._beta2, Variable)
642
                    else self._beta2.item(0)
643
                )
Z
zhangbo9674 已提交
644

645
                if framework.in_dygraph_mode():
646 647 648 649 650 651
                    master_weight = self._master_weight_dict[key]
                    master_weight = (
                        master_weight[param_group_idx]
                        if master_weight is not None
                        else None
                    )
W
wanghuancoder 已提交
652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673
                    found_inf = self._get_auxiliary_var('found_inf')
                    if found_inf:
                        if isinstance(found_inf, core.eager.Tensor):
                            self._set_auxiliary_var('found_inf', True)
                    else:
                        if isinstance(found_inf, core.eager.Tensor):
                            self._set_auxiliary_var('found_inf', False)
                        _, _, _, _, _, _ = _C_ops.merged_adam_(
                            self._param_dict[key][param_group_idx],
                            grad_dict[key],
                            lr_dict[key],
                            self._moment1_dict[key][param_group_idx],
                            self._moment2_dict[key][param_group_idx],
                            self._beta1_pow_acc_dict[key][param_group_idx],
                            self._beta2_pow_acc_dict[key][param_group_idx],
                            master_weight,
                            _beta1,
                            _beta2,
                            self._epsilon,
                            find_master,
                            False,
                        )
Z
zhangbo9674 已提交
674 675
                else:
                    inputs = {
676
                        "Param": self._param_dict[key][param_group_idx],
Z
zhangbo9674 已提交
677 678
                        "Grad": grad_dict[key],
                        "LearningRate": lr_dict[key],
679 680 681 682 683 684 685 686
                        "Moment1": self._moment1_dict[key][param_group_idx],
                        "Moment2": self._moment2_dict[key][param_group_idx],
                        "Beta1Pow": self._beta1_pow_acc_dict[key][
                            param_group_idx
                        ],
                        "Beta2Pow": self._beta2_pow_acc_dict[key][
                            param_group_idx
                        ],
Z
zhangbo9674 已提交
687 688
                    }
                    outputs = {
689 690 691 692 693 694 695 696 697
                        "ParamOut": self._param_dict[key][param_group_idx],
                        "Moment1Out": self._moment1_dict[key][param_group_idx],
                        "Moment2Out": self._moment2_dict[key][param_group_idx],
                        "Beta1PowOut": self._beta1_pow_acc_dict[key][
                            param_group_idx
                        ],
                        "Beta2PowOut": self._beta2_pow_acc_dict[key][
                            param_group_idx
                        ],
Z
zhangbo9674 已提交
698 699 700 701
                    }
                    attrs = {
                        "epsilon": self._epsilon,
                        "beta1": _beta1,
702
                        "beta2": _beta2,
Z
zhangbo9674 已提交
703
                    }
704
                    if find_master:
705 706 707
                        inputs["MasterParam"] = self._master_weight_dict[key][
                            param_group_idx
                        ]
Z
zhangbo9674 已提交
708
                        outputs["MasterParamOut"] = self._master_weight_dict[
709
                            key
710
                        ][param_group_idx]
711
                        attrs["multi_precision"] = find_master
712 713 714 715 716 717 718
                    target_block.append_op(
                        type="merged_adam",
                        inputs=inputs,
                        outputs=outputs,
                        attrs=attrs,
                        stop_gradient=True,
                    )
Z
zhangbo9674 已提交
719 720
        return None

721 722 723 724
    def _update_param_group(self, parameters):
        self._beta1 = parameters.get('beta1', self._default_dict['beta1'])
        self._beta2 = parameters.get('beta2', self._default_dict['beta2'])
        self._epsilon = parameters.get('epsilon', self._default_dict['epsilon'])
725 726 727
        self._lazy_mode = parameters.get(
            'lazy_mode', self._default_dict['lazy_mode']
        )
728 729
        parameters = parameters.get('params')
        return parameters