momentum.py 24.3 KB
Newer Older
J
Jiawei Wang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

J
Jiangxinz 已提交
15 16
import warnings

J
Jiawei Wang 已提交
17 18 19 20
from .optimizer import Optimizer
from ..fluid import core
from ..fluid import framework
from ..fluid.framework import Variable, name_scope
21
from ..fluid.layer_helper import LayerHelper
H
huangxu96 已提交
22 23
from ..fluid import unique_name
from ..fluid import layers
24
import paddle.fluid as fluid
H
huangxu96 已提交
25
from paddle.fluid.regularizer import L2DecayRegularizer
W
wanghuancoder 已提交
26
from paddle import _C_ops
27
import paddle
28
from paddle.fluid.framework import in_dygraph_mode, _in_legacy_dygraph
J
Jiawei Wang 已提交
29

30 31
__all__ = []

J
Jiawei Wang 已提交
32 33

class Momentum(Optimizer):
34
    r"""
J
Jiawei Wang 已提交
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58

    Simple Momentum optimizer with velocity state

    This optimizer has a flag for Nestrov Momentum.

    The update equations are as follows:

    .. math::

        & velocity = mu * velocity + gradient

        & if (use\_nesterov):

        &\quad   param = param - (gradient + mu * velocity) * learning\_rate

        & else:

        &\quad   param = param - learning\_rate * velocity

    Parameters:

        learning_rate (float|Tensor|LearningRateDecay, optional): The learning rate used to update ``Parameter``.
            It can be a float value, a ``Tensor`` with a float type or a LearningRateDecay. The default value is 0.001.
        momentum (float): Momentum factor. The default value is 0.9.
59 60 61 62 63
        parameters (list|tuple, optional): List|Tuple of ``Tensor`` to update to minimize ``loss``. \
            This parameter is required in dygraph mode. And you can specify different options for \
            different parameter groups such as the learning rate, weight decay, etc, \
            then the parameters are list of dict. Note that the learning_rate in paramter groups \
            represents the scale of base learning_rate. \
J
Jiawei Wang 已提交
64 65
            The default value is None in static mode, at this time all parameters will be updated.
        weight_decay (float|WeightDecayRegularizer, optional): The strategy of regularization. \
66 67 68 69 70 71
            It canbe a float value as coeff of L2 regularization or \
            :ref:`api_fluid_regularizer_L1Decay`, :ref:`api_fluid_regularizer_L2Decay`.
            If a parameter has set regularizer using :ref:`api_fluid_ParamAttr` already, \
            the regularization setting here in optimizer will be ignored for this parameter. \
            Otherwise, the regularization setting here in optimizer will take effect. \
            Default None, meaning there is no regularization.
J
Jiawei Wang 已提交
72 73 74 75
        grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of
            some derived class of ``GradientClipBase`` . There are three cliping strategies
            ( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` ,
            :ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping.
H
huangxu96 已提交
76 77 78
        multi_precision (bool, optional): Whether to use multi-precision during weight updating. Default is false.
        rescale_grad (float, optional): Multiply the gradient with `rescale_grad` before updating. \
            Often choose to be ``1.0/batch_size``.
79
        use_multi_tensor (bool, optional): Whether to use multi-tensor strategy to update all parameters at once . Default is false.
J
Jiawei Wang 已提交
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
        name (str, optional): The default value is None. Normally there is no need for user
                to set this property. For more information, please refer to
                :ref:`api_guide_Name` .

    Examples:
        .. code-block:: python

            import paddle
            import numpy as np
            inp = np.random.uniform(-0.1, 0.1, [10, 10]).astype("float32")
            linear = paddle.nn.Linear(10, 10)
            inp = paddle.to_tensor(inp)
            out = linear(inp)
            loss = paddle.mean(out)
            beta1 = paddle.to_tensor([0.9], dtype="float32")
            beta2 = paddle.to_tensor([0.99], dtype="float32")
            momentum = paddle.optimizer.Momentum(learning_rate=0.1, parameters=linear.parameters(), weight_decay=0.01)
            back = out.backward()
            momentum.step()
            momentum.clear_grad()
100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122

            #Note that the learning_rate of linear_2 is 0.01.
            linear_1 = paddle.nn.Linear(10, 10)
            linear_2 = paddle.nn.Linear(10, 10)
            inp = paddle.uniform(shape=[10, 10], min=-0.1, max=0.1)
            out = linear_1(inp)
            out = linear_2(out)
            loss = paddle.mean(out)
            momentum = paddle.optimizer.Momentum(
                learning_rate=0.1,
                parameters=[{
                    'params': linear_1.parameters()
                }, {
                    'params': linear_2.parameters(),
                    'weight_decay': 0.001,
                    'learning_rate': 0.1
                }],
                weight_decay=0.01,
                momentum=0.9)                   
            out.backward()
            momentum.step()
            momentum.clear_grad()

J
Jiawei Wang 已提交
123 124 125 126 127 128 129 130 131 132
    """
    _velocity_acc_str = "velocity"

    def __init__(self,
                 learning_rate=0.001,
                 momentum=0.9,
                 parameters=None,
                 use_nesterov=False,
                 weight_decay=None,
                 grad_clip=None,
H
huangxu96 已提交
133 134
                 multi_precision=False,
                 rescale_grad=1.0,
135
                 use_multi_tensor=False,
J
Jiawei Wang 已提交
136 137 138 139 140
                 name=None):
        if learning_rate is None:
            raise ValueError("learning_rate is not set")
        if momentum is None:
            raise ValueError("momentum is not set")
