adam.py 33.4 KB
Newer Older
M
MRXLT 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import warnings
16
from collections import defaultdict
M
MRXLT 已提交
17

18
import paddle
19
from paddle import _C_ops, _legacy_C_ops
20

21
from ..fluid import core, framework, unique_name
22 23 24 25 26
from ..fluid.dygraph import base as imperative_base
from ..fluid.framework import Variable, in_dygraph_mode
from ..fluid.layer_helper import LayerHelper
from .optimizer import Optimizer

27 28
__all__ = []

29
GRAD_TYPES = [int(paddle.float32), int(paddle.float16), int(paddle.bfloat16)]
30

M
MRXLT 已提交
31 32

class Adam(Optimizer):
33
    r"""
M
MRXLT 已提交
34 35 36 37
    The Adam optimizer uses an optimization described at the end
    of section 2 of `Adam paper <https://arxiv.org/abs/1412.6980>`_ ,
    it can dynamically adjusts the learning rate of each parameter using
    the 1st moment estimates and the 2nd moment estimates of the gradient.
38

M
MRXLT 已提交
39 40 41 42 43 44
    The parameter ``param_out`` update rule with gradient ``grad``:

    .. math::

        t & = t + 1

45
        moment\_1\_out & = {\beta}_1 * moment\_1 + (1 - {\beta}_1) * grad
M
MRXLT 已提交
46

47
        moment\_2\_out & = {\beta}_2 * moment\_2 + (1 - {\beta}_2) * grad * grad
M
MRXLT 已提交
48

49 50
        learning\_rate & = learning\_rate * \
                          \frac{\sqrt{1 - {\beta}_2^t}}{1 - {\beta}_1^t}
M
MRXLT 已提交
51

52
        param\_out & = param - learning\_rate * \frac{moment\_1}{\sqrt{moment\_2} + \epsilon}
M
MRXLT 已提交
53 54 55 56

    Related paper: `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_

    Args:
57 58
        learning_rate (float|LRScheduler, optional): The learning rate used to update ``Parameter``.
            It can be a float value or a LRScheduler. The default value is 0.001.
M
MRXLT 已提交
59 60 61 62 63 64
        beta1 (float|Tensor, optional): The exponential decay rate for the 1st moment estimates.
            It should be a float number or a Tensor with shape [1] and data type as float32.
            The default value is 0.9.
        beta2 (float|Tensor, optional): The exponential decay rate for the 2nd moment estimates.
            It should be a float number or a Tensor with shape [1] and data type as float32.
            The default value is 0.999.
65 66
        epsilon (float|Tensor, optional): A small float value for numerical stability.
            It should be a float number or a Tensor with shape [1] and data type as float32.
M
MRXLT 已提交
67
            The default value is 1e-08.
68 69 70 71 72 73 74 75 76 77 78 79 80
        parameters (list|tuple, optional): List/Tuple of ``Tensor`` to update to minimize ``loss``.
            This parameter is required in dygraph mode. And you can specify different options for
            different parameter groups such as the learning rate, weight decay, etc,
            then the parameters are list of dict. Note that the learning_rate in paramter groups
            represents the scale of base learning_rate.
            The default value is None in static mode, at this time all parameters will be updated.
        weight_decay (float|WeightDecayRegularizer, optional): The strategy of regularization.
            It canbe a float value as coeff of L2 regularization or
            :ref:`api_fluid_regularizer_L1Decay`, :ref:`api_fluid_regularizer_L2Decay`.
            If a parameter has set regularizer using :ref:`api_fluid_ParamAttr` already,
            the regularization setting here in optimizer will be ignored for this parameter.
            Otherwise, the regularization setting here in optimizer will take effect.
            Default None, meaning there is no regularization.
81 82 83
        grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of
            some derived class of ``GradientClipBase`` . There are three cliping strategies
            ( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` ,
M
MRXLT 已提交
84 85 86 87 88 89 90 91
            :ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping.
        lazy_mode (bool, optional): The official Adam algorithm has two moving-average accumulators.
            The accumulators are updated at every step. Every element of the two moving-average
            is updated in both dense mode and sparse mode. If the size of parameter is very large,
            then the update may be very slow. The lazy mode only update the element that has
            gradient in current mini-batch, so it will be much more faster. But this mode has
            different semantics with the original Adam algorithm and may lead to different result.
            The default value is False.
92
        multi_precision (bool, optional): Whether to use multi-precision during weight updating. Default is false.
Z
zhangbo9674 已提交
93
        use_multi_tensor (bool, optional): Whether to use multi-tensor strategy to update all parameters at once . Default is false.
94 95 96
        name (str, optional): Normally there is no need for user to set this property.
            For more information, please refer to :ref:`api_guide_Name`.
            The default value is None.
M
MRXLT 已提交
97 98 99 100 101 102 103

    Examples:
        .. code-block:: python

            import paddle

            linear = paddle.nn.Linear(10, 10)
104
            inp = paddle.rand([10,10], dtype="float32")
M
MRXLT 已提交
105 106 107 108 109 110 111 112 113 114 115 116 117 118
            out = linear(inp)
            loss = paddle.mean(out)
            adam = paddle.optimizer.Adam(learning_rate=0.1,
                    parameters=linear.parameters())
            out.backward()
            adam.step()
            adam.clear_grad()

        .. code-block:: python

            # Adam with beta1/beta2 as Tensor and weight_decay as float
            import paddle

            linear = paddle.nn.Linear(10, 10)
119
            inp = paddle.rand([10,10], dtype="float32")
M
MRXLT 已提交
120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
            out = linear(inp)
            loss = paddle.mean(out)

            beta1 = paddle.to_tensor([0.9], dtype="float32")
            beta2 = paddle.to_tensor([0.99], dtype="float32")

            adam = paddle.optimizer.Adam(learning_rate=0.1,
                    parameters=linear.parameters(),
                    beta1=beta1,
                    beta2=beta2,
                    weight_decay=0.01)
            out.backward()
            adam.step()
            adam.clear_grad()

135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
            #Note that the learning_rate of linear_2 is 0.01.
            linear_1 = paddle.nn.Linear(10, 10)
            linear_2 = paddle.nn.Linear(10, 10)
            inp = paddle.uniform(shape=[10, 10], min=-0.1, max=0.1)
            out = linear_1(inp)
            out = linear_2(out)
            loss = paddle.mean(out)
            adam = paddle.optimizer.Adam(
                learning_rate=0.1,
                parameters=[{
                    'params': linear_1.parameters()
                }, {
                    'params': linear_2.parameters(),
                    'weight_decay': 0.001,
                    'learning_rate': 0.1,
                    'beta1': 0.8
                }],
                weight_decay=0.01,
153
                beta1=0.9)
154 155 156 157
            out.backward()
            adam.step()
            adam.clear_grad()

M
MRXLT 已提交
158 159 160 161 162 163
    """
    _moment1_acc_str = "moment1"
    _moment2_acc_str = "moment2"
    _beta1_pow_acc_str = "beta1_pow_acc"
    _beta2_pow_acc_str = "beta2_pow_acc"

