adam.py 32.4 KB
Newer Older
M
MRXLT 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .optimizer import Optimizer
from ..fluid import core
from ..fluid import framework
18
from ..fluid.framework import Variable, in_dygraph_mode
19 20 21 22
from ..fluid import layers
from ..fluid import unique_name
from ..fluid.layer_helper import LayerHelper
import warnings
W
WangXi 已提交
23
from ..fluid.dygraph import base as imperative_base
24
from collections import defaultdict
M
MRXLT 已提交
25

26
import paddle
27
from paddle import _C_ops, _legacy_C_ops
28

29 30
__all__ = []

M
MRXLT 已提交
31 32

class Adam(Optimizer):
33
    r"""
M
MRXLT 已提交
34 35 36 37
    The Adam optimizer uses an optimization described at the end
    of section 2 of `Adam paper <https://arxiv.org/abs/1412.6980>`_ ,
    it can dynamically adjusts the learning rate of each parameter using
    the 1st moment estimates and the 2nd moment estimates of the gradient.
38

M
MRXLT 已提交
39 40 41 42 43 44
    The parameter ``param_out`` update rule with gradient ``grad``:

    .. math::

        t & = t + 1

45
        moment\_1\_out & = {\beta}_1 * moment\_1 + (1 - {\beta}_1) * grad
M
MRXLT 已提交
46

47
        moment\_2\_out & = {\beta}_2 * moment\_2 + (1 - {\beta}_2) * grad * grad
M
MRXLT 已提交
48

49 50
        learning\_rate & = learning\_rate * \
                          \frac{\sqrt{1 - {\beta}_2^t}}{1 - {\beta}_1^t}
M
MRXLT 已提交
51

52
        param\_out & = param - learning\_rate * \frac{moment\_1}{\sqrt{moment\_2} + \epsilon}
M
MRXLT 已提交
53 54 55 56

    Related paper: `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_

    Args:
57 58
        learning_rate (float|LRScheduler, optional): The learning rate used to update ``Parameter``.
            It can be a float value or a LRScheduler. The default value is 0.001.
M
MRXLT 已提交
59 60 61 62 63 64
        beta1 (float|Tensor, optional): The exponential decay rate for the 1st moment estimates.
            It should be a float number or a Tensor with shape [1] and data type as float32.
            The default value is 0.9.
        beta2 (float|Tensor, optional): The exponential decay rate for the 2nd moment estimates.
            It should be a float number or a Tensor with shape [1] and data type as float32.
            The default value is 0.999.
65 66
        epsilon (float|Tensor, optional): A small float value for numerical stability.
            It should be a float number or a Tensor with shape [1] and data type as float32.
M
MRXLT 已提交
67
            The default value is 1e-08.
68 69 70 71 72 73 74 75 76 77 78 79 80
        parameters (list|tuple, optional): List/Tuple of ``Tensor`` to update to minimize ``loss``.
            This parameter is required in dygraph mode. And you can specify different options for
            different parameter groups such as the learning rate, weight decay, etc,
            then the parameters are list of dict. Note that the learning_rate in paramter groups
            represents the scale of base learning_rate.
            The default value is None in static mode, at this time all parameters will be updated.
        weight_decay (float|WeightDecayRegularizer, optional): The strategy of regularization.
            It canbe a float value as coeff of L2 regularization or
            :ref:`api_fluid_regularizer_L1Decay`, :ref:`api_fluid_regularizer_L2Decay`.
            If a parameter has set regularizer using :ref:`api_fluid_ParamAttr` already,
            the regularization setting here in optimizer will be ignored for this parameter.
            Otherwise, the regularization setting here in optimizer will take effect.
            Default None, meaning there is no regularization.
81 82 83
        grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of
            some derived class of ``GradientClipBase`` . There are three cliping strategies
            ( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` ,
M
MRXLT 已提交
84 85 86 87 88 89 90 91
            :ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping.
        lazy_mode (bool, optional): The official Adam algorithm has two moving-average accumulators.
            The accumulators are updated at every step. Every element of the two moving-average
            is updated in both dense mode and sparse mode. If the size of parameter is very large,
            then the update may be very slow. The lazy mode only update the element that has
            gradient in current mini-batch, so it will be much more faster. But this mode has
            different semantics with the original Adam algorithm and may lead to different result.
            The default value is False.
92
        multi_precision (bool, optional): Whether to use multi-precision during weight updating. Default is false.
Z
zhangbo9674 已提交
93
        use_multi_tensor (bool, optional): Whether to use multi-tensor strategy to update all parameters at once . Default is false.
94 95 96
        name (str, optional): Normally there is no need for user to set this property.
            For more information, please refer to :ref:`api_guide_Name`.
            The default value is None.
M
MRXLT 已提交
97 98 99 100 101 102 103

    Examples:
        .. code-block:: python

            import paddle

            linear = paddle.nn.Linear(10, 10)
104
            inp = paddle.rand([10,10], dtype="float32")
M
MRXLT 已提交
105 106 107 108 109 110 111 112 113 114 115 116 117 118
            out = linear(inp)
            loss = paddle.mean(out)
            adam = paddle.optimizer.Adam(learning_rate=0.1,
                    parameters=linear.parameters())
            out.backward()
            adam.step()
            adam.clear_grad()

        .. code-block:: python

            # Adam with beta1/beta2 as Tensor and weight_decay as float
            import paddle

            linear = paddle.nn.Linear(10, 10)
119
            inp = paddle.rand([10,10], dtype="float32")
M
MRXLT 已提交
120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
            out = linear(inp)
            loss = paddle.mean(out)

            beta1 = paddle.to_tensor([0.9], dtype="float32")
            beta2 = paddle.to_tensor([0.99], dtype="float32")

            adam = paddle.optimizer.Adam(learning_rate=0.1,
                    parameters=linear.parameters(),
                    beta1=beta1,
                    beta2=beta2,
                    weight_decay=0.01)
            out.backward()
            adam.step()
            adam.clear_grad()

135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
            #Note that the learning_rate of linear_2 is 0.01.
            linear_1 = paddle.nn.Linear(10, 10)
            linear_2 = paddle.nn.Linear(10, 10)
            inp = paddle.uniform(shape=[10, 10], min=-0.1, max=0.1)
            out = linear_1(inp)
            out = linear_2(out)
            loss = paddle.mean(out)
            adam = paddle.optimizer.Adam(
                learning_rate=0.1,
                parameters=[{
                    'params': linear_1.parameters()
                }, {
                    'params': linear_2.parameters(),
                    'weight_decay': 0.001,
                    'learning_rate': 0.1,
                    'beta1': 0.8
                }],
                weight_decay=0.01,
153
                beta1=0.9)
154 155 156 157
            out.backward()
            adam.step()
            adam.clear_grad()

M
MRXLT 已提交
158 159 160 161 162 163
    """
    _moment1_acc_str = "moment1"
    _moment2_acc_str = "moment2"
    _beta1_pow_acc_str = "beta1_pow_acc"
    _beta2_pow_acc_str = "beta2_pow_acc"