141

142 143
        predicate = lambda regular: isinstance(regular,
                                               (L2DecayRegularizer, float))
144 145 146 147 148 149 150 151 152 153 154
        if isinstance(parameters, list):
            if isinstance(parameters[0], dict):
                for param_group in parameters:
                    decay = param_group[
                        'weight_decay'] if 'weight_decay' in param_group else weight_decay
                    reg_method, reg_coeff = self._update_regularization(decay)
                    param_group['regularization_method'] = reg_method
                    param_group['regularization_coeff'] = reg_coeff
                    py_regular = None if predicate(decay) else decay
                    param_group['weight_decay'] = py_regular

H
huangxu96 已提交
155
        py_regular = None if predicate(weight_decay) else weight_decay
156 157 158 159 160
        super(Momentum, self).__init__(learning_rate=learning_rate,
                                       parameters=parameters,
                                       weight_decay=py_regular,
                                       grad_clip=grad_clip,
                                       name=name)
J
Jiawei Wang 已提交
161 162 163
        self.type = "momentum"
        self._momentum = momentum
        self._use_nesterov = bool(use_nesterov)
164 165
        self._regularization_method, self._regularization_coeff = self._update_regularization(
            weight_decay)
H
huangxu96 已提交
166 167 168 169
        self._multi_precision = multi_precision
        self._rescale_grad = rescale_grad
        self._master_weights = {}

170 171 172 173 174 175 176
        self._default_dict = {
            'momentum': momentum,
            'use_nesterov': use_nesterov,
            'rescale_grad': rescale_grad,
            'regularization_method': self._regularization_method,
            'regularization_coeff': self._regularization_coeff,
        }
177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192
        self._use_multi_tensor = use_multi_tensor
        if self._use_multi_tensor:
            self._param_dict = {'FP32_LODTensor': [], 'FP16_LODTensor': []}
            self._velocity_dict = {'FP32_LODTensor': [], 'FP16_LODTensor': []}
            self._master_weight_dict = {
                'FP32_LODTensor': None,
                'FP16_LODTensor': []
            }
            self._regularization_method_dict = {
                'FP32_LODTensor': [],
                'FP16_LODTensor': []
            }
            self._regularization_coeff_dict = {
                'FP32_LODTensor': [],
                'FP16_LODTensor': []
            }
193 194 195

    def _update_regularization(self, weight_decay):
        reg_method = ""
196
        reg_coeff = 0.0
197 198 199 200 201 202 203 204

        if (isinstance(weight_decay, L2DecayRegularizer)):
            reg_method = "l2_decay"
            reg_coeff = weight_decay._regularization_coeff
        if (isinstance(weight_decay, float)):
            reg_method = "l2_decay"
            reg_coeff = weight_decay
        return reg_method, reg_coeff
J
Jiawei Wang 已提交
205

H
huangxu96 已提交
206
    def _create_master_weight(self, param):
207 208 209 210 211 212 213
        if param.name in self._master_weights:
            var = self._master_weights[param.name]
        else:
            assert isinstance(self.helper, LayerHelper)

            var_name = param.name + "_fp32_master"
            var_name = unique_name.generate(var_name)
214 215 216 217 218
            var = layers.create_global_var(name=var_name,
                                           shape=param.shape,
                                           value=0,
                                           dtype='float32',
                                           persistable=True)
219
            block = self.helper.startup_program.global_block()
220 221 222 223 224 225 226
            block.append_op(type="cast",
                            inputs={"X": [param]},
                            outputs={"Out": [var]},
                            attrs={
                                "in_dtype": param.dtype,
                                "out_dtype": core.VarDesc.VarType.FP32
                            })