164 165 166 167 168 169 170 171 172 173 174 175 176 177
    def __init__(
        self,
        learning_rate=0.001,
        beta1=0.9,
        beta2=0.999,
        epsilon=1e-8,
        parameters=None,
        weight_decay=None,
        grad_clip=None,
        lazy_mode=False,
        multi_precision=False,
        use_multi_tensor=False,
        name=None,
    ):
M
MRXLT 已提交
178 179 180 181
        assert learning_rate is not None
        assert beta1 is not None
        assert beta2 is not None
        assert epsilon is not None
182 183 184
        if not isinstance(beta1, Variable):
            if not 0 <= beta1 < 1:
                raise ValueError(
185 186
                    "Invaild value of beta1, expect beta1 in [0,1)."
                )
187 188 189
        if not isinstance(beta2, Variable):
            if not 0 <= beta2 < 1:
                raise ValueError(
190 191
                    "Invaild value of beta2, expect beta2 in [0,1)."
                )
192 193 194
        if not isinstance(epsilon, Variable):
            if not 0 <= epsilon:
                raise ValueError(
195 196
                    "Invaild value of epsilon, expect epsilon >= 0."
                )
197
        super().__init__(
198 199 200 201 202 203
            learning_rate=learning_rate,
            parameters=parameters,
            weight_decay=weight_decay,
            grad_clip=grad_clip,
            name=name,
        )
