adadelta.py 9.5 KB
Newer Older
J
Jiawei Wang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
import warnings

17
from paddle import _C_ops
18

19
from ..fluid import framework
20
from ..fluid.dygraph import no_grad
21 22
from ..framework import in_dygraph_mode
from .optimizer import Optimizer
J
Jiawei Wang 已提交
23

24 25
__all__ = []

J
Jiawei Wang 已提交
26 27

class Adadelta(Optimizer):
28
    r"""
J
Jiawei Wang 已提交
29 30 31 32 33 34 35 36 37
    **Notes: This API does not support sparse parameter optimization.**

    Adadelta Optimizer. Please refer to this for details:
    `ADADELTA: AN ADAPTIVE LEARNING RATE METHOD <https://arxiv.org/abs/1212.5701>`_.

    The update is done as follows:

    .. math::

38
        E(g_t^2) &= \rho * E(g_{t-1}^2) + (1-\rho) * g^2
J
Jiawei Wang 已提交
39

40
        learning\_rate &= \sqrt{ ( E(dx_{t-1}^2) + \epsilon ) / ( E(g_t^2) + \epsilon ) }
J
Jiawei Wang 已提交
41

42
        E(dx_t^2) &= \rho * E(dx_{t-1}^2) + (1-\rho) * (-g*learning\_rate)^2
J
Jiawei Wang 已提交
43 44

    Args:
45
        learning_rate (float|Tensor|LearningRateDecay, optional): The learning rate used to update ``Parameter``.
J
Jiawei Wang 已提交
46 47 48
            It can be a float value, a ``Tensor`` with a float type or a LearningRateDecay. The default value is 0.001.
        epsilon (float): a small float number for numeric stability. Default 1.0e-6.
        rho (float): a floating point value indicating the decay rate. Default 0.95.
49
        parameters (list|tuple, optional): List/Tuple of ``Tensor`` to update to minimize ``loss``. \
50 51 52 53
            This parameter is required in dygraph mode. And you can specify different options for \
            different parameter groups such as the learning rate, weight decay, etc, \
            then the parameters are list of dict. Note that the learning_rate in paramter groups \
            represents the scale of base learning_rate. \
54
            The default value is None in static graph mode, at this time all parameters will be updated.
J
Jiawei Wang 已提交
55
        weight_decay (float|WeightDecayRegularizer, optional): The strategy of regularization. \
56 57 58 59 60
            It canbe a float value as coeff of L2 regularization or \
            :ref:`api_fluid_regularizer_L1Decay`, :ref:`api_fluid_regularizer_L2Decay`.
            If a parameter has set regularizer using :ref:`api_fluid_ParamAttr` already, \
            the regularization setting here in optimizer will be ignored for this parameter. \
            Otherwise, the regularization setting here in optimizer will take effect. \
61
            Default None, meaning there is no regularization.
J
Jiawei Wang 已提交
62 63 64 65 66 67 68 69 70 71
        grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of
            some derived class of ``GradientClipBase`` . There are three cliping strategies
            ( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` ,
            :ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping.
        name (str, optional): The default value is None. Normally there is no need for user
                to set this property. For more information, please refer to
                :ref:`api_guide_Name` .

    Examples:
        .. code-block:: python
72

J
Jiawei Wang 已提交
73
            import paddle
74 75

            inp = paddle.uniform([10, 10], dtype="float32", min=-0.1, max=0.1)
J
Jiawei Wang 已提交
76 77 78 79 80 81 82 83 84 85
            linear = paddle.nn.Linear(10, 10)
            out = linear(inp)
            loss = paddle.mean(out)
            beta1 = paddle.to_tensor([0.9], dtype="float32")
            beta2 = paddle.to_tensor([0.99], dtype="float32")
            adadelta = paddle.optimizer.Adadelta(learning_rate=0.1, parameters=linear.parameters(), weight_decay=0.01)
            back = out.backward()
            adadelta.step()
            adadelta.clear_grad()

86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
            #Note that the learning_rate of linear_2 is 0.01.
            linear_1 = paddle.nn.Linear(10, 10)
            linear_2 = paddle.nn.Linear(10, 10)
            inp = paddle.uniform(shape=[10, 10], min=-0.1, max=0.1)
            out = linear_1(inp)
            out = linear_2(out)
            loss = paddle.mean(out)
            adadelta = paddle.optimizer.Adadelta(
                learning_rate=0.1,
                parameters=[{
                    'params': linear_1.parameters()
                }, {
                    'params': linear_2.parameters(),
                    'weight_decay': 0.001,
                    'learning_rate': 0.1,
                }],
102
                weight_decay=0.01)
103 104 105 106
            out.backward()
            adadelta.step()
            adadelta.clear_grad()

J
Jiawei Wang 已提交
107 108 109 110 111
    """

    _avg_squared_grad_acc_str = "_avg_squared_grad"
    _avg_squared_update_acc_str = "_avg_squared_update"