227
            self._master_weights[param.name] = var
H
huangxu96 已提交
228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245
        return var

    def _get_accumulator(self, name, param):
        """Utility function to fetch an accumulator for a parameter

        Args:
            name: name of the accumulator
            param: parameter variable for which accumulator is to be fetched

        Returns:
            accumulator variable for the parameter
        """
        if self._name is not None:
            name = self._name + "_" + name
        find_master = self._multi_precision and param.dtype == core.VarDesc.VarType.FP16
        target_param = self._master_weights[
            param.name] if find_master else param
        target_name = target_param.name
246 247 248 249 250
        if (name not in self._accumulators
                or target_name not in self._accumulators[name]):
            raise Exception(
                "Accumulator {} does not exist for parameter {}".format(
                    name, target_name))
H
huangxu96 已提交
251 252
        return self._accumulators[name][target_name]

J
Jiawei Wang 已提交
253
    def _create_accumulators(self, block, parameters):
254
        '''
J
Jiabin Yang 已提交
255
        if framework._non_static_mode():
256
            return
257
        '''
J
Jiawei Wang 已提交
258
        assert isinstance(block, framework.Block)
259 260 261 262

        if isinstance(parameters, dict):
            parameters = self._update_param_group(parameters)

263 264 265 266 267 268 269 270 271 272 273
        for p in parameters:
            if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16:
                master_p = self._create_master_weight(p)
                self._add_accumulator(self._velocity_acc_str, master_p)
                continue
            if p.dtype == core.VarDesc.VarType.FP16 and not self._multi_precision:
                warnings.warn(
                    "Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence."
                    "Consider using multi_precision=True option of the Momentum optimizer."
                )
            self._add_accumulator(self._velocity_acc_str, p)
J
Jiawei Wang 已提交
274

275 276 277 278 279 280 281 282 283 284 285 286 287
    def _create_regularization_of_grad(self, param, grad, regularization=None):
        """ Create and add backward regularization Operators
    
        Function helper of append_regularization_ops.
        """
        # If ParamAttr is set to L2Decay, we skip doing regularization here. And then we fused
        # L2Decay with momentum which can refer to _append_optimize_op below.
        if hasattr(param, 'regularizer') and isinstance(param.regularizer,
                                                        L2DecayRegularizer):
            return grad
        return super(Momentum, self)._create_regularization_of_grad(
            param, grad, regularization)

J
Jiawei Wang 已提交
288 289
    def _append_optimize_op(self, block, param_and_grad):
        assert isinstance(block, framework.Block)
290 291
        if isinstance(param_and_grad, dict):
            param_and_grad = self._update_param_group(param_and_grad)
J
Jiawei Wang 已提交
292 293 294 295 296

        velocity_acc = self._get_accumulator(self._velocity_acc_str,
                                             param_and_grad[0])
        lr = self._create_param_lr(param_and_grad)

297
        # For fusion of momentum and l2decay
298 299 300 301 302 303 304 305 306 307 308
        param = param_and_grad[0]
        regularization_method = self._regularization_method
        regularization_coeff = self._regularization_coeff
        if hasattr(param, 'regularizer'):
            # we skip param's l2decay before, so fuse it with momentum here.
            if isinstance(param.regularizer, L2DecayRegularizer):
                regularization_method = "l2_decay"
                regularization_coeff = param.regularizer._regularization_coeff
            # the param's regularization has been done before, we avoid do l2decay in momentum.
            elif param.regularizer is not None:
                regularization_method = ""
309
                regularization_coeff = 0.0
310

311 312 313 314 315
        find_master = self._multi_precision and param_and_grad[
            0].dtype == core.VarDesc.VarType.FP16
        master_weight = (self._master_weights[param_and_grad[0].name]
                         if find_master else None)

316
        if _in_legacy_dygraph():
317 318
            if isinstance(param_and_grad, dict):
                self._update_regularization(param_and_grad['weight_decay'])
319
            _, _, _ = _C_ops.momentum(
H
huangxu96 已提交
320
                param_and_grad[0], param_and_grad[1], velocity_acc, lr,
321 322 323 324 325 326
                master_weight, param_and_grad[0], velocity_acc, master_weight,
                'mu', self._momentum, 'use_nesterov', self._use_nesterov,
                'regularization_method', regularization_method,
                'regularization_coeff', regularization_coeff, 'multi_precision',
                find_master)
            return None