M
MRXLT 已提交
204 205 206 207 208
        self.type = "adam"
        self._beta1 = beta1
        self._beta2 = beta2
        self._epsilon = epsilon
        self._lazy_mode = lazy_mode
209 210
        self._multi_precision = multi_precision
        self._master_weights = {}
211 212 213 214 215 216
        self._default_dict = {
            'beta1': beta1,
            'beta2': beta2,
            'epsilon': epsilon,
            'lazy_mode': lazy_mode,
        }
217

Z
zhangbo9674 已提交
218 219
        self._use_multi_tensor = use_multi_tensor
        if self._use_multi_tensor:
220 221 222 223 224 225 226
            self._param_dict = self._create_multi_tensor_dict()
            self._moment1_dict = self._create_multi_tensor_dict()
            self._moment2_dict = self._create_multi_tensor_dict()
            self._beta1_pow_acc_dict = self._create_multi_tensor_dict()
            self._beta2_pow_acc_dict = self._create_multi_tensor_dict()
            self._master_weight_dict = self._create_multi_tensor_dict()
            self._master_weight_dict['FP32_LODTensor'] = None
Z
zhangbo9674 已提交
227

228
    def _create_master_weight(self, param):
229 230 231 232 233 234 235
        if param.name in self._master_weights:
            var = self._master_weights[param.name]
        else:
            assert isinstance(self.helper, LayerHelper)

            var_name = param.name + "_fp32_master"
            var_name = unique_name.generate(var_name)
236
            var = paddle.static.create_global_var(
237 238 239 240 241 242
                name=var_name,
                shape=param.shape,
                value=0,
                dtype='float32',
                persistable=True,
            )
243
            block = self.helper.startup_program.global_block()
244 245 246 247 248 249 250 251 252
            block.append_op(
                type="cast",
                inputs={"X": [param]},
                outputs={"Out": [var]},
                attrs={
                    "in_dtype": param.dtype,
                    "out_dtype": core.VarDesc.VarType.FP32,
                },
            )
253
            self._master_weights[param.name] = var
254 255 256 257 258 259 260 261 262 263 264 265
        return var

    def _get_accumulator(self, name, param):
        """Utility function to fetch an accumulator for a parameter
        Args:
            name: name of the accumulator
            param: parameter variable for which accumulator is to be fetched
        Returns:
            accumulator variable for the parameter
        """
        if self._name is not None:
            name = self._name + "_" + name
266 267
        find_master = self._multi_precision and self._is_dtype_fp16_or_bf16(
            param.dtype
268 269 270 271
        )
        target_param = (
            self._master_weights[param.name] if find_master else param
        )
272
        target_name = target_param.name
273 274 275 276
        if (
            name not in self._accumulators
            or target_name not in self._accumulators[name]
        ):
277 278
            raise Exception(
                "Accumulator {} does not exist for parameter {}".format(
279 280 281
                    name, target_name
                )
            )
282 283 284 285
        return self._accumulators[name][target_name]

    def _add_moments_pows(self, p):
        acc_dtype = p.dtype
286
        if self._is_dtype_fp16_or_bf16(acc_dtype):
287 288 289 290 291 292 293
            acc_dtype = core.VarDesc.VarType.FP32
        self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype)
        self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype)
        self._add_accumulator(
            name=self._beta1_pow_acc_str,
            param=p,
            dtype=acc_dtype,
294 295 296
            fill_value=0.9
            if isinstance(self._beta1, Variable)
            else self._beta1,
297
            shape=[1],
298 299 300
            type=core.VarDesc.VarType.LOD_TENSOR,
            device='cpu',
        )
301 302 303 304
        self._add_accumulator(
            name=self._beta2_pow_acc_str,
            param=p,
            dtype=acc_dtype,
305 306 307
            fill_value=0.999
            if isinstance(self._beta2, Variable)
            else self._beta2,
308
            shape=[1],
309 310 311
            type=core.VarDesc.VarType.LOD_TENSOR,
            device='cpu',
        )
M
MRXLT 已提交
312 313 314

    def _create_accumulators(self, block, parameters):
        assert isinstance(block, framework.Block)
