adagrad.py 7.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from ..fluid import framework
16
from .optimizer import Optimizer
17

18 19
__all__ = []

20 21

class Adagrad(Optimizer):
22
    r"""
23
    The Adaptive Gradient optimizer (Adagrad for short) use an optimization described
24 25 26 27 28 29 30 31 32
    in paper: `Adaptive Subgradient Methods for Online Learning and
    Stochastic Optimization <http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf>`_.

    The parameter ``param_out`` update rule with gradient ``grad``:

    .. math::

        moment\_out &= moment + grad * grad

33
        param\_out &= param - \frac{learning\_rate * grad}{\sqrt{moment\_out} + \epsilon}
34 35 36 37 38 39 40 41 42 43 44 45


    The original paper does not have the ``epsilon`` attribute. It is added here
    in our implementation as also proposed `Per-parameter adaptive learning rate
    methods <http://cs231n.github.io/neural-networks-3/#ada>`_
    for numerical stability to avoid the division by zero error.

    Args:
        learning_rate (float|Tensor): The learning rate used to update ``Parameter``.
            It can be a float value or a ``Variable`` with a float type.
        epsilon (float, optional): A small float value for numerical stability.
            The default value is 1e-06.
46 47 48 49 50
        parameters (list|tuple, optional): List/Tuple of ``Tensor`` to update to minimize ``loss``.
            This parameter is required in dygraph mode. And you can specify different options for
            different parameter groups such as the learning rate, weight decay, etc,
            then the parameters are list of dict. Note that the learning_rate in paramter groups
            represents the scale of base learning_rate.
51
            The default value is None in static graph mode, at this time all parameters will be updated.
52 53 54 55 56 57 58
        weight_decay (float|WeightDecayRegularizer, optional): The strategy of regularization.
            It canbe a float value as coeff of L2 regularization or
            :ref:`api_paddle_regularizer_L1Decay`, :ref:`api_paddle_regularizer_L2Decay`.
            If a parameter has set regularizer using :ref:`api_paddle_fluid_param_attr_aramAttr` already,
            the regularization setting here in optimizer will be ignored for this parameter.
            Otherwise, the regularization setting here in optimizer will take effect.
            Default None, meaning there is no regularization.
59 60 61
        grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of
            some derived class of ``GradientClipBase`` . There are three cliping strategies,
            ClipGradByGlobalNorm, ClipGradByNorm and ClipGradByValue. Default None,
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
            meaning there is no gradient clipping.
        name (str, optional): Normally there is no need for user to set this property.
            For more information, please refer to :ref:`api_guide_Name`.
            The default value is None.
        initial_accumulator_value (float, optional): Initial value for moment accumulator.
            The default value is 0.0.

    Examples:
        .. code-block:: python

            import paddle

            inp = paddle.rand(shape=[10, 10])
            linear = paddle.nn.Linear(10, 10)
            out = linear(inp)
            loss = paddle.mean(out)
            adagrad = paddle.optimizer.Adagrad(learning_rate=0.1,
                    parameters=linear.parameters())
            out.backward()
            adagrad.step()
            adagrad.clear_grad()

84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
            #Note that the learning_rate of linear_2 is 0.01.
            linear_1 = paddle.nn.Linear(10, 10)
            linear_2 = paddle.nn.Linear(10, 10)
            inp = paddle.uniform(shape=[10, 10], min=-0.1, max=0.1)
            out = linear_1(inp)
            out = linear_2(out)
            loss = paddle.mean(out)
            adagrad = paddle.optimizer.Adagrad(
                learning_rate=0.1,
                parameters=[{
                    'params': linear_1.parameters()
                }, {
                    'params': linear_2.parameters(),
                    'weight_decay': 0.001,
                    'learning_rate': 0.1,
                }],
100
                weight_decay=0.01)
101 102 103 104
            out.backward()
            adagrad.step()
            adagrad.clear_grad()

105 106 107
    """
    _moment_acc_str = "moment"

108 109 110 111 112 113 114 115 116 117
    def __init__(
        self,
        learning_rate,
        epsilon=1.0e-6,
        parameters=None,
        weight_decay=None,
        grad_clip=None,
        name=None,
        initial_accumulator_value=0.0,
    ):
118 119
        assert learning_rate is not None
        assert epsilon is not None
120
        super().__init__(
121 122 123 124 125 126
            learning_rate=learning_rate,
            parameters=parameters,
            weight_decay=weight_decay,
            grad_clip=grad_clip,
            name=name,
        )
127 128 129
        self.type = "adagrad"
        self._epsilon = epsilon
        self.initial_accumulator_value = initial_accumulator_value
130 131 132 133
        self._default_dict = {
            'epsilon': epsilon,
            'initial_accumulator_value': initial_accumulator_value,
        }
134 135 136 137

    def _create_accumulators(self, block, parameters):
        assert isinstance(block, framework.Block)

138 139 140
        if isinstance(parameters, dict):
            parameters = self._update_param_group(parameters)

141
        for p in parameters:
142 143 144 145 146
            self._add_accumulator(
                self._moment_acc_str,
                p,
                fill_value=self.initial_accumulator_value,
            )
147 148 149 150

    def _append_optimize_op(self, block, param_and_grad):
        assert isinstance(block, framework.Block)

151 152 153
        if isinstance(param_and_grad, dict):
            param_and_grad = self._update_param_group(param_and_grad)

154 155 156
        moment_acc = self._get_accumulator(
            self._moment_acc_str, param_and_grad[0]
        )
157
        # Create the adagrad optimizer op
158 159 160 161 162 163 164 165 166 167 168 169
        adagrad_op = block.append_op(
            type=self.type,
            inputs={
                "Param": param_and_grad[0],
                "Grad": param_and_grad[1],
                "Moment": moment_acc,
                "LearningRate": self._create_param_lr(param_and_grad),
            },
            outputs={"ParamOut": param_and_grad[0], "MomentOut": moment_acc},
            attrs={"epsilon": self._epsilon},
            stop_gradient=True,
        )
170 171

        return adagrad_op
172 173 174 175 176

    def _update_param_group(self, parameters):
        self._epsilon = parameters.get('epsilon', self._default_dict['epsilon'])
        self.initial_accumulator_value = parameters.get(
            'initial_accumulator_value',
177 178
            self._default_dict['initial_accumulator_value'],
        )
179 180
        parameters = parameters.get('params')
        return parameters