adamax.py 14.0 KB
Newer Older
M
MRXLT 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
import warnings

17
from paddle import _C_ops
18

19
from ..fluid import core, framework
20
from ..fluid.dygraph import no_grad
21 22
from ..fluid.framework import name_scope
from .optimizer import Optimizer
M
MRXLT 已提交
23

24 25
__all__ = []

M
MRXLT 已提交
26 27

class Adamax(Optimizer):
28
    r"""
29
    The Adamax optimizer is implemented based on the Adamax Optimization
M
MRXLT 已提交
30 31 32 33 34 35 36 37 38 39
    in Section 7 of `Adam paper <https://arxiv.org/abs/1412.6980>`_.
    The Adamax algorithm is a variant of the Adam algorithm based on the infinite norm,
    which makes the learning rate update algorithm more stable and simple.

    The parameter ``param_out`` update rule with gradient ``grad``:

    .. math::

        t & = t + 1

40
        moment\_out & = {\beta}_1 * moment + (1 - {\beta}_1) * grad
M
MRXLT 已提交
41

42
        inf\_norm\_out & = max({\beta}_2 * inf\_norm + \epsilon, |grad|)
M
MRXLT 已提交
43

44
        learning\_rate & = \frac{learning\_rate}{1 - {\beta}_1^t}
M
MRXLT 已提交
45

46
        param\_out & = param - learning\_rate * \frac{moment\_out}{inf\_norm\_out}
M
MRXLT 已提交
47 48 49 50 51 52 53

    Related paper: `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_

    The original paper does not have an ``epsilon`` attribute,
    it is added here for numerical stability to prevent the division by 0 error.

    Args:
54 55
        learning_rate (float|LRScheduler, optional): The learning rate used to update ``Parameter``.
            It can be a float value or a LRScheduler. The default value is 0.001.
M
MRXLT 已提交
56 57 58 59 60 61
        beta1 (float, optional): The exponential decay rate for the 1st moment estimates.
            The default value is 0.9.
        beta2 (float, optional): The exponential decay rate for the 2nd moment estimates.
            The default value is 0.999.
        epsilon (float, optional): A small float value for numerical stability.
            The default value is 1e-08.
62 63 64 65 66
        parameters (list|tuple, optional): List/Tuple of ``Tensor`` to update to minimize ``loss``.
            This parameter is required in dygraph mode. And you can specify different options for
            different parameter groups such as the learning rate, weight decay, etc,
            then the parameters are list of dict. Note that the learning_rate in paramter groups
            represents the scale of base learning_rate.
67
            The default value is None in static graph mode, at this time all parameters will be updated.
68 69 70 71 72 73 74
        weight_decay (float|WeightDecayRegularizer, optional): The strategy of regularization.
            It canbe a float value as coeff of L2 regularization or
            :ref:`api_fluid_regularizer_L1Decay`, :ref:`api_fluid_regularizer_L2Decay`.
            If a parameter has set regularizer using :ref:`api_fluid_ParamAttr` already,
            the regularization setting here in optimizer will be ignored for this parameter.
            Otherwise, the regularization setting here in optimizer will take effect.
            Default None, meaning there is no regularization.
75 76 77
        grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of
            some derived class of ``GradientClipBase`` . There are three cliping strategies
            ( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` ,
M
MRXLT 已提交
78 79 80 81 82 83 84 85 86 87
            :ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping.
        name (str, optional): Normally there is no need for user to set this property.
            For more information, please refer to :ref:`api_guide_Name`.
            The default value is None.

    **Notes**:
        **Currently, Adamax doesn't support sparse parameter optimization.**

    Examples:
        .. code-block:: python
88

M
MRXLT 已提交
89 90
            import paddle

91
            inp = paddle.uniform([10, 10], dtype="float32", min=-0.1, max=0.1)
M
MRXLT 已提交
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
            linear = paddle.nn.Linear(10, 10)
            inp = paddle.to_tensor(inp)
            out = linear(inp)
            loss = paddle.mean(out)

            beta1 = paddle.to_tensor([0.9], dtype="float32")
            beta2 = paddle.to_tensor([0.99], dtype="float32")

            adam = paddle.optimizer.Adamax(learning_rate=0.1,
                    parameters=linear.parameters(),
                    beta1=beta1,
                    beta2=beta2,
                    weight_decay=0.01)
            out.backward()
            adam.step()
            adam.clear_grad()

109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127

            #Note that the learning_rate of linear_2 is 0.01.
            linear_1 = paddle.nn.Linear(10, 10)
            linear_2 = paddle.nn.Linear(10, 10)
            inp = paddle.uniform(shape=[10, 10], min=-0.1, max=0.1)
            out = linear_1(inp)
            out = linear_2(out)
            loss = paddle.mean(out)
            adam = paddle.optimizer.Adamax(
                learning_rate=0.1,
                parameters=[{
                    'params': linear_1.parameters()
                }, {
                    'params': linear_2.parameters(),
                    'weight_decay': 0.001,
                    'learning_rate': 0.1,
                    'beta1': 0.8
                }],
                weight_decay=0.01,
128
                beta1=0.9)
129 130 131
            out.backward()
            adam.step()
            adam.clear_grad()
M
MRXLT 已提交
132 133 134 135 136
    """
    _moment_acc_str = "moment"
    _inf_norm_acc_str = "inf_norm"
    _beta1_pow_acc_str = "beta1_pow_acc"