164 165 166 167 168 169 170 171 172 173 174 175 176 177
    def __init__(
        self,
        learning_rate=0.001,
        beta1=0.9,
        beta2=0.999,
        epsilon=1e-8,
        parameters=None,
        weight_decay=None,
        grad_clip=None,
        lazy_mode=False,
        multi_precision=False,
        use_multi_tensor=False,
        name=None,
    ):
M
MRXLT 已提交
178 179 180 181
        assert learning_rate is not None
        assert beta1 is not None
        assert beta2 is not None
        assert epsilon is not None
182 183 184
        if not isinstance(beta1, Variable):
            if not 0 <= beta1 < 1:
                raise ValueError(
185 186
                    "Invaild value of beta1, expect beta1 in [0,1)."
                )
187 188 189
        if not isinstance(beta2, Variable):
            if not 0 <= beta2 < 1:
                raise ValueError(
190 191
                    "Invaild value of beta2, expect beta2 in [0,1)."
                )
192 193 194
        if not isinstance(epsilon, Variable):
            if not 0 <= epsilon:
                raise ValueError(
195 196 197 198 199 200 201 202 203
                    "Invaild value of epsilon, expect epsilon >= 0."
                )
        super(Adam, self).__init__(
            learning_rate=learning_rate,
            parameters=parameters,
            weight_decay=weight_decay,
            grad_clip=grad_clip,
            name=name,
        )
