momentum.py 22.6 KB
Newer Older
J
Jiawei Wang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

J
Jiangxinz 已提交
15 16
import warnings

17
import paddle
18 19
from paddle import _C_ops
from paddle.fluid.framework import in_dygraph_mode
20 21
from paddle.fluid.regularizer import L2DecayRegularizer

22
from ..fluid import core, framework
23
from .optimizer import Optimizer
J
Jiawei Wang 已提交
24

25 26
__all__ = []

J
Jiawei Wang 已提交
27 28

class Momentum(Optimizer):
29
    r"""
J
Jiawei Wang 已提交
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53

    Simple Momentum optimizer with velocity state

    This optimizer has a flag for Nestrov Momentum.

    The update equations are as follows:

    .. math::

        & velocity = mu * velocity + gradient

        & if (use\_nesterov):

        &\quad   param = param - (gradient + mu * velocity) * learning\_rate

        & else:

        &\quad   param = param - learning\_rate * velocity

    Parameters:

        learning_rate (float|Tensor|LearningRateDecay, optional): The learning rate used to update ``Parameter``.
            It can be a float value, a ``Tensor`` with a float type or a LearningRateDecay. The default value is 0.001.
        momentum (float): Momentum factor. The default value is 0.9.
54 55 56 57 58
        parameters (list|tuple, optional): List|Tuple of ``Tensor`` to update to minimize ``loss``. \
            This parameter is required in dygraph mode. And you can specify different options for \
            different parameter groups such as the learning rate, weight decay, etc, \
            then the parameters are list of dict. Note that the learning_rate in paramter groups \
            represents the scale of base learning_rate. \
59
            The default value is None in static graph mode, at this time all parameters will be updated.
J
Jiawei Wang 已提交
60
        weight_decay (float|WeightDecayRegularizer, optional): The strategy of regularization. \
61 62 63 64 65 66
            It canbe a float value as coeff of L2 regularization or \
            :ref:`api_fluid_regularizer_L1Decay`, :ref:`api_fluid_regularizer_L2Decay`.
            If a parameter has set regularizer using :ref:`api_fluid_ParamAttr` already, \
            the regularization setting here in optimizer will be ignored for this parameter. \
            Otherwise, the regularization setting here in optimizer will take effect. \
            Default None, meaning there is no regularization.
J
Jiawei Wang 已提交
67 68 69 70
        grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of
            some derived class of ``GradientClipBase`` . There are three cliping strategies
            ( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` ,
            :ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping.
H
huangxu96 已提交
71 72 73
        multi_precision (bool, optional): Whether to use multi-precision during weight updating. Default is false.
        rescale_grad (float, optional): Multiply the gradient with `rescale_grad` before updating. \
            Often choose to be ``1.0/batch_size``.
74
        use_multi_tensor (bool, optional): Whether to use multi-tensor strategy to update all parameters at once . Default is false.
J
Jiawei Wang 已提交
75 76 77 78 79 80 81 82
        name (str, optional): The default value is None. Normally there is no need for user
                to set this property. For more information, please refer to
                :ref:`api_guide_Name` .

    Examples:
        .. code-block:: python

            import paddle
83 84

            inp = paddle.uniform([10, 10], dtype="float32", min=-0.1, max=0.1)
J
Jiawei Wang 已提交
85 86 87 88 89 90 91 92 93 94
            linear = paddle.nn.Linear(10, 10)
            inp = paddle.to_tensor(inp)
            out = linear(inp)
            loss = paddle.mean(out)
            beta1 = paddle.to_tensor([0.9], dtype="float32")
            beta2 = paddle.to_tensor([0.99], dtype="float32")
            momentum = paddle.optimizer.Momentum(learning_rate=0.1, parameters=linear.parameters(), weight_decay=0.01)
            back = out.backward()
            momentum.step()
            momentum.clear_grad()
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112

            #Note that the learning_rate of linear_2 is 0.01.
            linear_1 = paddle.nn.Linear(10, 10)
            linear_2 = paddle.nn.Linear(10, 10)
            inp = paddle.uniform(shape=[10, 10], min=-0.1, max=0.1)
            out = linear_1(inp)
            out = linear_2(out)
            loss = paddle.mean(out)
            momentum = paddle.optimizer.Momentum(
                learning_rate=0.1,
                parameters=[{
                    'params': linear_1.parameters()
                }, {
                    'params': linear_2.parameters(),
                    'weight_decay': 0.001,
                    'learning_rate': 0.1
                }],
                weight_decay=0.01,
113
                momentum=0.9)
114 115 116 117
            out.backward()
            momentum.step()
            momentum.clear_grad()

J
Jiawei Wang 已提交
118 119 120
    """
    _velocity_acc_str = "velocity"

