adamax.py 12.8 KB
Newer Older
M
MRXLT 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .optimizer import Optimizer
from ..fluid import framework
17
from ..fluid.framework import name_scope
18
from paddle import _C_ops, _legacy_C_ops
19
from ..fluid.dygraph import no_grad
M
MRXLT 已提交
20

21 22
__all__ = []

M
MRXLT 已提交
23 24

class Adamax(Optimizer):
25
    r"""
26
    The Adamax optimizer is implemented based on the Adamax Optimization
M
MRXLT 已提交
27 28 29 30 31 32 33 34 35 36
    in Section 7 of `Adam paper <https://arxiv.org/abs/1412.6980>`_.
    The Adamax algorithm is a variant of the Adam algorithm based on the infinite norm,
    which makes the learning rate update algorithm more stable and simple.

    The parameter ``param_out`` update rule with gradient ``grad``:

    .. math::

        t & = t + 1

37
        moment\_out & = {\beta}_1 * moment + (1 - {\beta}_1) * grad
M
MRXLT 已提交
38

39
        inf\_norm\_out & = max({\beta}_2 * inf\_norm + \epsilon, |grad|)
M
MRXLT 已提交
40

41
        learning\_rate & = \frac{learning\_rate}{1 - {\beta}_1^t}
M
MRXLT 已提交
42

43
        param\_out & = param - learning\_rate * \frac{moment\_out}{inf\_norm\_out}
M
MRXLT 已提交
44 45 46 47 48 49 50

    Related paper: `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_

    The original paper does not have an ``epsilon`` attribute,
    it is added here for numerical stability to prevent the division by 0 error.

    Args:
51 52
        learning_rate (float|LRScheduler, optional): The learning rate used to update ``Parameter``.
            It can be a float value or a LRScheduler. The default value is 0.001.
M
MRXLT 已提交
53 54 55 56 57 58
        beta1 (float, optional): The exponential decay rate for the 1st moment estimates.
            The default value is 0.9.
        beta2 (float, optional): The exponential decay rate for the 2nd moment estimates.
            The default value is 0.999.
        epsilon (float, optional): A small float value for numerical stability.
            The default value is 1e-08.
59 60 61 62 63 64 65 66 67 68 69 70 71
        parameters (list|tuple, optional): List/Tuple of ``Tensor`` to update to minimize ``loss``.
            This parameter is required in dygraph mode. And you can specify different options for
            different parameter groups such as the learning rate, weight decay, etc,
            then the parameters are list of dict. Note that the learning_rate in paramter groups
            represents the scale of base learning_rate.
            The default value is None in static mode, at this time all parameters will be updated.
        weight_decay (float|WeightDecayRegularizer, optional): The strategy of regularization.
            It canbe a float value as coeff of L2 regularization or
            :ref:`api_fluid_regularizer_L1Decay`, :ref:`api_fluid_regularizer_L2Decay`.
            If a parameter has set regularizer using :ref:`api_fluid_ParamAttr` already,
            the regularization setting here in optimizer will be ignored for this parameter.
            Otherwise, the regularization setting here in optimizer will take effect.
            Default None, meaning there is no regularization.
72 73 74
        grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of
            some derived class of ``GradientClipBase`` . There are three cliping strategies
            ( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` ,
M
MRXLT 已提交
75 76 77 78 79 80 81 82 83 84
            :ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping.
        name (str, optional): Normally there is no need for user to set this property.
            For more information, please refer to :ref:`api_guide_Name`.
            The default value is None.

    **Notes**:
        **Currently, Adamax doesn't support sparse parameter optimization.**

    Examples:
        .. code-block:: python
85

M
MRXLT 已提交
86 87
            import paddle

88
            inp = paddle.uniform([10, 10], dtype="float32", min=-0.1, max=0.1)
M
MRXLT 已提交
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
            linear = paddle.nn.Linear(10, 10)
            inp = paddle.to_tensor(inp)
            out = linear(inp)
            loss = paddle.mean(out)

            beta1 = paddle.to_tensor([0.9], dtype="float32")
            beta2 = paddle.to_tensor([0.99], dtype="float32")

            adam = paddle.optimizer.Adamax(learning_rate=0.1,
                    parameters=linear.parameters(),
                    beta1=beta1,
                    beta2=beta2,
                    weight_decay=0.01)
            out.backward()
            adam.step()
            adam.clear_grad()

106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124

            #Note that the learning_rate of linear_2 is 0.01.
            linear_1 = paddle.nn.Linear(10, 10)
            linear_2 = paddle.nn.Linear(10, 10)
            inp = paddle.uniform(shape=[10, 10], min=-0.1, max=0.1)
            out = linear_1(inp)
            out = linear_2(out)
            loss = paddle.mean(out)
            adam = paddle.optimizer.Adamax(
                learning_rate=0.1,
                parameters=[{
                    'params': linear_1.parameters()
                }, {
                    'params': linear_2.parameters(),
                    'weight_decay': 0.001,
                    'learning_rate': 0.1,
                    'beta1': 0.8
                }],
                weight_decay=0.01,
125
                beta1=0.9)
126 127 128
            out.backward()
            adam.step()
            adam.clear_grad()
M
MRXLT 已提交
129 130 131 132 133
    """
    _moment_acc_str = "moment"
    _inf_norm_acc_str = "inf_norm"
    _beta1_pow_acc_str = "beta1_pow_acc"