315 316
        if isinstance(parameters, dict):
            parameters = self._update_param_group(parameters)
M
MRXLT 已提交
317 318 319

        # Create accumulator tensors for first and second moments
        for p in parameters:
320
            if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype):
321 322 323
                master_p = self._create_master_weight(p)
                self._add_moments_pows(master_p)
                continue
324
            if (
325
                self._is_dtype_fp16_or_bf16(p.dtype)
326 327
                and not self._multi_precision
            ):
328
                warnings.warn(
329
                    "Accumulating with FP16 or BF16 in optimizer can lead to poor accuracy or slow convergence."
330
                    "Consider using multi_precision=True option of the Adam optimizer."
331 332
                )
            self._add_moments_pows(p)
M
MRXLT 已提交
333 334 335

    def _append_optimize_op(self, block, param_and_grad):
        assert isinstance(block, framework.Block)
336 337
        if isinstance(param_and_grad, dict):
            param_and_grad = self._update_param_group(param_and_grad)
M
MRXLT 已提交
338

339 340 341 342 343 344 345 346 347 348 349 350
        moment1 = self._get_accumulator(
            self._moment1_acc_str, param_and_grad[0]
        )
        moment2 = self._get_accumulator(
            self._moment2_acc_str, param_and_grad[0]
        )
        beta1_pow_acc = self._get_accumulator(
            self._beta1_pow_acc_str, param_and_grad[0]
        )
        beta2_pow_acc = self._get_accumulator(
            self._beta2_pow_acc_str, param_and_grad[0]
        )
351 352
        find_master = self._multi_precision and self._is_dtype_fp16_or_bf16(
            param_and_grad[0].dtype
353 354 355 356 357 358
        )
        master_weight = (
            self._master_weights[param_and_grad[0].name]
            if find_master
            else None
        )
M
MRXLT 已提交
359 360 361
        lr = self._create_param_lr(param_and_grad)
        # create the adam optimize op

C
chentianyu03 已提交
362 363 364
        if framework.in_dygraph_mode():
            found_inf = self._get_auxiliary_var('found_inf')

365 366 367 368 369 370 371 372 373 374
            _beta1 = (
                self._beta1
                if not isinstance(self._beta1, Variable)
                else self._beta1.numpy().item(0)
            )
            _beta2 = (
                self._beta2
                if not isinstance(self._beta2, Variable)
                else self._beta2.numpy().item(0)
            )
C
chentianyu03 已提交
375

376
            _, _, _, _, _, _ = _C_ops.adam_(
377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393
                param_and_grad[0],
                param_and_grad[1],
                lr,
                moment1,
                moment2,
                beta1_pow_acc,
                beta2_pow_acc,
                master_weight,
                found_inf,
                _beta1,
                _beta2,
                self._epsilon,
                self._lazy_mode,
                1000,
                find_master,
                False,
            )
C
chentianyu03 已提交
394 395 396 397

            return None

        if framework._in_legacy_dygraph():
398

399 400 401 402 403 404 405 406 407 408
            _beta1 = (
                self._beta1
                if not isinstance(self._beta1, Variable)
                else self._beta1.numpy().item(0)
            )
            _beta2 = (
                self._beta2
                if not isinstance(self._beta2, Variable)
                else self._beta2.numpy().item(0)
            )
409
            _, _, _, _, _, _ = _legacy_C_ops.adam(
410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436
                param_and_grad[0],
                param_and_grad[1],
                lr,
                moment1,
                moment2,
                beta1_pow_acc,
                beta2_pow_acc,
                master_weight,
                param_and_grad[0],
                moment1,
                moment2,
                beta1_pow_acc,
                beta2_pow_acc,
                master_weight,
                'epsilon',
                self._epsilon,
                'lazy_mode',
                self._lazy_mode,
                'min_row_size_to_use_multithread',
                1000,
                'beta1',
                _beta1,
                'beta2',
                _beta2,
                'multi_precision',
                find_master,
            )