121 122 123 124 125 126 127 128 129 130 131 132 133
    def __init__(
        self,
        learning_rate=0.001,
        momentum=0.9,
        parameters=None,
        use_nesterov=False,
        weight_decay=None,
        grad_clip=None,
        multi_precision=False,
        rescale_grad=1.0,
        use_multi_tensor=False,
        name=None,
    ):
J
Jiawei Wang 已提交
134 135 136 137
        if learning_rate is None:
            raise ValueError("learning_rate is not set")
        if momentum is None:
            raise ValueError("momentum is not set")
138

139 140 141
        predicate = lambda regular: isinstance(
            regular, (L2DecayRegularizer, float)
        )
142 143 144
        if isinstance(parameters, list):
            if isinstance(parameters[0], dict):
                for param_group in parameters:
145 146 147 148 149
                    decay = (
                        param_group['weight_decay']
                        if 'weight_decay' in param_group
                        else weight_decay
                    )
150 151 152 153 154 155
                    reg_method, reg_coeff = self._update_regularization(decay)
                    param_group['regularization_method'] = reg_method
                    param_group['regularization_coeff'] = reg_coeff
                    py_regular = None if predicate(decay) else decay
                    param_group['weight_decay'] = py_regular

H
huangxu96 已提交
156
        py_regular = None if predicate(weight_decay) else weight_decay
157
        super().__init__(
158 159 160 161 162 163
            learning_rate=learning_rate,
            parameters=parameters,
            weight_decay=py_regular,
            grad_clip=grad_clip,
            name=name,
        )
J
Jiawei Wang 已提交
164 165 166
        self.type = "momentum"
        self._momentum = momentum
        self._use_nesterov = bool(use_nesterov)
167 168 169 170
        (
            self._regularization_method,
            self._regularization_coeff,
        ) = self._update_regularization(weight_decay)
H
huangxu96 已提交
171 172 173 174
        self._multi_precision = multi_precision
        self._rescale_grad = rescale_grad
        self._master_weights = {}

175 176 177 178 179 180 181
        self._default_dict = {
            'momentum': momentum,
            'use_nesterov': use_nesterov,
            'rescale_grad': rescale_grad,
            'regularization_method': self._regularization_method,
            'regularization_coeff': self._regularization_coeff,
        }
182 183
        self._use_multi_tensor = use_multi_tensor
        if self._use_multi_tensor:
184 185 186 187 188 189
            self._param_dict = self._create_multi_tensor_dict()
            self._velocity_dict = self._create_multi_tensor_dict()
            self._master_weight_dict = self._create_multi_tensor_dict()
            self._master_weight_dict['FP32_LODTensor'] = None
            self._regularization_method_dict = self._create_multi_tensor_dict()
            self._regularization_coeff_dict = self._create_multi_tensor_dict()
190 191 192

    def _update_regularization(self, weight_decay):
        reg_method = ""
193
        reg_coeff = 0.0
194

195
        if isinstance(weight_decay, L2DecayRegularizer):
196 197
            reg_method = "l2_decay"
            reg_coeff = weight_decay._regularization_coeff
198
        if isinstance(weight_decay, float):
199 200 201
            reg_method = "l2_decay"
            reg_coeff = weight_decay
        return reg_method, reg_coeff
J
Jiawei Wang 已提交
202 203

    def _create_accumulators(self, block, parameters):
204
        '''
J
Jiabin Yang 已提交
205
        if framework._non_static_mode():
206
            return
207
        '''
J
Jiawei Wang 已提交
208
        assert isinstance(block, framework.Block)
209 210 211 212

        if isinstance(parameters, dict):
            parameters = self._update_param_group(parameters)

213
        for p in parameters:
214
            if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype):
215 216 217
                master_p = self._create_master_weight(p)
                self._add_accumulator(self._velocity_acc_str, master_p)
                continue