134 135 136 137 138 139 140 141 142 143 144
    def __init__(
        self,
        learning_rate=0.001,
        beta1=0.9,
        beta2=0.999,
        epsilon=1e-8,
        parameters=None,
        weight_decay=None,
        grad_clip=None,
        name=None,
    ):
M
MRXLT 已提交
145 146 147 148
        assert learning_rate is not None
        assert beta1 is not None
        assert beta2 is not None
        assert epsilon is not None
M
MRXLT 已提交
149 150 151 152 153 154
        if not 0 <= beta1 < 1:
            raise ValueError("Invaild value of beta1, expect beta1 in [0,1).")
        if not 0 <= beta2 < 1:
            raise ValueError("Invaild value of beta2, expect beta2 in [0,1).")
        if not 0 <= epsilon:
            raise ValueError("Invaild value of epsilon, expect epsilon >= 0.")
155
        super().__init__(
156 157 158 159 160 161
            learning_rate=learning_rate,
            parameters=parameters,
            weight_decay=weight_decay,
            grad_clip=grad_clip,
            name=name,
        )
M
MRXLT 已提交
162 163 164 165
        self.type = "adamax"
        self._beta1 = beta1
        self._beta2 = beta2
        self._epsilon = epsilon
166 167 168
        self._default_dict = {
            'beta1': beta1,
            'beta2': beta2,
169
            'epsilon': epsilon,
170
        }
M
MRXLT 已提交
171 172

    def _create_accumulators(self, block, parameters):
173 174 175
        if isinstance(parameters, dict):
            parameters = self._update_param_group(parameters)

M
MRXLT 已提交
176 177 178 179
        # Create accumulator tensors for first moment and infinity norm
        for p in parameters:
            self._add_accumulator(self._moment_acc_str, p)
            self._add_accumulator(self._inf_norm_acc_str, p)
180 181 182 183 184 185
            self._add_accumulator(
                name=self._beta1_pow_acc_str,
                param=p,
                fill_value=self._beta1,
                shape=[1],
            )
M
MRXLT 已提交
186 187 188

    def _append_optimize_op(self, block, param_and_grad):
        assert isinstance(block, framework.Block)
189 190
        if isinstance(param_and_grad, dict):
            param_and_grad = self._update_param_group(param_and_grad)
M
MRXLT 已提交
191 192

        moment = self._get_accumulator(self._moment_acc_str, param_and_grad[0])
193 194 195 196 197 198
        inf_norm = self._get_accumulator(
            self._inf_norm_acc_str, param_and_grad[0]
        )
        beta1_pow_acc = self._get_accumulator(
            self._beta1_pow_acc_str, param_and_grad[0]
        )
199

200
        if framework.in_dygraph_mode():
201 202 203 204 205 206 207 208 209 210 211
            _C_ops.adamax_(
                param_and_grad[0],
                param_and_grad[1],
                self._create_param_lr(param_and_grad),
                moment,
                inf_norm,
                beta1_pow_acc,
                self._beta1,
                self._beta2,
                self._epsilon,
            )
212
        elif framework._in_legacy_dygraph():