M
MRXLT 已提交
204 205 206 207 208
        self.type = "adam"
        self._beta1 = beta1
        self._beta2 = beta2
        self._epsilon = epsilon
        self._lazy_mode = lazy_mode
209 210
        self._multi_precision = multi_precision
        self._master_weights = {}
211 212 213 214 215 216
        self._default_dict = {
            'beta1': beta1,
            'beta2': beta2,
            'epsilon': epsilon,
            'lazy_mode': lazy_mode,
        }
217

Z
zhangbo9674 已提交
218 219
        self._use_multi_tensor = use_multi_tensor
        if self._use_multi_tensor:
220 221 222 223 224 225 226
            self._param_dict = self._create_multi_tensor_dict()
            self._moment1_dict = self._create_multi_tensor_dict()
            self._moment2_dict = self._create_multi_tensor_dict()
            self._beta1_pow_acc_dict = self._create_multi_tensor_dict()
            self._beta2_pow_acc_dict = self._create_multi_tensor_dict()
            self._master_weight_dict = self._create_multi_tensor_dict()
            self._master_weight_dict['FP32_LODTensor'] = None
Z
zhangbo9674 已提交
227

228
    def _create_master_weight(self, param):
229 230 231 232 233 234 235
        if param.name in self._master_weights:
            var = self._master_weights[param.name]
        else:
            assert isinstance(self.helper, LayerHelper)

            var_name = param.name + "_fp32_master"
            var_name = unique_name.generate(var_name)
236 237 238 239 240 241 242
            var = layers.create_global_var(
                name=var_name,
                shape=param.shape,
                value=0,
                dtype='float32',
                persistable=True,
            )
243
            block = self.helper.startup_program.global_block()
244 245 246 247 248 249 250 251 252
            block.append_op(
                type="cast",
                inputs={"X": [param]},
                outputs={"Out": [var]},
                attrs={
                    "in_dtype": param.dtype,
                    "out_dtype": core.VarDesc.VarType.FP32,
                },
            )
253
            self._master_weights[param.name] = var
254 255 256 257 258 259 260 261 262 263 264 265
        return var

    def _get_accumulator(self, name, param):
        """Utility function to fetch an accumulator for a parameter
        Args:
            name: name of the accumulator
            param: parameter variable for which accumulator is to be fetched
        Returns:
            accumulator variable for the parameter
        """
        if self._name is not None:
            name = self._name + "_" + name
266 267 268 269 270 271
        find_master = (
            self._multi_precision and param.dtype == core.VarDesc.VarType.FP16
        )
        target_param = (
            self._master_weights[param.name] if find_master else param
        )
272
        target_name = target_param.name
273 274 275 276
        if (
            name not in self._accumulators
            or target_name not in self._accumulators[name]
        ):
277 278
            raise Exception(
                "Accumulator {} does not exist for parameter {}".format(
279 280 281
                    name, target_name
                )
            )
282 283 284 285
        return self._accumulators[name][target_name]

    def _add_moments_pows(self, p):
        acc_dtype = p.dtype
286 287 288 289
        if (
            acc_dtype == core.VarDesc.VarType.FP16
            or acc_dtype == core.VarDesc.VarType.BF16
        ):
290 291 292 293 294 295 296
            acc_dtype = core.VarDesc.VarType.FP32
        self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype)
        self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype)
        self._add_accumulator(
            name=self._beta1_pow_acc_str,
            param=p,
            dtype=acc_dtype,
297 298 299
            fill_value=0.9
            if isinstance(self._beta1, Variable)
            else self._beta1,
300
            shape=[1],
301 302 303
            type=core.VarDesc.VarType.LOD_TENSOR,
            device='cpu',
        )