218
            if (
219
                self._is_dtype_fp16_or_bf16(p.dtype)
220 221
                and not self._multi_precision
            ):
222
                warnings.warn(
223
                    "Accumulating with FP16/BF16 in optimizer can lead to poor accuracy or slow convergence."
224 225 226
                    "Consider using multi_precision=True option of the Momentum optimizer."
                )
            self._add_accumulator(self._velocity_acc_str, p)
J
Jiawei Wang 已提交
227

228
    def _create_regularization_of_grad(self, param, grad, regularization=None):
229
        """Create and add backward regularization Operators
230

231 232 233 234
        Function helper of append_regularization_ops.
        """
        # If ParamAttr is set to L2Decay, we skip doing regularization here. And then we fused
        # L2Decay with momentum which can refer to _append_optimize_op below.
235 236 237
        if hasattr(param, 'regularizer') and isinstance(
            param.regularizer, L2DecayRegularizer
        ):
238
            return grad
239
        return super()._create_regularization_of_grad(
240 241
            param, grad, regularization
        )
242

J
Jiawei Wang 已提交
243 244
    def _append_optimize_op(self, block, param_and_grad):
        assert isinstance(block, framework.Block)
245 246
        if isinstance(param_and_grad, dict):
            param_and_grad = self._update_param_group(param_and_grad)
J
Jiawei Wang 已提交
247

248
        velocity_acc = self._get_accumulator_master(
249 250
            self._velocity_acc_str, param_and_grad[0]
        )
J
Jiawei Wang 已提交
251 252
        lr = self._create_param_lr(param_and_grad)

253
        # For fusion of momentum and l2decay
254 255 256 257 258 259 260 261 262 263 264
        param = param_and_grad[0]
        regularization_method = self._regularization_method
        regularization_coeff = self._regularization_coeff
        if hasattr(param, 'regularizer'):
            # we skip param's l2decay before, so fuse it with momentum here.
            if isinstance(param.regularizer, L2DecayRegularizer):
                regularization_method = "l2_decay"
                regularization_coeff = param.regularizer._regularization_coeff
            # the param's regularization has been done before, we avoid do l2decay in momentum.
            elif param.regularizer is not None:
                regularization_method = ""
265
                regularization_coeff = 0.0
266

267 268
        find_master = self._multi_precision and self._is_dtype_fp16_or_bf16(
            param_and_grad[0].dtype
269 270 271 272 273 274
        )
        master_weight = (
            self._master_weights[param_and_grad[0].name]
            if find_master
            else None
        )
275

276 277 278
        if in_dygraph_mode():
            if isinstance(param_and_grad, dict):
                self._update_regularization(param_and_grad['weight_decay'])
279 280 281 282 283 284 285 286 287 288 289 290 291
            return _C_ops.momentum_(
                param_and_grad[0],
                param_and_grad[1],
                velocity_acc,
                lr,
                master_weight,
                self._momentum,
                self._use_nesterov,
                regularization_method,
                regularization_coeff,
                find_master,
                self._rescale_grad,
            )
292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325
        else:
            attrs = {
                "mu": self._momentum,
                "use_nesterov": self._use_nesterov,
                "regularization_method": regularization_method,
                "regularization_coeff": regularization_coeff,
                "multi_precision": find_master,
                "rescale_grad": self._rescale_grad,
            }

            inputs = {
                "Param": [param_and_grad[0]],
                "Grad": [param_and_grad[1]],
                "Velocity": [velocity_acc],
                "LearningRate": [lr],
            }

            outputs = {
                "ParamOut": [param_and_grad[0]],
                "VelocityOut": [velocity_acc],
            }

            if find_master:
                inputs["MasterParam"] = master_weight
                outputs["MasterParamOut"] = master_weight

            # create the momentum optimize op
            momentum_op = block.append_op(
                type=self.type,
                inputs=inputs,
                outputs=outputs,
                attrs=attrs,
                stop_gradient=True,
            )
326

327
            return momentum_op
328

329
    def _multi_tensor_init(self, target_block, parameters, param_group_idx):
330
        """
331
        All parameters used for optimizer (such as: parameters, master_weight, velocity_acc for momentum) calculations are grouped into a python list by data type (float16, bf16, float32).
332 333 334 335 336 337 338 339
        This function will be overridden in the corresponding optimizer file.

        Args:
            target_block: the block in which the loss tensor is present
            parameters: list of parameter tensors for the optimizer
        """
        self._create_accumulators(target_block, parameters)
        for param in parameters:
340 341 342
            velocity_acc = self._get_accumulator_master(
                self._velocity_acc_str, param
            )
343 344 345 346 347 348
            regularization_method = self._regularization_method
            regularization_coeff = self._regularization_coeff
            if hasattr(param, 'regularizer'):
                # we skip param's l2decay before, so fuse it with momentum here.
                if isinstance(param.regularizer, L2DecayRegularizer):
                    regularization_method = "l2_decay"
349 350 351
                    regularization_coeff = (
                        param.regularizer._regularization_coeff
                    )
352
                elif param.regularizer is not None:
353 354 355
                    regularization_method = ""
                    regularization_coeff = 0.0
            if param.dtype == paddle.float32:
356 357
                self._param_dict['FP32_LODTensor'][param_group_idx].append(
                    param
358
                )
359 360
                self._velocity_dict['FP32_LODTensor'][param_group_idx].append(
                    velocity_acc
361
                )
362 363 364 365 366 367 368
                # fp32 no master weight
                self._regularization_method_dict['FP32_LODTensor'][
                    param_group_idx
                ].append(regularization_method)
                self._regularization_coeff_dict['FP32_LODTensor'][
                    param_group_idx
                ].append(regularization_coeff)
369
            elif self._is_dtype_fp16_or_bf16(param.dtype):
370 371
                self._param_dict['FP16_LODTensor'][param_group_idx].append(
                    param
372
                )
373 374
                self._velocity_dict['FP16_LODTensor'][param_group_idx].append(
                    velocity_acc
375
                )
376 377 378 379 380 381 382 383 384 385 386 387 388 389
                if self._multi_precision:
                    self._master_weight_dict['FP16_LODTensor'][
                        param_group_idx
                    ].append(self._master_weights[param.name])
                else:
                    self._master_weight_dict['FP16_LODTensor'][
                        param_group_idx
                    ] = None
                self._regularization_method_dict['FP16_LODTensor'][
                    param_group_idx
                ].append(regularization_method)
                self._regularization_coeff_dict['FP16_LODTensor'][
                    param_group_idx
                ].append(regularization_coeff)
390 391
            else:
                raise ValueError(
392
                    "Now multi_tensor_momentum only support fp32, fp16 or bf16 parameters and grad is LOD_TENSOR."
393 394
                )

395
    def _append_optimize_multi_tensor_op(
396 397 398 399
        self,
        target_block,
        parameters_and_grads,
        param_group_idx,
400
    ):
401
        """
402 403 404 405 406 407 408 409 410 411 412 413
        For Multi Tensor, append optimize merged_operator to block.
        """
        assert isinstance(target_block, framework.Block)

        grad_dict = {'FP32_LODTensor': [], 'FP16_LODTensor': []}
        lr_dict = {'FP32_LODTensor': [], 'FP16_LODTensor': []}

        if isinstance(parameters_and_grads, list):
            for param_and_grad in parameters_and_grads:
                if param_and_grad[1] is None:
                    continue
                if param_and_grad[0].stop_gradient is False:
414 415 416 417 418
                    if (
                        param_and_grad[0].dtype == paddle.float32
                        and param_and_grad[1].type
                        == core.VarDesc.VarType.LOD_TENSOR
                    ):
419 420 421
                        grad_dict['FP32_LODTensor'].append(param_and_grad[1])
                        lr = self._create_param_lr(param_and_grad)
                        lr_dict['FP32_LODTensor'].append(lr)
422
                    elif (
423
                        self._is_dtype_fp16_or_bf16(param_and_grad[0].dtype)
424 425 426
                        and param_and_grad[1].type
                        == core.VarDesc.VarType.LOD_TENSOR
                    ):
427 428 429 430 431 432 433 434 435 436
                        grad_dict['FP16_LODTensor'].append(param_and_grad[1])
                        lr = self._create_param_lr(param_and_grad)
                        lr_dict['FP16_LODTensor'].append(lr)
        else:
            for param_and_grad in parameters_and_grads['params']:
                if param_and_grad[1] is None:
                    continue
                if param_and_grad[0].stop_gradient is False:
                    param_grad_dict = dict()
                    param_grad_dict['params'] = param_and_grad
