momentum.py 14.9 KB
Newer Older
J
Jiawei Wang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

J
Jiangxinz 已提交
15 16
import warnings

J
Jiawei Wang 已提交
17 18 19 20
from .optimizer import Optimizer
from ..fluid import core
from ..fluid import framework
from ..fluid.framework import Variable, name_scope
21
from ..fluid.layer_helper import LayerHelper
H
huangxu96 已提交
22 23
from ..fluid import unique_name
from ..fluid import layers
24
import paddle.fluid as fluid
H
huangxu96 已提交
25
from paddle.fluid.regularizer import L2DecayRegularizer
W
wanghuancoder 已提交
26
from paddle import _C_ops
J
Jiawei Wang 已提交
27

28 29
__all__ = []

J
Jiawei Wang 已提交
30 31

class Momentum(Optimizer):
32
    r"""
J
Jiawei Wang 已提交
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56

    Simple Momentum optimizer with velocity state

    This optimizer has a flag for Nestrov Momentum.

    The update equations are as follows:

    .. math::

        & velocity = mu * velocity + gradient

        & if (use\_nesterov):

        &\quad   param = param - (gradient + mu * velocity) * learning\_rate

        & else:

        &\quad   param = param - learning\_rate * velocity

    Parameters:

        learning_rate (float|Tensor|LearningRateDecay, optional): The learning rate used to update ``Parameter``.
            It can be a float value, a ``Tensor`` with a float type or a LearningRateDecay. The default value is 0.001.
        momentum (float): Momentum factor. The default value is 0.9.
57 58 59 60 61
        parameters (list|tuple, optional): List|Tuple of ``Tensor`` to update to minimize ``loss``. \
            This parameter is required in dygraph mode. And you can specify different options for \
            different parameter groups such as the learning rate, weight decay, etc, \
            then the parameters are list of dict. Note that the learning_rate in paramter groups \
            represents the scale of base learning_rate. \
J
Jiawei Wang 已提交
62 63
            The default value is None in static mode, at this time all parameters will be updated.
        weight_decay (float|WeightDecayRegularizer, optional): The strategy of regularization. \
64 65 66 67 68 69
            It canbe a float value as coeff of L2 regularization or \
            :ref:`api_fluid_regularizer_L1Decay`, :ref:`api_fluid_regularizer_L2Decay`.
            If a parameter has set regularizer using :ref:`api_fluid_ParamAttr` already, \
            the regularization setting here in optimizer will be ignored for this parameter. \
            Otherwise, the regularization setting here in optimizer will take effect. \
            Default None, meaning there is no regularization.
J
Jiawei Wang 已提交
70 71 72 73
        grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of
            some derived class of ``GradientClipBase`` . There are three cliping strategies
            ( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` ,
            :ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping.
H
huangxu96 已提交
74 75 76
        multi_precision (bool, optional): Whether to use multi-precision during weight updating. Default is false.
        rescale_grad (float, optional): Multiply the gradient with `rescale_grad` before updating. \
            Often choose to be ``1.0/batch_size``.
J
Jiawei Wang 已提交
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
        name (str, optional): The default value is None. Normally there is no need for user
                to set this property. For more information, please refer to
                :ref:`api_guide_Name` .

    Examples:
        .. code-block:: python

            import paddle
            import numpy as np
            inp = np.random.uniform(-0.1, 0.1, [10, 10]).astype("float32")
            linear = paddle.nn.Linear(10, 10)
            inp = paddle.to_tensor(inp)
            out = linear(inp)
            loss = paddle.mean(out)
            beta1 = paddle.to_tensor([0.9], dtype="float32")
            beta2 = paddle.to_tensor([0.99], dtype="float32")
            momentum = paddle.optimizer.Momentum(learning_rate=0.1, parameters=linear.parameters(), weight_decay=0.01)
            back = out.backward()
            momentum.step()
            momentum.clear_grad()
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119

            #Note that the learning_rate of linear_2 is 0.01.
            linear_1 = paddle.nn.Linear(10, 10)
            linear_2 = paddle.nn.Linear(10, 10)
            inp = paddle.uniform(shape=[10, 10], min=-0.1, max=0.1)
            out = linear_1(inp)
            out = linear_2(out)
            loss = paddle.mean(out)
            momentum = paddle.optimizer.Momentum(
                learning_rate=0.1,
                parameters=[{
                    'params': linear_1.parameters()
                }, {
                    'params': linear_2.parameters(),
                    'weight_decay': 0.001,
                    'learning_rate': 0.1
                }],
                weight_decay=0.01,
                momentum=0.9)                   
            out.backward()
            momentum.step()
            momentum.clear_grad()

J
Jiawei Wang 已提交
120 121 122 123 124 125 126 127 128 129
    """
    _velocity_acc_str = "velocity"

    def __init__(self,
                 learning_rate=0.001,
                 momentum=0.9,
                 parameters=None,
                 use_nesterov=False,
                 weight_decay=None,
                 grad_clip=None,
H
huangxu96 已提交
130 131
                 multi_precision=False,
                 rescale_grad=1.0,
J
Jiawei Wang 已提交
132 133 134 135 136
                 name=None):
        if learning_rate is None:
            raise ValueError("learning_rate is not set")
        if momentum is None:
            raise ValueError("momentum is not set")