304 305 306 307
        self._add_accumulator(
            name=self._beta2_pow_acc_str,
            param=p,
            dtype=acc_dtype,
308 309 310
            fill_value=0.999
            if isinstance(self._beta2, Variable)
            else self._beta2,
311
            shape=[1],
312 313 314
            type=core.VarDesc.VarType.LOD_TENSOR,
            device='cpu',
        )
M
MRXLT 已提交
315 316 317

    def _create_accumulators(self, block, parameters):
        assert isinstance(block, framework.Block)
318 319
        if isinstance(parameters, dict):
            parameters = self._update_param_group(parameters)
M
MRXLT 已提交
320 321 322

        # Create accumulator tensors for first and second moments
        for p in parameters:
323 324 325 326
            if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16:
                master_p = self._create_master_weight(p)
                self._add_moments_pows(master_p)
                continue
327 328 329 330
            if (
                p.dtype == core.VarDesc.VarType.FP16
                and not self._multi_precision
            ):
331 332
                warnings.warn(
                    "Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence."
333
                    "Consider using multi_precision=True option of the Adam optimizer."
334 335
                )
            self._add_moments_pows(p)
M
MRXLT 已提交
336 337 338

    def _append_optimize_op(self, block, param_and_grad):
        assert isinstance(block, framework.Block)
339 340
        if isinstance(param_and_grad, dict):
            param_and_grad = self._update_param_group(param_and_grad)
M
MRXLT 已提交
341

342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362
        moment1 = self._get_accumulator(
            self._moment1_acc_str, param_and_grad[0]
        )
        moment2 = self._get_accumulator(
            self._moment2_acc_str, param_and_grad[0]
        )
        beta1_pow_acc = self._get_accumulator(
            self._beta1_pow_acc_str, param_and_grad[0]
        )
        beta2_pow_acc = self._get_accumulator(
            self._beta2_pow_acc_str, param_and_grad[0]
        )
        find_master = (
            self._multi_precision
            and param_and_grad[0].dtype == core.VarDesc.VarType.FP16
        )
        master_weight = (
            self._master_weights[param_and_grad[0].name]
            if find_master
            else None
        )
M
MRXLT 已提交
363 364 365
        lr = self._create_param_lr(param_and_grad)
        # create the adam optimize op

C
chentianyu03 已提交
366 367 368
        if framework.in_dygraph_mode():
            found_inf = self._get_auxiliary_var('found_inf')

369 370 371 372 373 374 375 376 377 378
            _beta1 = (
                self._beta1
                if not isinstance(self._beta1, Variable)
                else self._beta1.numpy().item(0)
            )
            _beta2 = (
                self._beta2
                if not isinstance(self._beta2, Variable)
                else self._beta2.numpy().item(0)
            )
C
chentianyu03 已提交
379

380
            _, _, _, _, _, _ = _C_ops.adam_(
381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397
                param_and_grad[0],
                param_and_grad[1],
                lr,
                moment1,
                moment2,
                beta1_pow_acc,
                beta2_pow_acc,
                master_weight,
                found_inf,
                _beta1,
                _beta2,
                self._epsilon,
                self._lazy_mode,
                1000,
                find_master,
                False,
            )
C
chentianyu03 已提交
398 399 400 401

            return None

        if framework._in_legacy_dygraph():
402

403 404 405 406 407 408 409 410 411 412
            _beta1 = (
                self._beta1
                if not isinstance(self._beta1, Variable)
                else self._beta1.numpy().item(0)
            )
            _beta2 = (
                self._beta2
                if not isinstance(self._beta2, Variable)
                else self._beta2.numpy().item(0)
            )
413
            _, _, _, _, _, _ = _legacy_C_ops.adam(
414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440
                param_and_grad[0],
                param_and_grad[1],
                lr,
                moment1,
                moment2,
                beta1_pow_acc,
                beta2_pow_acc,
                master_weight,
                param_and_grad[0],
                moment1,
                moment2,
                beta1_pow_acc,
                beta2_pow_acc,
                master_weight,
                'epsilon',
                self._epsilon,
                'lazy_mode',
                self._lazy_mode,
                'min_row_size_to_use_multithread',
                1000,
                'beta1',
                _beta1,
                'beta2',
                _beta2,
                'multi_precision',
                find_master,
            )