137 138 139 140 141 142 143 144 145 146 147
    def __init__(
        self,
        learning_rate=0.001,
        beta1=0.9,
        beta2=0.999,
        epsilon=1e-8,
        parameters=None,
        weight_decay=None,
        grad_clip=None,
        name=None,
    ):
M
MRXLT 已提交
148 149 150 151
        assert learning_rate is not None
        assert beta1 is not None
        assert beta2 is not None
        assert epsilon is not None
M
MRXLT 已提交
152 153 154 155 156 157
        if not 0 <= beta1 < 1:
            raise ValueError("Invaild value of beta1, expect beta1 in [0,1).")
        if not 0 <= beta2 < 1:
            raise ValueError("Invaild value of beta2, expect beta2 in [0,1).")
        if not 0 <= epsilon:
            raise ValueError("Invaild value of epsilon, expect epsilon >= 0.")
158
        super().__init__(
159 160 161 162 163 164
            learning_rate=learning_rate,
            parameters=parameters,
            weight_decay=weight_decay,
            grad_clip=grad_clip,
            name=name,
        )
M
MRXLT 已提交
165 166 167 168
        self.type = "adamax"
        self._beta1 = beta1
        self._beta2 = beta2
        self._epsilon = epsilon
169 170 171
        self._multi_precision = False
        self._master_weights = {}

172 173 174
        self._default_dict = {
            'beta1': beta1,
            'beta2': beta2,
175
            'epsilon': epsilon,
176
        }
M
MRXLT 已提交
177

178 179 180 181 182 183 184 185 186 187 188 189 190 191
    def _add_moments_pows(self, p):
        acc_dtype = p.dtype
        if self._is_dtype_fp16_or_bf16(acc_dtype):
            acc_dtype = core.VarDesc.VarType.FP32

        self._add_accumulator(self._moment_acc_str, p, dtype=acc_dtype)
        self._add_accumulator(self._inf_norm_acc_str, p, dtype=acc_dtype)
        self._add_accumulator(
            name=self._beta1_pow_acc_str,
            param=p,
            fill_value=self._beta1,
            shape=[1],
        )

M
MRXLT 已提交
192
    def _create_accumulators(self, block, parameters):
193 194 195
        if isinstance(parameters, dict):
            parameters = self._update_param_group(parameters)

M
MRXLT 已提交
196 197
        # Create accumulator tensors for first moment and infinity norm
        for p in parameters:
W
wanghuancoder 已提交
198 199
            if p.name in self._already_create_accumulater:
                continue
200
            if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype):
201 202
                master_p = self._create_master_weight(p)
                self._add_moments_pows(master_p)
W
wanghuancoder 已提交
203
                self._already_create_accumulater.add(p.name)
204 205
                continue
            if (
206
                self._is_dtype_fp16_or_bf16(p.dtype)
207 208 209
                and not self._multi_precision
            ):
                warnings.warn(
210
                    "Accumulating with FP16/BF16 in optimizer can lead to poor accuracy or slow convergence."
211 212 213
                    "Consider using multi_precision=True option of the Adam optimizer."
                )
            self._add_moments_pows(p)
W
wanghuancoder 已提交
214
            self._already_create_accumulater.add(p.name)
215

M
MRXLT 已提交
216 217
    def _append_optimize_op(self, block, param_and_grad):
        assert isinstance(block, framework.Block)
218 219
        if isinstance(param_and_grad, dict):
            param_and_grad = self._update_param_group(param_and_grad)
M
MRXLT 已提交
220

221 222 223 224
        moment = self._get_accumulator_master(
            self._moment_acc_str, param_and_grad[0]
        )
        inf_norm = self._get_accumulator_master(
225 226
            self._inf_norm_acc_str, param_and_grad[0]
        )
227

228 229
        find_master = self._multi_precision and self._is_dtype_fp16_or_bf16(
            param_and_grad[0].dtype
230 231 232 233 234 235 236
        )
        master_weight = (
            self._master_weights[param_and_grad[0].name]
            if find_master
            else None
        )