112 113 114 115 116 117 118 119 120 121
    def __init__(
        self,
        learning_rate=0.001,
        epsilon=1.0e-6,
        rho=0.95,
        parameters=None,
        weight_decay=None,
        grad_clip=None,
        name=None,
    ):
J
Jiawei Wang 已提交
122 123 124 125 126 127
        if learning_rate is None:
            raise ValueError("learning_rate is not set.")
        if epsilon is None:
            raise ValueError("epsilon is not set.")
        if rho is None:
            raise ValueError("rho is not set.")
128
        super().__init__(
129 130 131 132 133 134
            learning_rate=learning_rate,
            parameters=parameters,
            weight_decay=weight_decay,
            grad_clip=grad_clip,
            name=name,
        )
135 136
        self._multi_precision = False
        self._master_weights = {}
J
Jiawei Wang 已提交
137 138 139
        self.type = "adadelta"
        self._epsilon = epsilon
        self._rho = rho
140 141 142 143
        self._default_dict = {
            'epsilon': epsilon,
            'rho': rho,
        }
J
Jiawei Wang 已提交
144 145 146 147

    def _create_accumulators(self, block, parameters):
        if not isinstance(block, framework.Block):
            raise TypeError("block is not instance of framework.Block.")
148 149
        if isinstance(parameters, dict):
            parameters = parameters.get('params')
J
Jiawei Wang 已提交
150 151

        for p in parameters:
152
            if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype):
153 154 155 156 157 158 159
                master_p = self._create_master_weight(p)
                self._add_accumulator(self._avg_squared_grad_acc_str, master_p)
                self._add_accumulator(
                    self._avg_squared_update_acc_str, master_p
                )
                continue
            if (
160
                self._is_dtype_fp16_or_bf16(p.dtype)
161 162 163
                and not self._multi_precision
            ):
                warnings.warn(
164
                    "Accumulating with FP16/BF16 in optimizer can lead to poor accuracy or slow convergence."
165 166
                    "Consider using multi_precision=True option of the Lars optimizer."
                )
J
Jiawei Wang 已提交
167 168 169 170
            self._add_accumulator(self._avg_squared_grad_acc_str, p)
            self._add_accumulator(self._avg_squared_update_acc_str, p)

    def _append_optimize_op(self, block, param_and_grad):
171 172 173
        if isinstance(param_and_grad, dict):
            param_and_grad = self._update_param_group(param_and_grad)

174
        avg_squared_grad_acc = self._get_accumulator_master(
175 176
            self._avg_squared_grad_acc_str, param_and_grad[0]
        )
177
        avg_squared_update_acc = self._get_accumulator_master(
178 179
            self._avg_squared_update_acc_str, param_and_grad[0]
        )
180 181
        find_master = self._multi_precision and self._is_dtype_fp16_or_bf16(
            param_and_grad[0].dtype
182 183 184 185 186 187
        )
        master_weight = (
            self._master_weights[param_and_grad[0].name]
            if find_master
            else None
        )
J
Jiawei Wang 已提交
188

189 190
        if in_dygraph_mode():
            with no_grad():
191 192 193 194 195
                _C_ops.adadelta_(
                    param_and_grad[0],
                    param_and_grad[1],
                    avg_squared_grad_acc,
                    avg_squared_update_acc,
196
                    master_weight,
197 198
                    self._rho,
                    self._epsilon,
199
                    find_master,
200
                )
201
            return None
202 203 204 205 206
        else:
            if not isinstance(block, framework.Block):
                raise TypeError("block is not instance of framework.Block.")

            # Create the adadelta optimizer op
207 208 209 210 211 212 213 214 215 216 217 218 219 220
            inputs = {
                "Param": param_and_grad[0],
                "Grad": param_and_grad[1],
                "AvgSquaredGrad": avg_squared_grad_acc,
                "AvgSquaredUpdate": avg_squared_update_acc,
            }
            outputs = {
                "ParamOut": param_and_grad[0],
                "AvgSquaredGradOut": avg_squared_grad_acc,
                "AvgSquaredUpdateOut": avg_squared_update_acc,
            }
            if find_master:
                inputs["MasterParam"] = master_weight
                outputs["MasterParamOut"] = master_weight
221 222
            adadelta_op = block.append_op(
                type=self.type,
223 224 225 226 227 228
                inputs=inputs,
                outputs=outputs,
                attrs={
                    "epsilon": self._epsilon,
                    "rho": self._rho,
                    "multi_precision": find_master,
229 230 231 232 233
                },
                stop_gradient=True,
            )

            return adadelta_op
234 235 236 237 238 239

    def _update_param_group(self, parameters):
        self._epsilon = parameters.get('epsilon', self._default_dict['epsilon'])
        self._rho = parameters.get('rho', self._default_dict['rho'])
        parameters = parameters.get('params')
        return parameters