M
MRXLT 已提交
441 442 443 444 445 446 447 448 449 450

            return None

        inputs = {
            "Param": [param_and_grad[0]],
            "Grad": [param_and_grad[1]],
            "LearningRate": [lr],
            "Moment1": [moment1],
            "Moment2": [moment2],
            "Beta1Pow": [beta1_pow_acc],
451
            "Beta2Pow": [beta2_pow_acc],
M
MRXLT 已提交
452 453 454 455 456 457 458 459 460 461
        }
        outputs = {
            "ParamOut": [param_and_grad[0]],
            "Moment1Out": [moment1],
            "Moment2Out": [moment2],
            "Beta1PowOut": [beta1_pow_acc],
            "Beta2PowOut": [beta2_pow_acc],
        }
        attrs = {
            "lazy_mode": self._lazy_mode,
462
            "min_row_size_to_use_multithread": 1000,
463
            "multi_precision": find_master,
M
MRXLT 已提交
464 465 466 467 468 469 470 471 472 473
        }

        if isinstance(self._beta1, Variable):
            inputs['Beta1Tensor'] = self._beta1
        else:
            attrs['beta1'] = self._beta1
        if isinstance(self._beta2, Variable):
            inputs['Beta2Tensor'] = self._beta2
        else:
            attrs['beta2'] = self._beta2
474 475 476 477
        if isinstance(self._epsilon, Variable):
            inputs['EpsilonTensor'] = self._epsilon
        else:
            attrs['epsilon'] = self._epsilon
M
MRXLT 已提交
478

479 480 481 482
        if find_master:
            inputs["MasterParam"] = master_weight
            outputs["MasterParamOut"] = master_weight

483 484 485 486 487 488 489
        adam_op = block.append_op(
            type=self.type,
            inputs=inputs,
            outputs=outputs,
            attrs=attrs,
            stop_gradient=True,
        )
M
MRXLT 已提交
490 491

        return adam_op
492

W
WangXi 已提交
493
    @imperative_base.no_grad
494 495 496 497
    @framework.dygraph_only
    def step(self):
        """
        Execute the optimizer and update parameters once.
498

499 500 501 502 503 504 505
        Returns:
            None

        Examples:
            .. code-block:: python

                import paddle
506

507
                a = paddle.rand([2,13], dtype="float32")
508 509
                linear = paddle.nn.Linear(13, 5)
                # This can be any optimizer supported by dygraph.
510
                adam = paddle.optimizer.Adam(learning_rate = 0.01,
511 512 513 514 515 516
                                            parameters = linear.parameters())
                out = linear(a)
                out.backward()
                adam.step()
                adam.clear_grad()
        """
517 518 519 520 521 522 523
        if not isinstance(self._parameter_list[0], dict):
            params_grads = []
            for param in self._parameter_list:
                if param.stop_gradient:
                    continue
                if param._grad_ivar() is not None:
                    grad_var = param._grad_ivar()
524
                    if in_dygraph_mode():
525 526 527 528 529
                        if (
                            hasattr(grad_var, "is_selected_rows")
                            and grad_var.is_selected_rows()
                            and self.regularization is not None
                        ):
530 531 532 533
                            raise RuntimeError(
                                "Adam don't support weight_decay with sparse parameters, please set it to None."
                            )
                    else:
534 535 536 537 538
                        if (
                            hasattr(grad_var, "_is_sparse")
                            and grad_var._is_sparse()
                            and self.regularization is not None
                        ):
539 540 541
                            raise RuntimeError(
                                "Adam don't support weight_decay with sparse parameters, please set it to None."
                            )
542 543
                    params_grads.append((param, grad_var))

544
            optimize_ops = self._apply_optimize(
545 546 547 548
                loss=None,
                startup_program=None,
                params_grads=params_grads,
                param_group_idx=0,
549
            )