M
MRXLT 已提交
437 438 439 440 441 442 443 444 445 446

            return None

        inputs = {
            "Param": [param_and_grad[0]],
            "Grad": [param_and_grad[1]],
            "LearningRate": [lr],
            "Moment1": [moment1],
            "Moment2": [moment2],
            "Beta1Pow": [beta1_pow_acc],
447
            "Beta2Pow": [beta2_pow_acc],
M
MRXLT 已提交
448 449 450 451 452 453 454 455 456 457
        }
        outputs = {
            "ParamOut": [param_and_grad[0]],
            "Moment1Out": [moment1],
            "Moment2Out": [moment2],
            "Beta1PowOut": [beta1_pow_acc],
            "Beta2PowOut": [beta2_pow_acc],
        }
        attrs = {
            "lazy_mode": self._lazy_mode,
458
            "min_row_size_to_use_multithread": 1000,
459
            "multi_precision": find_master,
M
MRXLT 已提交
460 461 462 463 464 465 466 467 468 469
        }

        if isinstance(self._beta1, Variable):
            inputs['Beta1Tensor'] = self._beta1
        else:
            attrs['beta1'] = self._beta1
        if isinstance(self._beta2, Variable):
            inputs['Beta2Tensor'] = self._beta2
        else:
            attrs['beta2'] = self._beta2
470 471 472 473
        if isinstance(self._epsilon, Variable):
            inputs['EpsilonTensor'] = self._epsilon
        else:
            attrs['epsilon'] = self._epsilon
M
MRXLT 已提交
474

475 476 477 478
        if find_master:
            inputs["MasterParam"] = master_weight
            outputs["MasterParamOut"] = master_weight

479 480 481 482 483 484 485
        adam_op = block.append_op(
            type=self.type,
            inputs=inputs,
            outputs=outputs,
            attrs=attrs,
            stop_gradient=True,
        )
M
MRXLT 已提交
486 487

        return adam_op
488

W
WangXi 已提交
489
    @imperative_base.no_grad
490 491 492 493
    @framework.dygraph_only
    def step(self):
        """
        Execute the optimizer and update parameters once.
494

495 496 497 498 499 500 501
        Returns:
            None

        Examples:
            .. code-block:: python

                import paddle
502

503
                a = paddle.rand([2,13], dtype="float32")
504 505
                linear = paddle.nn.Linear(13, 5)
                # This can be any optimizer supported by dygraph.
506
                adam = paddle.optimizer.Adam(learning_rate = 0.01,
507 508 509 510 511 512
                                            parameters = linear.parameters())
                out = linear(a)
                out.backward()
                adam.step()
                adam.clear_grad()
        """
513 514 515 516 517 518 519
        if not isinstance(self._parameter_list[0], dict):
            params_grads = []
            for param in self._parameter_list:
                if param.stop_gradient:
                    continue
                if param._grad_ivar() is not None:
                    grad_var = param._grad_ivar()
520
                    if in_dygraph_mode():
521 522 523 524 525
                        if (
                            hasattr(grad_var, "is_selected_rows")
                            and grad_var.is_selected_rows()
                            and self.regularization is not None
                        ):
526 527 528 529
                            raise RuntimeError(
                                "Adam don't support weight_decay with sparse parameters, please set it to None."
                            )
                    else:
530 531 532 533 534
                        if (
                            hasattr(grad_var, "_is_sparse")
                            and grad_var._is_sparse()
                            and self.regularization is not None
                        ):
535 536 537
                            raise RuntimeError(
                                "Adam don't support weight_decay with sparse parameters, please set it to None."
                            )
538 539
                    params_grads.append((param, grad_var))

540
            optimize_ops = self._apply_optimize(
541 542 543 544
                loss=None,
                startup_program=None,
                params_grads=params_grads,
                param_group_idx=0,
545
            )
546 547
        else:
            # optimize parameters in groups
548
            for idx, param_group in enumerate(self._param_groups):
549 550 551 552 553 554 555 556
                params_grads = defaultdict(lambda: list())
                for param in param_group['params']:
                    if param.stop_gradient:
                        continue
                    if param._grad_ivar() is not None:
                        grad_var = param._grad_ivar()
                        params_grads['params'].append((param, grad_var))
                params_grads.update(
557 558 559
                    {k: v for k, v in param_group.items() if k != 'params'}
                )
                self._apply_optimize(
560 561 562 563
                    loss=None,
                    startup_program=None,
                    params_grads=params_grads,
                    param_group_idx=idx,
564
                )