327 328 329 330 331 332 333 334
        if in_dygraph_mode():
            if isinstance(param_and_grad, dict):
                self._update_regularization(param_and_grad['weight_decay'])
            return _C_ops.final_state_momentum(
                param_and_grad[0], param_and_grad[1], velocity_acc, lr,
                master_weight, self._momentum, self._use_nesterov,
                regularization_method, regularization_coeff, find_master,
                self._rescale_grad)
335

H
huangxu96 已提交
336 337 338
        attrs = {
            "mu": self._momentum,
            "use_nesterov": self._use_nesterov,
339 340
            "regularization_method": regularization_method,
            "regularization_coeff": regularization_coeff,
H
huangxu96 已提交
341 342 343 344
            "multi_precision": find_master,
            "rescale_grad": self._rescale_grad
        }

J
Jiawei Wang 已提交
345 346 347 348 349 350 351 352 353 354 355
        inputs = {
            "Param": [param_and_grad[0]],
            "Grad": [param_and_grad[1]],
            "Velocity": [velocity_acc],
            "LearningRate": [lr]
        }

        outputs = {
            "ParamOut": [param_and_grad[0]],
            "VelocityOut": [velocity_acc]
        }
H
huangxu96 已提交
356 357 358 359 360

        if find_master:
            inputs["MasterParam"] = master_weight
            outputs["MasterParamOut"] = master_weight

J
Jiawei Wang 已提交
361
        # create the momentum optimize op
362 363 364 365 366
        momentum_op = block.append_op(type=self.type,
                                      inputs=inputs,
                                      outputs=outputs,
                                      attrs=attrs,
                                      stop_gradient=True)
J
Jiawei Wang 已提交
367 368

        return momentum_op
369

370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388
    def _multi_tensor_init(self, target_block, parameters):
        """
        All parameters used for optimizer (such as: parameters, master_weight, velocity_acc for momentum) calculations are grouped into a python list by data type (float16, float32).
        This function will be overridden in the corresponding optimizer file.

        Args:
            target_block: the block in which the loss tensor is present
            parameters: list of parameter tensors for the optimizer
        """
        self._create_accumulators(target_block, parameters)
        for param in parameters:
            velocity_acc = self._get_accumulator(self._velocity_acc_str, param)
            regularization_method = self._regularization_method
            regularization_coeff = self._regularization_coeff
            if hasattr(param, 'regularizer'):
                # we skip param's l2decay before, so fuse it with momentum here.
                if isinstance(param.regularizer, L2DecayRegularizer):
                    regularization_method = "l2_decay"
                    regularization_coeff = param.regularizer._regularization_coeff
389
                elif param.regularizer is not None:
390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472
                    regularization_method = ""
                    regularization_coeff = 0.0
            if param.dtype == paddle.float32:
                self._param_dict['FP32_LODTensor'].append(param)
                self._velocity_dict['FP32_LODTensor'].append(velocity_acc)
                # fp32 no master weight
                self._regularization_method_dict['FP32_LODTensor'].append(
                    regularization_method)
                self._regularization_coeff_dict['FP32_LODTensor'].append(
                    regularization_coeff)
            elif param.dtype == paddle.float16:
                self._param_dict['FP16_LODTensor'].append(param)
                self._velocity_dict['FP16_LODTensor'].append(velocity_acc)
                if self._multi_precision:
                    self._master_weight_dict['FP16_LODTensor'].append(
                        self._master_weights[param.name])
                else:
                    self._master_weight_dict['FP16_LODTensor'] = None
                self._regularization_method_dict['FP16_LODTensor'].append(
                    regularization_method)
                self._regularization_coeff_dict['FP16_LODTensor'].append(
                    regularization_coeff)
            else:
                raise ValueError(
                    "Now multi_tensor_momentum only support fp32 and fp16 parameters and grad is LOD_TENSOR."
                )

    def _append_optimize_multi_tensor_op(self, target_block,
                                         parameters_and_grads):
        """ 
        For Multi Tensor, append optimize merged_operator to block.
        """
        assert isinstance(target_block, framework.Block)

        grad_dict = {'FP32_LODTensor': [], 'FP16_LODTensor': []}
        lr_dict = {'FP32_LODTensor': [], 'FP16_LODTensor': []}

        if isinstance(parameters_and_grads, list):
            for param_and_grad in parameters_and_grads:
                if param_and_grad[1] is None:
                    continue
                if param_and_grad[0].stop_gradient is False:
                    if param_and_grad[
                            0].dtype == paddle.float32 and param_and_grad[
                                1].type == core.VarDesc.VarType.LOD_TENSOR:
                        grad_dict['FP32_LODTensor'].append(param_and_grad[1])
                        lr = self._create_param_lr(param_and_grad)
                        lr_dict['FP32_LODTensor'].append(lr)
                    elif param_and_grad[
                            0].dtype == paddle.float16 and param_and_grad[
                                1].type == core.VarDesc.VarType.LOD_TENSOR:
                        grad_dict['FP16_LODTensor'].append(param_and_grad[1])
                        lr = self._create_param_lr(param_and_grad)
                        lr_dict['FP16_LODTensor'].append(lr)
        else:
            for param_and_grad in parameters_and_grads['params']:
                if param_and_grad[1] is None:
                    continue
                if param_and_grad[0].stop_gradient is False:
                    param_grad_dict = dict()
                    param_grad_dict['params'] = param_and_grad
                    param_grad_dict.update({
                        k: v
                        for k, v in parameters_and_grads.items()
                        if k != 'params'
                    })
                    param_and_grad = self._update_param_group(param_grad_dict)
                    if param_and_grad[
                            0].dtype == paddle.float32 and param_and_grad[
                                1].type == core.VarDesc.VarType.LOD_TENSOR:
                        grad_dict['FP32_LODTensor'].append(param_and_grad[1])
                        lr = self._create_param_lr(param_and_grad)
                        lr_dict['FP32_LODTensor'].append(lr)
                    elif param_and_grad[
                            0].dtype == paddle.float16 and param_and_grad[
                                1].type == core.VarDesc.VarType.LOD_TENSOR:
                        grad_dict['FP16_LODTensor'].append(param_and_grad[1])
                        lr = self._create_param_lr(param_and_grad)
                        lr_dict['FP16_LODTensor'].append(lr)

        multi_tensor_list = ['FP32_LODTensor', 'FP16_LODTensor']
        for key in multi_tensor_list:
            if len(self._param_dict[key]) > 0:
473
                find_master = self._multi_precision and key == 'FP16_LODTensor'
474

J
Jiabin Yang 已提交
475
                if framework._non_static_mode():
476 477 478 479 480 481 482 483 484 485
                    _, _, _ = _C_ops.merged_momentum(
                        self._param_dict[key], grad_dict[key],
                        self._velocity_dict[key], lr_dict[key],
                        self._master_weight_dict[key], self._param_dict[key],
                        self._velocity_dict[key], self._master_weight_dict[key],
                        'mu', self._momentum, 'use_nesterov',
                        self._use_nesterov, 'regularization_method',
                        self._regularization_method_dict[key],
                        'regularization_coeff',
                        self._regularization_coeff_dict[key], 'multi_precision',
486
                        find_master)
487 488 489 490 491 492 493 494 495 496 497 498
                else:
                    inputs = {
                        "Param": self._param_dict[key],
                        "Grad": grad_dict[key],
                        "Velocity": self._velocity_dict[key],
                        "LearningRate": lr_dict[key],
                    }
                    outputs = {
                        "ParamOut": self._param_dict[key],
                        "VelocityOut": self._velocity_dict[key],
                    }
                    attrs = {
499 500 501 502
                        "mu":
                        self._momentum,
                        "use_nesterov":
                        self._use_nesterov,
503 504 505 506 507
                        "regularization_method":
                        self._regularization_method_dict[key],
                        "regularization_coeff":
                        self._regularization_coeff_dict[key],
                    }
508
                    if find_master:
509 510 511
                        inputs["MasterParam"] = self._master_weight_dict[key]
                        outputs["MasterParamOut"] = self._master_weight_dict[
                            key]
512
                        attrs["multi_precision"] = find_master
513 514 515 516 517
                    target_block.append_op(type="merged_momentum",
                                           inputs=inputs,
                                           outputs=outputs,
                                           attrs=attrs,
                                           stop_gradient=True)
518 519
        return None

520 521 522 523 524 525 526 527 528 529 530 531 532 533
    def _update_param_group(self, parameters):
        self._momentum = parameters.get('momentum',
                                        self._default_dict['momentum'])
        self._use_nesterov = parameters.get('use_nesterov',
                                            self._default_dict['use_nesterov'])
        self._rescale_grad = parameters.get('rescale_grad',
                                            self._default_dict['rescale_grad'])
        self._regularization_method = parameters.get(
            'regularization_method',
            self._default_dict['regularization_method'])
        self._regularization_coeff = parameters.get(
            'regularization_coeff', self._default_dict['regularization_coeff'])
        parameters = parameters.get('params')
        return parameters