550 551
        else:
            # optimize parameters in groups
552
            for idx, param_group in enumerate(self._param_groups):
553 554 555 556 557 558 559 560
                params_grads = defaultdict(lambda: list())
                for param in param_group['params']:
                    if param.stop_gradient:
                        continue
                    if param._grad_ivar() is not None:
                        grad_var = param._grad_ivar()
                        params_grads['params'].append((param, grad_var))
                params_grads.update(
561 562 563
                    {k: v for k, v in param_group.items() if k != 'params'}
                )
                self._apply_optimize(
564 565 566 567
                    loss=None,
                    startup_program=None,
                    params_grads=params_grads,
                    param_group_idx=idx,
568
                )
569

570
    def _multi_tensor_init(self, target_block, parameters, param_group_idx):
Z
zhangbo9674 已提交
571 572 573 574 575 576 577 578 579 580 581
        """
        All parameters used for optimizer (such as: parameters, master_weight, velocity_acc for momentum) calculations are grouped into a python list by data type (float16, float32).
        This function will be overridden in the corresponding optimizer file.
        Args:
            target_block: the block in which the loss tensor is present
            parameters: list of parameter tensors for the optimizer
        """
        self._create_accumulators(target_block, parameters)
        for param in parameters:
            moment1 = self._get_accumulator(self._moment1_acc_str, param)
            moment2 = self._get_accumulator(self._moment2_acc_str, param)
582 583 584 585 586 587
            beta1_pow_acc = self._get_accumulator(
                self._beta1_pow_acc_str, param
            )
            beta2_pow_acc = self._get_accumulator(
                self._beta2_pow_acc_str, param
            )
Z
zhangbo9674 已提交
588 589

            if param.dtype == paddle.float32:
590 591 592 593 594 595 596 597 598 599 600 601 602 603 604
                self._param_dict['FP32_LODTensor'][param_group_idx].append(
                    param
                )
                self._moment1_dict['FP32_LODTensor'][param_group_idx].append(
                    moment1
                )
                self._moment2_dict['FP32_LODTensor'][param_group_idx].append(
                    moment2
                )
                self._beta1_pow_acc_dict['FP32_LODTensor'][
                    param_group_idx
                ].append(beta1_pow_acc)
                self._beta2_pow_acc_dict['FP32_LODTensor'][
                    param_group_idx
                ].append(beta2_pow_acc)
Z
zhangbo9674 已提交
605
            elif param.dtype == paddle.float16:
606 607 608 609 610 611 612 613 614 615 616 617 618 619 620
                self._param_dict['FP16_LODTensor'][param_group_idx].append(
                    param
                )
                self._moment1_dict['FP16_LODTensor'][param_group_idx].append(
                    moment1
                )
                self._moment2_dict['FP16_LODTensor'][param_group_idx].append(
                    moment2
                )
                self._beta1_pow_acc_dict['FP16_LODTensor'][
                    param_group_idx
                ].append(beta1_pow_acc)
                self._beta2_pow_acc_dict['FP16_LODTensor'][
                    param_group_idx
                ].append(beta2_pow_acc)
Z
zhangbo9674 已提交
621
                if self._multi_precision:
622 623 624
                    self._master_weight_dict['FP16_LODTensor'][
                        param_group_idx
                    ].append(self._master_weights[param.name])
Z
zhangbo9674 已提交
625 626 627 628 629 630 631
                else:
                    self._master_weight_dict['FP16_LODTensor'] = None
            else:
                raise ValueError(
                    "Now multi_tensor_momentum only support fp32 and fp16 parameters and grad is LOD_TENSOR."
                )

632
    def _append_optimize_multi_tensor_op(
633 634 635 636
        self,
        target_block,
        parameters_and_grads,
        param_group_idx,
637
    ):
