adadelta.py 11.5 KB
Newer Older
J
Jiawei Wang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16 17
import warnings

import paddle
18
from paddle import _C_ops
19

20
from ..fluid import core, framework, unique_name
21
from ..fluid.dygraph import no_grad
22
from ..fluid.layer_helper import LayerHelper
23 24
from ..framework import in_dygraph_mode
from .optimizer import Optimizer
J
Jiawei Wang 已提交
25

26 27
__all__ = []

J
Jiawei Wang 已提交
28 29

class Adadelta(Optimizer):
30
    r"""
J
Jiawei Wang 已提交
31 32 33 34 35 36 37 38 39
    **Notes: This API does not support sparse parameter optimization.**

    Adadelta Optimizer. Please refer to this for details:
    `ADADELTA: AN ADAPTIVE LEARNING RATE METHOD <https://arxiv.org/abs/1212.5701>`_.

    The update is done as follows:

    .. math::

40
        E(g_t^2) &= \rho * E(g_{t-1}^2) + (1-\rho) * g^2
J
Jiawei Wang 已提交
41

42
        learning\_rate &= \sqrt{ ( E(dx_{t-1}^2) + \epsilon ) / ( E(g_t^2) + \epsilon ) }
J
Jiawei Wang 已提交
43

44
        E(dx_t^2) &= \rho * E(dx_{t-1}^2) + (1-\rho) * (-g*learning\_rate)^2
J
Jiawei Wang 已提交
45 46

    Args:
47
        learning_rate (float|Tensor|LearningRateDecay, optional): The learning rate used to update ``Parameter``.
J
Jiawei Wang 已提交
48 49 50
            It can be a float value, a ``Tensor`` with a float type or a LearningRateDecay. The default value is 0.001.
        epsilon (float): a small float number for numeric stability. Default 1.0e-6.
        rho (float): a floating point value indicating the decay rate. Default 0.95.
51
        parameters (list|tuple, optional): List/Tuple of ``Tensor`` to update to minimize ``loss``. \
52 53 54 55
            This parameter is required in dygraph mode. And you can specify different options for \
            different parameter groups such as the learning rate, weight decay, etc, \
            then the parameters are list of dict. Note that the learning_rate in paramter groups \
            represents the scale of base learning_rate. \
56
            The default value is None in static graph mode, at this time all parameters will be updated.
J
Jiawei Wang 已提交
57
        weight_decay (float|WeightDecayRegularizer, optional): The strategy of regularization. \
58 59 60 61 62
            It canbe a float value as coeff of L2 regularization or \
            :ref:`api_fluid_regularizer_L1Decay`, :ref:`api_fluid_regularizer_L2Decay`.
            If a parameter has set regularizer using :ref:`api_fluid_ParamAttr` already, \
            the regularization setting here in optimizer will be ignored for this parameter. \
            Otherwise, the regularization setting here in optimizer will take effect. \
63
            Default None, meaning there is no regularization.
J
Jiawei Wang 已提交
64 65 66 67 68 69 70 71 72 73
        grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of
            some derived class of ``GradientClipBase`` . There are three cliping strategies
            ( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` ,
            :ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping.
        name (str, optional): The default value is None. Normally there is no need for user
                to set this property. For more information, please refer to
                :ref:`api_guide_Name` .

    Examples:
        .. code-block:: python
74

J
Jiawei Wang 已提交
75
            import paddle
76 77

            inp = paddle.uniform([10, 10], dtype="float32", min=-0.1, max=0.1)
J
Jiawei Wang 已提交
78 79 80 81 82 83 84 85 86 87
            linear = paddle.nn.Linear(10, 10)
            out = linear(inp)
            loss = paddle.mean(out)
            beta1 = paddle.to_tensor([0.9], dtype="float32")
            beta2 = paddle.to_tensor([0.99], dtype="float32")
            adadelta = paddle.optimizer.Adadelta(learning_rate=0.1, parameters=linear.parameters(), weight_decay=0.01)
            back = out.backward()
            adadelta.step()
            adadelta.clear_grad()

88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
            #Note that the learning_rate of linear_2 is 0.01.
            linear_1 = paddle.nn.Linear(10, 10)
            linear_2 = paddle.nn.Linear(10, 10)
            inp = paddle.uniform(shape=[10, 10], min=-0.1, max=0.1)
            out = linear_1(inp)
            out = linear_2(out)
            loss = paddle.mean(out)
            adadelta = paddle.optimizer.Adadelta(
                learning_rate=0.1,
                parameters=[{
                    'params': linear_1.parameters()
                }, {
                    'params': linear_2.parameters(),
                    'weight_decay': 0.001,
                    'learning_rate': 0.1,
                }],
104
                weight_decay=0.01)
105 106 107 108
            out.backward()
            adadelta.step()
            adadelta.clear_grad()

J
Jiawei Wang 已提交
109 110 111 112 113
    """

    _avg_squared_grad_acc_str = "_avg_squared_grad"
    _avg_squared_update_acc_str = "_avg_squared_update"

114 115 116 117 118 119 120 121 122 123
    def __init__(
        self,
        learning_rate=0.001,
        epsilon=1.0e-6,
        rho=0.95,
        parameters=None,
        weight_decay=None,
        grad_clip=None,
        name=None,
    ):
J
Jiawei Wang 已提交
124 125 126 127 128 129
        if learning_rate is None:
            raise ValueError("learning_rate is not set.")
        if epsilon is None:
            raise ValueError("epsilon is not set.")
        if rho is None:
            raise ValueError("rho is not set.")
130
        super().__init__(
131 132 133 134 135 136
            learning_rate=learning_rate,
            parameters=parameters,
            weight_decay=weight_decay,
            grad_clip=grad_clip,
            name=name,
        )