237
        beta1_pow_acc = self._get_accumulator_master(
238 239
            self._beta1_pow_acc_str, param_and_grad[0]
        )
240
        if framework.in_dygraph_mode():
241 242 243 244 245 246 247
            _C_ops.adamax_(
                param_and_grad[0],
                param_and_grad[1],
                self._create_param_lr(param_and_grad),
                moment,
                inf_norm,
                beta1_pow_acc,
248
                master_weight,
249 250 251
                self._beta1,
                self._beta2,
                self._epsilon,
252
                find_master,
253
            )
254

255 256
        else:
            # create the adamax optimize op
257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278
            inputs = {
                "Param": param_and_grad[0],
                "Grad": param_and_grad[1],
                "LearningRate": self._create_param_lr(param_and_grad),
                "Moment": moment,
                "InfNorm": inf_norm,
                "Beta1Pow": beta1_pow_acc,
            }
            outputs = {
                "ParamOut": param_and_grad[0],
                "MomentOut": moment,
                "InfNormOut": inf_norm,
            }
            if find_master:
                inputs["MasterParam"] = master_weight
                outputs["MasterParamOut"] = master_weight
            attrs = {
                "beta1": self._beta1,
                "beta2": self._beta2,
                "epsilon": self._epsilon,
                "multi_precision": find_master,
            }
279 280
            adamax_op = block.append_op(
                type=self.type,
281 282 283
                inputs=inputs,
                outputs=outputs,
                attrs=attrs,
284 285
                stop_gradient=True,
            )
286 287

            return adamax_op
M
MRXLT 已提交
288 289

    def _finish_update(self, block, parameters_and_grads):
290
        """Update Beta1 Power accumulator"""
M
MRXLT 已提交
291
        assert isinstance(block, framework.Block)
292 293 294 295
        if isinstance(parameters_and_grads, list):
            for param, grad in parameters_and_grads:
                if grad is None or param.stop_gradient is True:
                    continue
296
                if framework.in_dygraph_mode():
297
                    beta1_pow_acc = self._get_accumulator_master(
298 299
                        self._beta1_pow_acc_str, param
                    )
300
                    with no_grad():
301 302 303
                        tmp = _C_ops.scale(
                            beta1_pow_acc, self._beta1, 0.0, True
                        )
304
                        beta1_pow_acc.copy_(tmp, False)
305 306 307 308
                else:
                    with param.block.program._optimized_guard(
                        [param, grad]
                    ), name_scope('adamax'):
309
                        beta1_pow_acc = self._get_accumulator_master(
310 311 312 313 314 315 316 317 318
                            self._beta1_pow_acc_str, param
                        )
                        block.append_op(
                            type="scale",
                            inputs={"X": beta1_pow_acc},
                            outputs={"Out": beta1_pow_acc},
                            attrs={"scale": self._beta1},
                            stop_gradient=True,
                        )
319 320 321 322
        else:
            for param, grad in parameters_and_grads['params']:
                if grad is None or param.stop_gradient is True:
                    continue
323
                if framework.in_dygraph_mode():
324
                    beta1_pow_acc = self._get_accumulator_master(
325 326
                        self._beta1_pow_acc_str, param
                    )
327
                    self._beta1 = parameters_and_grads.get(
328 329
                        'beta1', self._default_dict['beta1']
                    )
330
                    with no_grad():
331 332 333
                        tmp = _C_ops.scale(
                            beta1_pow_acc, self._beta1, 0.0, True
                        )
334
                        beta1_pow_acc.copy_(tmp, False)
335 336 337 338
                else:
                    with param.block.program._optimized_guard(
                        [param, grad]
                    ), name_scope('adamax'):
339
                        beta1_pow_acc = self._get_accumulator_master(
340 341 342 343 344 345 346 347 348 349 350 351
                            self._beta1_pow_acc_str, param
                        )
                        self._beta1 = parameters_and_grads.get(
                            'beta1', self._default_dict['beta1']
                        )
                        block.append_op(
                            type="scale",
                            inputs={"X": beta1_pow_acc},
                            outputs={"Out": beta1_pow_acc},
                            attrs={"scale": self._beta1},
                            stop_gradient=True,
                        )
352 353 354 355 356 357 358

    def _update_param_group(self, parameters):
        self._beta1 = parameters.get('beta1', self._default_dict['beta1'])
        self._beta2 = parameters.get('beta2', self._default_dict['beta2'])
        self._epsilon = parameters.get('epsilon', self._default_dict['epsilon'])
        parameters = parameters.get('params')
        return parameters