565

566
    def _multi_tensor_init(self, target_block, parameters, param_group_idx):
Z
zhangbo9674 已提交
567
        """
568
        All parameters used for optimizer (such as: parameters, master_weight, velocity_acc for momentum) calculations are grouped into a python list by data type (bfloat16, float16, float32).
Z
zhangbo9674 已提交
569 570 571 572 573 574 575 576 577
        This function will be overridden in the corresponding optimizer file.
        Args:
            target_block: the block in which the loss tensor is present
            parameters: list of parameter tensors for the optimizer
        """
        self._create_accumulators(target_block, parameters)
        for param in parameters:
            moment1 = self._get_accumulator(self._moment1_acc_str, param)
            moment2 = self._get_accumulator(self._moment2_acc_str, param)
578 579 580 581 582 583
            beta1_pow_acc = self._get_accumulator(
                self._beta1_pow_acc_str, param
            )
            beta2_pow_acc = self._get_accumulator(
                self._beta2_pow_acc_str, param
            )
Z
zhangbo9674 已提交
584 585

            if param.dtype == paddle.float32:
586 587 588 589 590 591 592 593 594 595 596 597 598 599 600
                self._param_dict['FP32_LODTensor'][param_group_idx].append(
                    param
                )
                self._moment1_dict['FP32_LODTensor'][param_group_idx].append(
                    moment1
                )
                self._moment2_dict['FP32_LODTensor'][param_group_idx].append(
                    moment2
                )
                self._beta1_pow_acc_dict['FP32_LODTensor'][
                    param_group_idx
                ].append(beta1_pow_acc)
                self._beta2_pow_acc_dict['FP32_LODTensor'][
                    param_group_idx
                ].append(beta2_pow_acc)
601
            elif self._is_dtype_fp16_or_bf16(param.dtype):
602 603 604 605 606 607 608 609 610 611 612 613 614 615 616
                self._param_dict['FP16_LODTensor'][param_group_idx].append(
                    param
                )
                self._moment1_dict['FP16_LODTensor'][param_group_idx].append(
                    moment1
                )
                self._moment2_dict['FP16_LODTensor'][param_group_idx].append(
                    moment2
                )
                self._beta1_pow_acc_dict['FP16_LODTensor'][
                    param_group_idx
                ].append(beta1_pow_acc)
                self._beta2_pow_acc_dict['FP16_LODTensor'][
                    param_group_idx
                ].append(beta2_pow_acc)
Z
zhangbo9674 已提交
617
                if self._multi_precision:
618 619 620
                    self._master_weight_dict['FP16_LODTensor'][
                        param_group_idx
                    ].append(self._master_weights[param.name])
Z
zhangbo9674 已提交
621 622 623 624
                else:
                    self._master_weight_dict['FP16_LODTensor'] = None
            else:
                raise ValueError(
625
                    "Now multi_tensor_momentum only support fp32, fp16 or bf16 parameters and grad is LOD_TENSOR."
Z
zhangbo9674 已提交
626 627
                )

628
    def _append_optimize_multi_tensor_op(
629 630 631 632
        self,
        target_block,
        parameters_and_grads,
        param_group_idx,
633
    ):
634
        """
Z
zhangbo9674 已提交
635 636 637 638 639 640 641 642
        For Multi Tensor, append optimize merged_operator to block.
        """
        assert isinstance(target_block, framework.Block)

        grad_dict = {'FP32_LODTensor': [], 'FP16_LODTensor': []}
        lr_dict = {'FP32_LODTensor': [], 'FP16_LODTensor': []}

        if isinstance(parameters_and_grads, list):
643 644 645 646 647 648 649 650 651
            if framework.in_dygraph_mode():
                params = [pair[0] for pair in parameters_and_grads]
                grads_types = core.eager.get_grads_types(params)
                for index, tp in enumerate(grads_types):
                    if tp == GRAD_TYPES[0]:
                        grad_dict['FP32_LODTensor'].append(
                            parameters_and_grads[index][1]
                        )
                        lr = self._create_param_lr(parameters_and_grads[index])