638
        """
Z
zhangbo9674 已提交
639 640 641 642 643 644 645 646 647 648 649 650
        For Multi Tensor, append optimize merged_operator to block.
        """
        assert isinstance(target_block, framework.Block)

        grad_dict = {'FP32_LODTensor': [], 'FP16_LODTensor': []}
        lr_dict = {'FP32_LODTensor': [], 'FP16_LODTensor': []}

        if isinstance(parameters_and_grads, list):
            for param_and_grad in parameters_and_grads:
                if param_and_grad[1] is None:
                    continue
                if param_and_grad[0].stop_gradient is False:
651 652 653 654 655
                    if (
                        param_and_grad[0].dtype == paddle.float32
                        and param_and_grad[1].type
                        == core.VarDesc.VarType.LOD_TENSOR
                    ):
Z
zhangbo9674 已提交
656 657 658
                        grad_dict['FP32_LODTensor'].append(param_and_grad[1])
                        lr = self._create_param_lr(param_and_grad)
                        lr_dict['FP32_LODTensor'].append(lr)
659 660 661 662 663
                    elif (
                        param_and_grad[0].dtype == paddle.float16
                        and param_and_grad[1].type
                        == core.VarDesc.VarType.LOD_TENSOR
                    ):
Z
zhangbo9674 已提交
664 665 666 667 668 669 670 671 672 673
                        grad_dict['FP16_LODTensor'].append(param_and_grad[1])
                        lr = self._create_param_lr(param_and_grad)
                        lr_dict['FP16_LODTensor'].append(lr)
        else:
            for param_and_grad in parameters_and_grads['params']:
                if param_and_grad[1] is None:
                    continue
                if param_and_grad[0].stop_gradient is False:
                    param_grad_dict = dict()
                    param_grad_dict['params'] = param_and_grad
674 675 676 677 678 679 680
                    param_grad_dict.update(
                        {
                            k: v
                            for k, v in parameters_and_grads.items()
                            if k != 'params'
                        }
                    )
Z
zhangbo9674 已提交
681
                    param_and_grad = self._update_param_group(param_grad_dict)
682 683 684 685 686
                    if (
                        param_and_grad[0].dtype == paddle.float32
                        and param_and_grad[1].type
                        == core.VarDesc.VarType.LOD_TENSOR
                    ):
Z
zhangbo9674 已提交
687 688 689
                        grad_dict['FP32_LODTensor'].append(param_and_grad[1])
                        lr = self._create_param_lr(param_and_grad)
                        lr_dict['FP32_LODTensor'].append(lr)
690 691 692 693 694
                    elif (
                        param_and_grad[0].dtype == paddle.float16
                        and param_and_grad[1].type
                        == core.VarDesc.VarType.LOD_TENSOR
                    ):
Z
zhangbo9674 已提交
695 696 697 698 699 700
                        grad_dict['FP16_LODTensor'].append(param_and_grad[1])
                        lr = self._create_param_lr(param_and_grad)
                        lr_dict['FP16_LODTensor'].append(lr)

        multi_tensor_list = ['FP32_LODTensor', 'FP16_LODTensor']
        for key in multi_tensor_list:
701
            if len(self._param_dict[key][param_group_idx]) > 0:
702
                find_master = self._multi_precision and key == 'FP16_LODTensor'
Z
zhangbo9674 已提交
703

704 705 706 707 708 709 710 711 712 713
                _beta1 = (
                    self._beta1
                    if not isinstance(self._beta1, Variable)
                    else self._beta1.numpy().item(0)
                )
                _beta2 = (
                    self._beta2
                    if not isinstance(self._beta2, Variable)
                    else self._beta2.numpy().item(0)
                )
Z
zhangbo9674 已提交
714

J
Jiabin Yang 已提交
715
                if framework._non_static_mode():
716 717 718 719 720 721
                    master_weight = self._master_weight_dict[key]
                    master_weight = (
                        master_weight[param_group_idx]
                        if master_weight is not None
                        else None
                    )
722
                    if in_dygraph_mode():
723