213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229
            _legacy_C_ops.adamax(
                param_and_grad[0],
                param_and_grad[1],
                self._create_param_lr(param_and_grad),
                moment,
                inf_norm,
                beta1_pow_acc,
                param_and_grad[0],
                moment,
                inf_norm,
                "beta1",
                self._beta1,
                "beta2",
                self._beta2,
                "epsilon",
                self._epsilon,
            )
230 231 232 233 234 235 236 237 238 239
        else:
            # create the adamax optimize op
            adamax_op = block.append_op(
                type=self.type,
                inputs={
                    "Param": param_and_grad[0],
                    "Grad": param_and_grad[1],
                    "LearningRate": self._create_param_lr(param_and_grad),
                    "Moment": moment,
                    "InfNorm": inf_norm,
240
                    "Beta1Pow": beta1_pow_acc,
241 242 243 244
                },
                outputs={
                    "ParamOut": param_and_grad[0],
                    "MomentOut": moment,
245
                    "InfNormOut": inf_norm,
246 247 248 249
                },
                attrs={
                    "beta1": self._beta1,
                    "beta2": self._beta2,
250
                    "epsilon": self._epsilon,
251
                },
252 253
                stop_gradient=True,
            )
254 255

            return adamax_op
M
MRXLT 已提交
256 257

    def _finish_update(self, block, parameters_and_grads):
258
        """Update Beta1 Power accumulator"""
M
MRXLT 已提交
259
        assert isinstance(block, framework.Block)
260 261 262 263
        if isinstance(parameters_and_grads, list):
            for param, grad in parameters_and_grads:
                if grad is None or param.stop_gradient is True:
                    continue
264 265
                if framework.in_dygraph_mode():
                    beta1_pow_acc = self._get_accumulator(
266 267
                        self._beta1_pow_acc_str, param
                    )
268
                    with no_grad():
269 270 271
                        tmp = _C_ops.scale(
                            beta1_pow_acc, self._beta1, 0.0, True
                        )
272 273
                        beta1_pow_acc.copy_(tmp, False)
                    continue
274
                with param.block.program._optimized_guard(
275 276
                    [param, grad]
                ), name_scope('adamax'):
277
                    beta1_pow_acc = self._get_accumulator(
278 279 280 281 282 283 284 285 286
                        self._beta1_pow_acc_str, param
                    )
                    block.append_op(
                        type="scale",
                        inputs={"X": beta1_pow_acc},
                        outputs={"Out": beta1_pow_acc},
                        attrs={"scale": self._beta1},
                        stop_gradient=True,
                    )
287 288 289 290
        else:
            for param, grad in parameters_and_grads['params']:
                if grad is None or param.stop_gradient is True:
                    continue
291 292
                if framework.in_dygraph_mode():
                    beta1_pow_acc = self._get_accumulator(
293 294
                        self._beta1_pow_acc_str, param
                    )
295
                    self._beta1 = parameters_and_grads.get(
296 297
                        'beta1', self._default_dict['beta1']
                    )
298
                    with no_grad():
299 300 301
                        tmp = _C_ops.scale(
                            beta1_pow_acc, self._beta1, 0.0, True
                        )
302 303 304
                        beta1_pow_acc.copy_(tmp, False)
                    continue

305
                with param.block.program._optimized_guard(
306 307
                    [param, grad]
                ), name_scope('adamax'):
308
                    beta1_pow_acc = self._get_accumulator(
309 310
                        self._beta1_pow_acc_str, param
                    )
311
                    self._beta1 = parameters_and_grads.get(
312 313 314 315 316 317 318 319 320
                        'beta1', self._default_dict['beta1']
                    )
                    block.append_op(
                        type="scale",
                        inputs={"X": beta1_pow_acc},
                        outputs={"Out": beta1_pow_acc},
                        attrs={"scale": self._beta1},
                        stop_gradient=True,
                    )
321 322 323 324 325 326 327

    def _update_param_group(self, parameters):
        self._beta1 = parameters.get('beta1', self._default_dict['beta1'])
        self._beta2 = parameters.get('beta2', self._default_dict['beta2'])
        self._epsilon = parameters.get('epsilon', self._default_dict['epsilon'])
        parameters = parameters.get('params')
        return parameters