Z
zhangbo9674 已提交
652
                        lr_dict['FP32_LODTensor'].append(lr)
653
                    elif tp == GRAD_TYPES[1] or tp == GRAD_TYPES[2]:
654 655 656 657
                        grad_dict['FP16_LODTensor'].append(
                            parameters_and_grads[index][1]
                        )
                        lr = self._create_param_lr(parameters_and_grads[index])
Z
zhangbo9674 已提交
658
                        lr_dict['FP16_LODTensor'].append(lr)
659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674
            else:
                for param_and_grad in parameters_and_grads:
                    if param_and_grad[1] is None:
                        continue
                    if param_and_grad[0].stop_gradient is False:
                        if (
                            param_and_grad[0].dtype == paddle.float32
                            and param_and_grad[1].type
                            == core.VarDesc.VarType.LOD_TENSOR
                        ):
                            grad_dict['FP32_LODTensor'].append(
                                param_and_grad[1]
                            )
                            lr = self._create_param_lr(param_and_grad)
                            lr_dict['FP32_LODTensor'].append(lr)
                        elif (
675
                            self._is_dtype_fp16_or_bf16(param_and_grad[0].dtype)
676 677 678 679 680 681 682 683
                            and param_and_grad[1].type
                            == core.VarDesc.VarType.LOD_TENSOR
                        ):
                            grad_dict['FP16_LODTensor'].append(
                                param_and_grad[1]
                            )
                            lr = self._create_param_lr(param_and_grad)
                            lr_dict['FP16_LODTensor'].append(lr)
Z
zhangbo9674 已提交
684 685 686 687 688 689 690
        else:
            for param_and_grad in parameters_and_grads['params']:
                if param_and_grad[1] is None:
                    continue
                if param_and_grad[0].stop_gradient is False:
                    param_grad_dict = dict()
                    param_grad_dict['params'] = param_and_grad
691 692 693 694 695 696 697
                    param_grad_dict.update(
                        {
                            k: v
                            for k, v in parameters_and_grads.items()
                            if k != 'params'
                        }
                    )
Z
zhangbo9674 已提交
698
                    param_and_grad = self._update_param_group(param_grad_dict)
699 700 701 702 703
                    if (
                        param_and_grad[0].dtype == paddle.float32
                        and param_and_grad[1].type
                        == core.VarDesc.VarType.LOD_TENSOR
                    ):
Z
zhangbo9674 已提交
704 705 706
                        grad_dict['FP32_LODTensor'].append(param_and_grad[1])
                        lr = self._create_param_lr(param_and_grad)
                        lr_dict['FP32_LODTensor'].append(lr)
707
                    elif (
708
                        self._is_dtype_fp16_or_bf16(param_and_grad[0].dtype)
709 710 711
                        and param_and_grad[1].type
                        == core.VarDesc.VarType.LOD_TENSOR
                    ):
Z
zhangbo9674 已提交
712 713 714 715 716 717
                        grad_dict['FP16_LODTensor'].append(param_and_grad[1])
                        lr = self._create_param_lr(param_and_grad)
                        lr_dict['FP16_LODTensor'].append(lr)

        multi_tensor_list = ['FP32_LODTensor', 'FP16_LODTensor']
        for key in multi_tensor_list:
718
            if len(self._param_dict[key][param_group_idx]) > 0:
719
                find_master = self._multi_precision and key == 'FP16_LODTensor'
Z
zhangbo9674 已提交
720

721 722 723 724 725 726 727 728 729 730
                _beta1 = (
                    self._beta1
                    if not isinstance(self._beta1, Variable)
                    else self._beta1.numpy().item(0)
                )
                _beta2 = (
                    self._beta2
                    if not isinstance(self._beta2, Variable)
                    else self._beta2.numpy().item(0)
                )
Z
zhangbo9674 已提交
731

J
Jiabin Yang 已提交
732
                if framework._non_static_mode():
733 734 735 736 737 738
                    master_weight = self._master_weight_dict[key]
                    master_weight = (
                        master_weight[param_group_idx]
                        if master_weight is not None
                        else None
                    )
739
                    if in_dygraph_mode():
740