724
                        _, _, _, _, _, _ = _C_ops.merged_adam_(
725
                            self._param_dict[key][param_group_idx],
726 727
                            grad_dict[key],
                            lr_dict[key],
728 729 730 731 732
                            self._moment1_dict[key][param_group_idx],
                            self._moment2_dict[key][param_group_idx],
                            self._beta1_pow_acc_dict[key][param_group_idx],
                            self._beta2_pow_acc_dict[key][param_group_idx],
                            master_weight,
733 734 735 736 737 738
                            _beta1,
                            _beta2,
                            self._epsilon,
                            find_master,
                            False,
                        )
739 740
                    else:
                        _, _, _, _, _, _ = _legacy_C_ops.merged_adam(
741
                            self._param_dict[key][param_group_idx],
742 743
                            grad_dict[key],
                            lr_dict[key],
744 745 746 747 748 749 750 751 752 753 754
                            self._moment1_dict[key][param_group_idx],
                            self._moment2_dict[key][param_group_idx],
                            self._beta1_pow_acc_dict[key][param_group_idx],
                            self._beta2_pow_acc_dict[key][param_group_idx],
                            master_weight,
                            self._param_dict[key][param_group_idx],
                            self._moment1_dict[key][param_group_idx],
                            self._moment2_dict[key][param_group_idx],
                            self._beta1_pow_acc_dict[key][param_group_idx],
                            self._beta2_pow_acc_dict[key][param_group_idx],
                            master_weight,
755 756 757 758 759 760 761 762 763
                            'epsilon',
                            self._epsilon,
                            'beta1',
                            _beta1,
                            'beta2',
                            _beta2,
                            'multi_precision',
                            find_master,
                        )
Z
zhangbo9674 已提交
764 765
                else:
                    inputs = {
766
                        "Param": self._param_dict[key][param_group_idx],
Z
zhangbo9674 已提交
767 768
                        "Grad": grad_dict[key],
                        "LearningRate": lr_dict[key],
769 770 771 772 773 774 775 776
                        "Moment1": self._moment1_dict[key][param_group_idx],
                        "Moment2": self._moment2_dict[key][param_group_idx],
                        "Beta1Pow": self._beta1_pow_acc_dict[key][
                            param_group_idx
                        ],
                        "Beta2Pow": self._beta2_pow_acc_dict[key][
                            param_group_idx
                        ],
Z
zhangbo9674 已提交
777 778
                    }
                    outputs = {
779 780 781 782 783 784 785 786 787
                        "ParamOut": self._param_dict[key][param_group_idx],
                        "Moment1Out": self._moment1_dict[key][param_group_idx],
                        "Moment2Out": self._moment2_dict[key][param_group_idx],
                        "Beta1PowOut": self._beta1_pow_acc_dict[key][
                            param_group_idx
                        ],
                        "Beta2PowOut": self._beta2_pow_acc_dict[key][
                            param_group_idx
                        ],
Z
zhangbo9674 已提交
788 789 790 791
                    }
                    attrs = {
                        "epsilon": self._epsilon,
                        "beta1": _beta1,
792
                        "beta2": _beta2,
Z
zhangbo9674 已提交
793
                    }
794
                    if find_master:
795 796 797
                        inputs["MasterParam"] = self._master_weight_dict[key][
                            param_group_idx
                        ]
Z
zhangbo9674 已提交
798
                        outputs["MasterParamOut"] = self._master_weight_dict[
799
                            key
800
                        ][param_group_idx]
801
                        attrs["multi_precision"] = find_master
802 803 804 805 806 807 808
                    target_block.append_op(
                        type="merged_adam",
                        inputs=inputs,
                        outputs=outputs,
                        attrs=attrs,
                        stop_gradient=True,
                    )
Z
zhangbo9674 已提交
809 810
        return None

811 812 813 814
    def _update_param_group(self, parameters):
        self._beta1 = parameters.get('beta1', self._default_dict['beta1'])
        self._beta2 = parameters.get('beta2', self._default_dict['beta2'])
        self._epsilon = parameters.get('epsilon', self._default_dict['epsilon'])
815 816 817
        self._lazy_mode = parameters.get(
            'lazy_mode', self._default_dict['lazy_mode']
        )
818 819
        parameters = parameters.get('params')
        return parameters