137 138
        self._multi_precision = False
        self._master_weights = {}
J
Jiawei Wang 已提交
139 140 141
        self.type = "adadelta"
        self._epsilon = epsilon
        self._rho = rho
142 143 144 145
        self._default_dict = {
            'epsilon': epsilon,
            'rho': rho,
        }
J
Jiawei Wang 已提交
146

147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202
    def _create_master_weight(self, param):
        if param.name in self._master_weights:
            var = self._master_weights[param.name]
        else:
            assert isinstance(self.helper, LayerHelper)

            var_name = param.name + "_fp32_master"
            var_name = unique_name.generate(var_name)
            var = paddle.static.create_global_var(
                name=var_name,
                shape=param.shape,
                value=0,
                dtype='float32',
                persistable=True,
            )
            block = self.helper.startup_program.global_block()
            block.append_op(
                type="cast",
                inputs={"X": [param]},
                outputs={"Out": [var]},
                attrs={
                    "in_dtype": param.dtype,
                    "out_dtype": core.VarDesc.VarType.FP32,
                },
            )
            self._master_weights[param.name] = var
        return var

    def _get_accumulator(self, name, param):
        """Utility function to fetch an accumulator for a parameter
        Args:
            name: name of the accumulator
            param: parameter variable for which accumulator is to be fetched
        Returns:
            accumulator variable for the parameter
        """
        if self._name is not None:
            name = self._name + "_" + name
        find_master = (
            self._multi_precision and param.dtype == core.VarDesc.VarType.FP16
        )
        target_param = (
            self._master_weights[param.name] if find_master else param
        )
        target_name = target_param.name
        if (
            name not in self._accumulators
            or target_name not in self._accumulators[name]
        ):
            raise Exception(
                "Accumulator {} does not exist for parameter {}".format(
                    name, target_name
                )
            )
        return self._accumulators[name][target_name]

J
Jiawei Wang 已提交
203 204 205
    def _create_accumulators(self, block, parameters):
        if not isinstance(block, framework.Block):
            raise TypeError("block is not instance of framework.Block.")
206 207
        if isinstance(parameters, dict):
            parameters = parameters.get('params')
J
Jiawei Wang 已提交
208 209

        for p in parameters:
210 211 212 213 214 215 216 217 218 219 220 221 222 223 224
            if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16:
                master_p = self._create_master_weight(p)
                self._add_accumulator(self._avg_squared_grad_acc_str, master_p)
                self._add_accumulator(
                    self._avg_squared_update_acc_str, master_p
                )
                continue
            if (
                p.dtype == core.VarDesc.VarType.FP16
                and not self._multi_precision
            ):
                warnings.warn(
                    "Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence."
                    "Consider using multi_precision=True option of the Lars optimizer."
                )
J
Jiawei Wang 已提交
225 226 227 228
            self._add_accumulator(self._avg_squared_grad_acc_str, p)
            self._add_accumulator(self._avg_squared_update_acc_str, p)

    def _append_optimize_op(self, block, param_and_grad):
229 230 231
        if isinstance(param_and_grad, dict):
            param_and_grad = self._update_param_group(param_and_grad)

J
Jiawei Wang 已提交
232
        avg_squared_grad_acc = self._get_accumulator(
233 234
            self._avg_squared_grad_acc_str, param_and_grad[0]
        )
J
Jiawei Wang 已提交
235
        avg_squared_update_acc = self._get_accumulator(
236 237
            self._avg_squared_update_acc_str, param_and_grad[0]
        )
238 239 240 241 242 243 244 245 246
        find_master = (
            self._multi_precision
            and param_and_grad[0].dtype == core.VarDesc.VarType.FP16
        )
        master_weight = (
            self._master_weights[param_and_grad[0].name]
            if find_master
            else None
        )
J
Jiawei Wang 已提交
247

248 249
        if in_dygraph_mode():
            with no_grad():
250 251 252 253 254
                _C_ops.adadelta_(
                    param_and_grad[0],
                    param_and_grad[1],
                    avg_squared_grad_acc,
                    avg_squared_update_acc,
255
                    master_weight,
256 257
                    self._rho,
                    self._epsilon,
258
                    find_master,
259
                )
260
            return None
261 262 263 264 265
        else:
            if not isinstance(block, framework.Block):
                raise TypeError("block is not instance of framework.Block.")

            # Create the adadelta optimizer op
266 267 268 269 270 271 272 273 274 275 276 277 278 279
            inputs = {
                "Param": param_and_grad[0],
                "Grad": param_and_grad[1],
                "AvgSquaredGrad": avg_squared_grad_acc,
                "AvgSquaredUpdate": avg_squared_update_acc,
            }
            outputs = {
                "ParamOut": param_and_grad[0],
                "AvgSquaredGradOut": avg_squared_grad_acc,
                "AvgSquaredUpdateOut": avg_squared_update_acc,
            }
            if find_master:
                inputs["MasterParam"] = master_weight
                outputs["MasterParamOut"] = master_weight
280 281
            adadelta_op = block.append_op(
                type=self.type,
282 283 284 285 286 287
                inputs=inputs,
                outputs=outputs,
                attrs={
                    "epsilon": self._epsilon,
                    "rho": self._rho,
                    "multi_precision": find_master,
288 289 290 291 292
                },
                stop_gradient=True,
            )

            return adadelta_op
293 294 295 296 297 298

    def _update_param_group(self, parameters):
        self._epsilon = parameters.get('epsilon', self._default_dict['epsilon'])
        self._rho = parameters.get('rho', self._default_dict['rho'])
        parameters = parameters.get('params')
        return parameters