741
                        _, _, _, _, _, _ = _C_ops.merged_adam_(
742
                            self._param_dict[key][param_group_idx],
743 744
                            grad_dict[key],
                            lr_dict[key],
745 746 747 748 749
                            self._moment1_dict[key][param_group_idx],
                            self._moment2_dict[key][param_group_idx],
                            self._beta1_pow_acc_dict[key][param_group_idx],
                            self._beta2_pow_acc_dict[key][param_group_idx],
                            master_weight,
750 751 752 753 754 755
                            _beta1,
                            _beta2,
                            self._epsilon,
                            find_master,
                            False,
                        )
756 757
                    else:
                        _, _, _, _, _, _ = _legacy_C_ops.merged_adam(
758
                            self._param_dict[key][param_group_idx],
759 760
                            grad_dict[key],
                            lr_dict[key],
761 762 763 764 765 766 767 768 769 770 771
                            self._moment1_dict[key][param_group_idx],
                            self._moment2_dict[key][param_group_idx],
                            self._beta1_pow_acc_dict[key][param_group_idx],
                            self._beta2_pow_acc_dict[key][param_group_idx],
                            master_weight,
                            self._param_dict[key][param_group_idx],
                            self._moment1_dict[key][param_group_idx],
                            self._moment2_dict[key][param_group_idx],
                            self._beta1_pow_acc_dict[key][param_group_idx],
                            self._beta2_pow_acc_dict[key][param_group_idx],
                            master_weight,
772 773 774 775 776 777 778 779 780
                            'epsilon',
                            self._epsilon,
                            'beta1',
                            _beta1,
                            'beta2',
                            _beta2,
                            'multi_precision',
                            find_master,
                        )
Z
zhangbo9674 已提交
781 782
                else:
                    inputs = {
783
                        "Param": self._param_dict[key][param_group_idx],
Z
zhangbo9674 已提交
784 785
                        "Grad": grad_dict[key],
                        "LearningRate": lr_dict[key],
786 787 788 789 790 791 792 793
                        "Moment1": self._moment1_dict[key][param_group_idx],
                        "Moment2": self._moment2_dict[key][param_group_idx],
                        "Beta1Pow": self._beta1_pow_acc_dict[key][
                            param_group_idx
                        ],
                        "Beta2Pow": self._beta2_pow_acc_dict[key][
                            param_group_idx
                        ],
Z
zhangbo9674 已提交
794 795
                    }
                    outputs = {
796 797 798 799 800 801 802 803 804
                        "ParamOut": self._param_dict[key][param_group_idx],
                        "Moment1Out": self._moment1_dict[key][param_group_idx],
                        "Moment2Out": self._moment2_dict[key][param_group_idx],
                        "Beta1PowOut": self._beta1_pow_acc_dict[key][
                            param_group_idx
                        ],
                        "Beta2PowOut": self._beta2_pow_acc_dict[key][
                            param_group_idx
                        ],
Z
zhangbo9674 已提交
805 806 807 808
                    }
                    attrs = {
                        "epsilon": self._epsilon,
                        "beta1": _beta1,
809
                        "beta2": _beta2,
Z
zhangbo9674 已提交
810
                    }
811
                    if find_master:
812 813 814
                        inputs["MasterParam"] = self._master_weight_dict[key][
                            param_group_idx
                        ]
Z
zhangbo9674 已提交
815
                        outputs["MasterParamOut"] = self._master_weight_dict[
816
                            key
817
                        ][param_group_idx]
818
                        attrs["multi_precision"] = find_master
819 820 821 822 823 824 825
                    target_block.append_op(
                        type="merged_adam",
                        inputs=inputs,
                        outputs=outputs,
                        attrs=attrs,
                        stop_gradient=True,
                    )
Z
zhangbo9674 已提交
826 827
        return None

828 829 830 831
    def _update_param_group(self, parameters):
        self._beta1 = parameters.get('beta1', self._default_dict['beta1'])
        self._beta2 = parameters.get('beta2', self._default_dict['beta2'])
        self._epsilon = parameters.get('epsilon', self._default_dict['epsilon'])
832 833 834
        self._lazy_mode = parameters.get(
            'lazy_mode', self._default_dict['lazy_mode']
        )
835 836
        parameters = parameters.get('params')
        return parameters