137

138
        predicate = lambda regular: isinstance(regular, (L2DecayRegularizer, float))
139 140 141 142 143 144 145 146 147 148 149
        if isinstance(parameters, list):
            if isinstance(parameters[0], dict):
                for param_group in parameters:
                    decay = param_group[
                        'weight_decay'] if 'weight_decay' in param_group else weight_decay
                    reg_method, reg_coeff = self._update_regularization(decay)
                    param_group['regularization_method'] = reg_method
                    param_group['regularization_coeff'] = reg_coeff
                    py_regular = None if predicate(decay) else decay
                    param_group['weight_decay'] = py_regular

H
huangxu96 已提交
150
        py_regular = None if predicate(weight_decay) else weight_decay
J
Jiawei Wang 已提交
151 152 153
        super(Momentum, self).__init__(
            learning_rate=learning_rate,
            parameters=parameters,
H
huangxu96 已提交
154
            weight_decay=py_regular,
J
Jiawei Wang 已提交
155 156 157 158 159
            grad_clip=grad_clip,
            name=name)
        self.type = "momentum"
        self._momentum = momentum
        self._use_nesterov = bool(use_nesterov)
160 161
        self._regularization_method, self._regularization_coeff = self._update_regularization(
            weight_decay)
H
huangxu96 已提交
162 163 164 165
        self._multi_precision = multi_precision
        self._rescale_grad = rescale_grad
        self._master_weights = {}

166 167 168 169 170 171 172 173
        self._default_dict = {
            'momentum': momentum,
            'use_nesterov': use_nesterov,
            'rescale_grad': rescale_grad,
            'regularization_method': self._regularization_method,
            'regularization_coeff': self._regularization_coeff,
        }

174 175
        if framework.in_dygraph_mode():
            self.helper = LayerHelper(self.__class__.__name__)
176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
            if isinstance(self._parameter_list[0], dict):
                for parameters in self._param_groups:
                    for p in parameters['params']:
                        self._add_accumulator(self._velocity_acc_str, p)
            else:
                for p in parameters:
                    self._add_accumulator(self._velocity_acc_str, p)

    def _update_regularization(self, weight_decay):
        reg_method = ""
        reg_coeff = 0

        if (isinstance(weight_decay, L2DecayRegularizer)):
            reg_method = "l2_decay"
            reg_coeff = weight_decay._regularization_coeff
        if (isinstance(weight_decay, float)):
            reg_method = "l2_decay"
            reg_coeff = weight_decay
        return reg_method, reg_coeff
J
Jiawei Wang 已提交
195

H
huangxu96 已提交
196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240
    def _create_master_weight(self, param):
        assert isinstance(self.helper, LayerHelper)

        var_name = param.name + "_fp32_master"
        var_name = unique_name.generate(var_name)
        var = layers.create_global_var(
            name=var_name,
            shape=param.shape,
            value=0,
            dtype='float32',
            persistable=True)
        block = self.helper.startup_program.global_block()
        block.append_op(
            type="cast",
            inputs={"X": [param]},
            outputs={"Out": [var]},
            attrs={
                "in_dtype": param.dtype,
                "out_dtype": core.VarDesc.VarType.FP32
            })
        self._master_weights[param.name] = var
        return var

    def _get_accumulator(self, name, param):
        """Utility function to fetch an accumulator for a parameter

        Args:
            name: name of the accumulator
            param: parameter variable for which accumulator is to be fetched

        Returns:
            accumulator variable for the parameter
        """
        if self._name is not None:
            name = self._name + "_" + name
        find_master = self._multi_precision and param.dtype == core.VarDesc.VarType.FP16
        target_param = self._master_weights[
            param.name] if find_master else param
        target_name = target_param.name
        if (name not in self._accumulators or
                target_name not in self._accumulators[name]):
            raise Exception("Accumulator {} does not exist for parameter {}".
                            format(name, target_name))
        return self._accumulators[name][target_name]

J
Jiawei Wang 已提交
241
    def _create_accumulators(self, block, parameters):
242 243 244
        if framework.in_dygraph_mode():
            return

J
Jiawei Wang 已提交
245
        assert isinstance(block, framework.Block)
246 247 248 249 250 251 252 253 254 255 256
        for p in parameters:
            if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16:
                master_p = self._create_master_weight(p)
                self._add_accumulator(self._velocity_acc_str, master_p)
                continue
            if p.dtype == core.VarDesc.VarType.FP16 and not self._multi_precision:
                warnings.warn(
                    "Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence."
                    "Consider using multi_precision=True option of the Momentum optimizer."
                )
            self._add_accumulator(self._velocity_acc_str, p)