437 438 439 440 441 442 443
                    param_grad_dict.update(
                        {
                            k: v
                            for k, v in parameters_and_grads.items()
                            if k != 'params'
                        }
                    )
444
                    param_and_grad = self._update_param_group(param_grad_dict)
445 446 447 448 449
                    if (
                        param_and_grad[0].dtype == paddle.float32
                        and param_and_grad[1].type
                        == core.VarDesc.VarType.LOD_TENSOR
                    ):
450 451 452
                        grad_dict['FP32_LODTensor'].append(param_and_grad[1])
                        lr = self._create_param_lr(param_and_grad)
                        lr_dict['FP32_LODTensor'].append(lr)
453
                    elif (
454
                        self._is_dtype_fp16_or_bf16(param_and_grad[0].dtype)
455 456 457
                        and param_and_grad[1].type
                        == core.VarDesc.VarType.LOD_TENSOR
                    ):
458 459 460 461 462 463
                        grad_dict['FP16_LODTensor'].append(param_and_grad[1])
                        lr = self._create_param_lr(param_and_grad)
                        lr_dict['FP16_LODTensor'].append(lr)

        multi_tensor_list = ['FP32_LODTensor', 'FP16_LODTensor']
        for key in multi_tensor_list:
464
            if len(self._param_dict[key][param_group_idx]) > 0:
465
                find_master = self._multi_precision and key == 'FP16_LODTensor'
466

467 468 469 470 471 472 473
                master_weight = self._master_weight_dict[key]
                master_weight = (
                    master_weight[param_group_idx]
                    if master_weight is not None
                    else None
                )

474
                if in_dygraph_mode():
475 476 477 478 479 480 481 482 483 484 485 486 487
                    _, _, _ = _C_ops.merged_momentum_(
                        self._param_dict[key][param_group_idx],
                        grad_dict[key],
                        self._velocity_dict[key][param_group_idx],
                        lr_dict[key],
                        master_weight,
                        self._momentum,
                        self._use_nesterov,
                        self._regularization_method_dict[key][param_group_idx],
                        self._regularization_coeff_dict[key][param_group_idx],
                        find_master,
                        self._rescale_grad,
                    )
488 489
                else:
                    inputs = {
490
                        "Param": self._param_dict[key][param_group_idx],
491
                        "Grad": grad_dict[key],
492
                        "Velocity": self._velocity_dict[key][param_group_idx],
493 494 495
                        "LearningRate": lr_dict[key],
                    }
                    outputs = {
496 497 498 499
                        "ParamOut": self._param_dict[key][param_group_idx],
                        "VelocityOut": self._velocity_dict[key][
                            param_group_idx
                        ],
500 501
                    }
                    attrs = {
502 503 504 505
                        "mu": self._momentum,
                        "use_nesterov": self._use_nesterov,
                        "regularization_method": self._regularization_method_dict[
                            key
506 507
                        ][
                            param_group_idx
508 509 510
                        ],
                        "regularization_coeff": self._regularization_coeff_dict[
                            key
511
                        ][param_group_idx],
512
                    }
513
                    if find_master:
514 515 516
                        inputs["MasterParam"] = self._master_weight_dict[key][
                            param_group_idx
                        ]
517
                        outputs["MasterParamOut"] = self._master_weight_dict[
518
                            key
519
                        ][param_group_idx]
520
                        attrs["multi_precision"] = find_master
521 522 523 524 525 526 527
                    target_block.append_op(
                        type="merged_momentum",
                        inputs=inputs,
                        outputs=outputs,
                        attrs=attrs,
                        stop_gradient=True,
                    )
528 529
        return None

530
    def _update_param_group(self, parameters):
531 532 533 534 535 536 537 538 539
        self._momentum = parameters.get(
            'momentum', self._default_dict['momentum']
        )
        self._use_nesterov = parameters.get(
            'use_nesterov', self._default_dict['use_nesterov']
        )
        self._rescale_grad = parameters.get(
            'rescale_grad', self._default_dict['rescale_grad']
        )
540
        self._regularization_method = parameters.get(
541 542
            'regularization_method', self._default_dict['regularization_method']
        )
543
        self._regularization_coeff = parameters.get(
544 545
            'regularization_coeff', self._default_dict['regularization_coeff']
        )
546 547
        parameters = parameters.get('params')
        return parameters