adagrad.py 8.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
import warnings
15

16
from ..fluid import framework
17
from .optimizer import Optimizer
18

19 20
__all__ = []

21 22

class Adagrad(Optimizer):
23
    r"""
24
    The Adaptive Gradient optimizer (Adagrad for short) use an optimization described
25 26 27 28 29 30 31 32 33
    in paper: `Adaptive Subgradient Methods for Online Learning and
    Stochastic Optimization <http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf>`_.

    The parameter ``param_out`` update rule with gradient ``grad``:

    .. math::

        moment\_out &= moment + grad * grad

34
        param\_out &= param - \frac{learning\_rate * grad}{\sqrt{moment\_out} + \epsilon}
35 36 37 38 39 40 41 42 43 44 45 46


    The original paper does not have the ``epsilon`` attribute. It is added here
    in our implementation as also proposed `Per-parameter adaptive learning rate
    methods <http://cs231n.github.io/neural-networks-3/#ada>`_
    for numerical stability to avoid the division by zero error.

    Args:
        learning_rate (float|Tensor): The learning rate used to update ``Parameter``.
            It can be a float value or a ``Variable`` with a float type.
        epsilon (float, optional): A small float value for numerical stability.
            The default value is 1e-06.
47 48 49 50 51
        parameters (list|tuple, optional): List/Tuple of ``Tensor`` to update to minimize ``loss``.
            This parameter is required in dygraph mode. And you can specify different options for
            different parameter groups such as the learning rate, weight decay, etc,
            then the parameters are list of dict. Note that the learning_rate in paramter groups
            represents the scale of base learning_rate.
52
            The default value is None in static graph mode, at this time all parameters will be updated.
53 54 55 56 57 58 59
        weight_decay (float|WeightDecayRegularizer, optional): The strategy of regularization.
            It canbe a float value as coeff of L2 regularization or
            :ref:`api_paddle_regularizer_L1Decay`, :ref:`api_paddle_regularizer_L2Decay`.
            If a parameter has set regularizer using :ref:`api_paddle_fluid_param_attr_aramAttr` already,
            the regularization setting here in optimizer will be ignored for this parameter.
            Otherwise, the regularization setting here in optimizer will take effect.
            Default None, meaning there is no regularization.
60 61 62
        grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of
            some derived class of ``GradientClipBase`` . There are three cliping strategies,
            ClipGradByGlobalNorm, ClipGradByNorm and ClipGradByValue. Default None,
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
            meaning there is no gradient clipping.
        name (str, optional): Normally there is no need for user to set this property.
            For more information, please refer to :ref:`api_guide_Name`.
            The default value is None.
        initial_accumulator_value (float, optional): Initial value for moment accumulator.
            The default value is 0.0.

    Examples:
        .. code-block:: python

            import paddle

            inp = paddle.rand(shape=[10, 10])
            linear = paddle.nn.Linear(10, 10)
            out = linear(inp)
            loss = paddle.mean(out)
            adagrad = paddle.optimizer.Adagrad(learning_rate=0.1,
                    parameters=linear.parameters())
            out.backward()
            adagrad.step()
            adagrad.clear_grad()

85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
            #Note that the learning_rate of linear_2 is 0.01.
            linear_1 = paddle.nn.Linear(10, 10)
            linear_2 = paddle.nn.Linear(10, 10)
            inp = paddle.uniform(shape=[10, 10], min=-0.1, max=0.1)
            out = linear_1(inp)
            out = linear_2(out)
            loss = paddle.mean(out)
            adagrad = paddle.optimizer.Adagrad(
                learning_rate=0.1,
                parameters=[{
                    'params': linear_1.parameters()
                }, {
                    'params': linear_2.parameters(),
                    'weight_decay': 0.001,
                    'learning_rate': 0.1,
                }],
101
                weight_decay=0.01)
102 103 104 105
            out.backward()
            adagrad.step()
            adagrad.clear_grad()

106 107 108
    """
    _moment_acc_str = "moment"

109 110 111 112 113 114 115 116 117 118
    def __init__(
        self,
        learning_rate,
        epsilon=1.0e-6,
        parameters=None,
        weight_decay=None,
        grad_clip=None,
        name=None,
        initial_accumulator_value=0.0,
    ):
119 120
        assert learning_rate is not None
        assert epsilon is not None
121
        super().__init__(
122 123 124 125 126 127
            learning_rate=learning_rate,
            parameters=parameters,
            weight_decay=weight_decay,
            grad_clip=grad_clip,
            name=name,
        )
128 129
        self.type = "adagrad"
        self._epsilon = epsilon
130 131
        self._multi_precision = False
        self._master_weights = {}
132
        self.initial_accumulator_value = initial_accumulator_value
133 134 135 136
        self._default_dict = {
            'epsilon': epsilon,
            'initial_accumulator_value': initial_accumulator_value,
        }
137 138 139 140

    def _create_accumulators(self, block, parameters):
        assert isinstance(block, framework.Block)

141 142 143
        if isinstance(parameters, dict):
            parameters = self._update_param_group(parameters)

144
        for p in parameters:
145
            if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype):
146 147 148 149
                master_p = self._create_master_weight(p)
                self._add_accumulator(self._moment_acc_str, master_p)
                continue
            if (
150
                self._is_dtype_fp16_or_bf16(p.dtype)
151 152 153
                and not self._multi_precision
            ):
                warnings.warn(
154
                    "Accumulating with FP16/BF16 in optimizer can lead to poor accuracy or slow convergence."
155 156
                    "Consider using multi_precision=True option of the Momentum optimizer."
                )
157 158 159 160 161
            self._add_accumulator(
                self._moment_acc_str,
                p,
                fill_value=self.initial_accumulator_value,
            )
162 163 164 165

    def _append_optimize_op(self, block, param_and_grad):
        assert isinstance(block, framework.Block)

166 167 168
        if isinstance(param_and_grad, dict):
            param_and_grad = self._update_param_group(param_and_grad)

169
        moment_acc = self._get_accumulator_master(
170 171
            self._moment_acc_str, param_and_grad[0]
        )
172

173 174
        find_master = self._multi_precision and self._is_dtype_fp16_or_bf16(
            param_and_grad[0].dtype
175 176 177 178 179 180 181 182
        )

        master_weight = (
            self._master_weights[param_and_grad[0].name]
            if find_master
            else None
        )

183
        # Create the adagrad optimizer op
184 185 186 187 188 189 190 191 192 193 194 195 196
        inputs = {
            "Param": param_and_grad[0],
            "Grad": param_and_grad[1],
            "Moment": moment_acc,
            "LearningRate": self._create_param_lr(param_and_grad),
        }

        outputs = {"ParamOut": param_and_grad[0], "MomentOut": moment_acc}

        if find_master:
            inputs["MasterParam"] = master_weight
            outputs["MasterParamOut"] = master_weight

197 198
        adagrad_op = block.append_op(
            type=self.type,
199 200 201
            inputs=inputs,
            outputs=outputs,
            attrs={"epsilon": self._epsilon, "multi_precision": find_master},
202 203
            stop_gradient=True,
        )
204 205

        return adagrad_op
206 207 208 209 210

    def _update_param_group(self, parameters):
        self._epsilon = parameters.get('epsilon', self._default_dict['epsilon'])
        self.initial_accumulator_value = parameters.get(
            'initial_accumulator_value',
211 212
            self._default_dict['initial_accumulator_value'],
        )
213 214
        parameters = parameters.get('params')
        return parameters