J
Jiawei Wang 已提交
257

258 259 260 261 262 263 264 265 266 267 268 269 270
    def _create_regularization_of_grad(self, param, grad, regularization=None):
        """ Create and add backward regularization Operators
    
        Function helper of append_regularization_ops.
        """
        # If ParamAttr is set to L2Decay, we skip doing regularization here. And then we fused
        # L2Decay with momentum which can refer to _append_optimize_op below.
        if hasattr(param, 'regularizer') and isinstance(param.regularizer,
                                                        L2DecayRegularizer):
            return grad
        return super(Momentum, self)._create_regularization_of_grad(
            param, grad, regularization)

J
Jiawei Wang 已提交
271 272
    def _append_optimize_op(self, block, param_and_grad):
        assert isinstance(block, framework.Block)
273 274
        if isinstance(param_and_grad, dict):
            param_and_grad = self._update_param_group(param_and_grad)
J
Jiawei Wang 已提交
275 276 277 278 279

        velocity_acc = self._get_accumulator(self._velocity_acc_str,
                                             param_and_grad[0])
        lr = self._create_param_lr(param_and_grad)

280 281 282 283 284 285 286 287 288 289 290 291 292 293
        # For fusion of momentum and l2decay 
        param = param_and_grad[0]
        regularization_method = self._regularization_method
        regularization_coeff = self._regularization_coeff
        if hasattr(param, 'regularizer'):
            # we skip param's l2decay before, so fuse it with momentum here.
            if isinstance(param.regularizer, L2DecayRegularizer):
                regularization_method = "l2_decay"
                regularization_coeff = param.regularizer._regularization_coeff
            # the param's regularization has been done before, we avoid do l2decay in momentum.
            elif param.regularizer is not None:
                regularization_method = ""
                regularization_coeff = 0

J
Jiawei Wang 已提交
294
        if framework.in_dygraph_mode():
295 296
            if isinstance(param_and_grad, dict):
                self._update_regularization(param_and_grad['weight_decay'])
W
wanghuancoder 已提交
297
            _, _ = _C_ops.momentum(
H
huangxu96 已提交
298 299 300
                param_and_grad[0], param_and_grad[1], velocity_acc, lr,
                param_and_grad[0], velocity_acc, 'mu', self._momentum,
                'use_nesterov', self._use_nesterov, 'regularization_method',
301 302
                regularization_method, 'regularization_coeff',
                regularization_coeff)
J
Jiawei Wang 已提交
303 304
            return None

305 306 307 308 309
        find_master = self._multi_precision and param_and_grad[
            0].dtype == core.VarDesc.VarType.FP16
        master_weight = (self._master_weights[param_and_grad[0].name]
                         if find_master else None)

H
huangxu96 已提交
310 311 312
        attrs = {
            "mu": self._momentum,
            "use_nesterov": self._use_nesterov,
313 314
            "regularization_method": regularization_method,
            "regularization_coeff": regularization_coeff,
H
huangxu96 已提交
315 316 317 318
            "multi_precision": find_master,
            "rescale_grad": self._rescale_grad
        }

J
Jiawei Wang 已提交
319 320 321 322 323 324 325 326 327 328 329
        inputs = {
            "Param": [param_and_grad[0]],
            "Grad": [param_and_grad[1]],
            "Velocity": [velocity_acc],
            "LearningRate": [lr]
        }

        outputs = {
            "ParamOut": [param_and_grad[0]],
            "VelocityOut": [velocity_acc]
        }
H
huangxu96 已提交
330 331 332 333 334

        if find_master:
            inputs["MasterParam"] = master_weight
            outputs["MasterParamOut"] = master_weight

J
Jiawei Wang 已提交
335 336 337 338 339 340 341 342 343
        # create the momentum optimize op
        momentum_op = block.append_op(
            type=self.type,
            inputs=inputs,
            outputs=outputs,
            attrs=attrs,
            stop_gradient=True)

        return momentum_op
344 345 346 347 348 349 350 351 352 353 354 355 356 357 358

    def _update_param_group(self, parameters):
        self._momentum = parameters.get('momentum',
                                        self._default_dict['momentum'])
        self._use_nesterov = parameters.get('use_nesterov',
                                            self._default_dict['use_nesterov'])
        self._rescale_grad = parameters.get('rescale_grad',
                                            self._default_dict['rescale_grad'])
        self._regularization_method = parameters.get(
            'regularization_method',
            self._default_dict['regularization_method'])
        self._regularization_coeff = parameters.get(
            'regularization_coeff', self._default_dict['regularization_coeff'])
        parameters = parameters.get('params')
        return parameters