adamax.py 15.8 KB
Newer Older
M
MRXLT 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16 17
import warnings

import paddle
18
from paddle import _C_ops
19

20
from ..fluid import core, framework, unique_name
21
from ..fluid.dygraph import no_grad
22
from ..fluid.framework import name_scope
23
from ..fluid.layer_helper import LayerHelper
24
from .optimizer import Optimizer
M
MRXLT 已提交
25

26 27
__all__ = []

M
MRXLT 已提交
28 29

class Adamax(Optimizer):
30
    r"""
31
    The Adamax optimizer is implemented based on the Adamax Optimization
M
MRXLT 已提交
32 33 34 35 36 37 38 39 40 41
    in Section 7 of `Adam paper <https://arxiv.org/abs/1412.6980>`_.
    The Adamax algorithm is a variant of the Adam algorithm based on the infinite norm,
    which makes the learning rate update algorithm more stable and simple.

    The parameter ``param_out`` update rule with gradient ``grad``:

    .. math::

        t & = t + 1

42
        moment\_out & = {\beta}_1 * moment + (1 - {\beta}_1) * grad
M
MRXLT 已提交
43

44
        inf\_norm\_out & = max({\beta}_2 * inf\_norm + \epsilon, |grad|)
M
MRXLT 已提交
45

46
        learning\_rate & = \frac{learning\_rate}{1 - {\beta}_1^t}
M
MRXLT 已提交
47

48
        param\_out & = param - learning\_rate * \frac{moment\_out}{inf\_norm\_out}
M
MRXLT 已提交
49 50 51 52 53 54 55

    Related paper: `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_

    The original paper does not have an ``epsilon`` attribute,
    it is added here for numerical stability to prevent the division by 0 error.

    Args:
56 57
        learning_rate (float|LRScheduler, optional): The learning rate used to update ``Parameter``.
            It can be a float value or a LRScheduler. The default value is 0.001.
M
MRXLT 已提交
58 59 60 61 62 63
        beta1 (float, optional): The exponential decay rate for the 1st moment estimates.
            The default value is 0.9.
        beta2 (float, optional): The exponential decay rate for the 2nd moment estimates.
            The default value is 0.999.
        epsilon (float, optional): A small float value for numerical stability.
            The default value is 1e-08.
64 65 66 67 68
        parameters (list|tuple, optional): List/Tuple of ``Tensor`` to update to minimize ``loss``.
            This parameter is required in dygraph mode. And you can specify different options for
            different parameter groups such as the learning rate, weight decay, etc,
            then the parameters are list of dict. Note that the learning_rate in paramter groups
            represents the scale of base learning_rate.
69
            The default value is None in static graph mode, at this time all parameters will be updated.
70 71 72 73 74 75 76
        weight_decay (float|WeightDecayRegularizer, optional): The strategy of regularization.
            It canbe a float value as coeff of L2 regularization or
            :ref:`api_fluid_regularizer_L1Decay`, :ref:`api_fluid_regularizer_L2Decay`.
            If a parameter has set regularizer using :ref:`api_fluid_ParamAttr` already,
            the regularization setting here in optimizer will be ignored for this parameter.
            Otherwise, the regularization setting here in optimizer will take effect.
            Default None, meaning there is no regularization.
77 78 79
        grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of
            some derived class of ``GradientClipBase`` . There are three cliping strategies
            ( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` ,
M
MRXLT 已提交
80 81 82 83 84 85 86 87 88 89
            :ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping.
        name (str, optional): Normally there is no need for user to set this property.
            For more information, please refer to :ref:`api_guide_Name`.
            The default value is None.

    **Notes**:
        **Currently, Adamax doesn't support sparse parameter optimization.**

    Examples:
        .. code-block:: python
90

M
MRXLT 已提交
91 92
            import paddle

93
            inp = paddle.uniform([10, 10], dtype="float32", min=-0.1, max=0.1)
M
MRXLT 已提交
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
            linear = paddle.nn.Linear(10, 10)
            inp = paddle.to_tensor(inp)
            out = linear(inp)
            loss = paddle.mean(out)

            beta1 = paddle.to_tensor([0.9], dtype="float32")
            beta2 = paddle.to_tensor([0.99], dtype="float32")

            adam = paddle.optimizer.Adamax(learning_rate=0.1,
                    parameters=linear.parameters(),
                    beta1=beta1,
                    beta2=beta2,
                    weight_decay=0.01)
            out.backward()
            adam.step()
            adam.clear_grad()

111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129

            #Note that the learning_rate of linear_2 is 0.01.
            linear_1 = paddle.nn.Linear(10, 10)
            linear_2 = paddle.nn.Linear(10, 10)
            inp = paddle.uniform(shape=[10, 10], min=-0.1, max=0.1)
            out = linear_1(inp)
            out = linear_2(out)
            loss = paddle.mean(out)
            adam = paddle.optimizer.Adamax(
                learning_rate=0.1,
                parameters=[{
                    'params': linear_1.parameters()
                }, {
                    'params': linear_2.parameters(),
                    'weight_decay': 0.001,
                    'learning_rate': 0.1,
                    'beta1': 0.8
                }],
                weight_decay=0.01,
130
                beta1=0.9)
131 132 133
            out.backward()
            adam.step()
            adam.clear_grad()
M
MRXLT 已提交
134 135 136 137 138
    """
    _moment_acc_str = "moment"
    _inf_norm_acc_str = "inf_norm"
    _beta1_pow_acc_str = "beta1_pow_acc"

139 140 141 142 143 144 145 146 147 148 149
    def __init__(
        self,
        learning_rate=0.001,
        beta1=0.9,
        beta2=0.999,
        epsilon=1e-8,
        parameters=None,
        weight_decay=None,
        grad_clip=None,
        name=None,
    ):
M
MRXLT 已提交
150 151 152 153
        assert learning_rate is not None
        assert beta1 is not None
        assert beta2 is not None
        assert epsilon is not None
M
MRXLT 已提交
154 155 156 157 158 159
        if not 0 <= beta1 < 1:
            raise ValueError("Invaild value of beta1, expect beta1 in [0,1).")
        if not 0 <= beta2 < 1:
            raise ValueError("Invaild value of beta2, expect beta2 in [0,1).")
        if not 0 <= epsilon:
            raise ValueError("Invaild value of epsilon, expect epsilon >= 0.")
160
        super().__init__(
161 162 163 164 165 166
            learning_rate=learning_rate,
            parameters=parameters,
            weight_decay=weight_decay,
            grad_clip=grad_clip,
            name=name,
        )
M
MRXLT 已提交
167 168 169 170
        self.type = "adamax"
        self._beta1 = beta1
        self._beta2 = beta2
        self._epsilon = epsilon
171 172 173
        self._multi_precision = False
        self._master_weights = {}

174 175 176
        self._default_dict = {
            'beta1': beta1,
            'beta2': beta2,
177
            'epsilon': epsilon,
178
        }
M
MRXLT 已提交
179

180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221
    def _add_moments_pows(self, p):
        acc_dtype = p.dtype
        if self._is_dtype_fp16_or_bf16(acc_dtype):
            acc_dtype = core.VarDesc.VarType.FP32

        self._add_accumulator(self._moment_acc_str, p, dtype=acc_dtype)
        self._add_accumulator(self._inf_norm_acc_str, p, dtype=acc_dtype)
        self._add_accumulator(
            name=self._beta1_pow_acc_str,
            param=p,
            fill_value=self._beta1,
            shape=[1],
        )

    def _create_master_weight(self, param):
        if param.name in self._master_weights:
            var = self._master_weights[param.name]
        else:
            assert isinstance(self.helper, LayerHelper)

            var_name = param.name + "_fp32_master"
            var_name = unique_name.generate(var_name)
            var = paddle.static.create_global_var(
                name=var_name,
                shape=param.shape,
                value=0,
                dtype='float32',
                persistable=True,
            )
            block = self.helper.startup_program.global_block()
            block.append_op(
                type="cast",
                inputs={"X": [param]},
                outputs={"Out": [var]},
                attrs={
                    "in_dtype": param.dtype,
                    "out_dtype": core.VarDesc.VarType.FP32,
                },
            )
            self._master_weights[param.name] = var
        return var

M
MRXLT 已提交
222
    def _create_accumulators(self, block, parameters):
223 224 225
        if isinstance(parameters, dict):
            parameters = self._update_param_group(parameters)

M
MRXLT 已提交
226 227
        # Create accumulator tensors for first moment and infinity norm
        for p in parameters:
228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266
            if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16:
                master_p = self._create_master_weight(p)
                self._add_moments_pows(master_p)
                continue
            if (
                p.dtype == core.VarDesc.VarType.FP16
                and not self._multi_precision
            ):
                warnings.warn(
                    "Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence."
                    "Consider using multi_precision=True option of the Adam optimizer."
                )
            self._add_moments_pows(p)

    def _get_accumulator(self, name, param):
        """Utility function to fetch an accumulator for a parameter
        Args:
            name: name of the accumulator
            param: parameter variable for which accumulator is to be fetched
        Returns:
            accumulator variable for the parameter
        """
        if self._name is not None:
            name = self._name + "_" + name
        find_master = self._multi_precision and self._is_dtype_fp16_or_bf16(
            param.dtype
        )
        target_param = (
            self._master_weights[param.name] if find_master else param
        )
        target_name = target_param.name
        if (
            name not in self._accumulators
            or target_name not in self._accumulators[name]
        ):
            raise Exception(
                "Accumulator {} does not exist for parameter {}".format(
                    name, target_name
                )
267
            )
268
        return self._accumulators[name][target_name]
M
MRXLT 已提交
269 270 271

    def _append_optimize_op(self, block, param_and_grad):
        assert isinstance(block, framework.Block)
272 273
        if isinstance(param_and_grad, dict):
            param_and_grad = self._update_param_group(param_and_grad)
M
MRXLT 已提交
274 275

        moment = self._get_accumulator(self._moment_acc_str, param_and_grad[0])
276 277 278
        inf_norm = self._get_accumulator(
            self._inf_norm_acc_str, param_and_grad[0]
        )
279 280 281 282 283 284 285 286 287 288 289

        find_master = (
            self._multi_precision
            and param_and_grad[0].dtype == core.VarDesc.VarType.FP16
        )
        master_weight = (
            self._master_weights[param_and_grad[0].name]
            if find_master
            else None
        )

290 291 292
        beta1_pow_acc = self._get_accumulator(
            self._beta1_pow_acc_str, param_and_grad[0]
        )
293
        if framework.in_dygraph_mode():
294 295 296 297 298 299 300
            _C_ops.adamax_(
                param_and_grad[0],
                param_and_grad[1],
                self._create_param_lr(param_and_grad),
                moment,
                inf_norm,
                beta1_pow_acc,
301
                master_weight,
302 303 304
                self._beta1,
                self._beta2,
                self._epsilon,
305
                find_master,
306
            )
307

308 309
        else:
            # create the adamax optimize op
310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331
            inputs = {
                "Param": param_and_grad[0],
                "Grad": param_and_grad[1],
                "LearningRate": self._create_param_lr(param_and_grad),
                "Moment": moment,
                "InfNorm": inf_norm,
                "Beta1Pow": beta1_pow_acc,
            }
            outputs = {
                "ParamOut": param_and_grad[0],
                "MomentOut": moment,
                "InfNormOut": inf_norm,
            }
            if find_master:
                inputs["MasterParam"] = master_weight
                outputs["MasterParamOut"] = master_weight
            attrs = {
                "beta1": self._beta1,
                "beta2": self._beta2,
                "epsilon": self._epsilon,
                "multi_precision": find_master,
            }
332 333
            adamax_op = block.append_op(
                type=self.type,
334 335 336
                inputs=inputs,
                outputs=outputs,
                attrs=attrs,
337 338
                stop_gradient=True,
            )
339 340

            return adamax_op
M
MRXLT 已提交
341 342

    def _finish_update(self, block, parameters_and_grads):
343
        """Update Beta1 Power accumulator"""
M
MRXLT 已提交
344
        assert isinstance(block, framework.Block)
345 346 347 348
        if isinstance(parameters_and_grads, list):
            for param, grad in parameters_and_grads:
                if grad is None or param.stop_gradient is True:
                    continue
349 350
                if framework.in_dygraph_mode():
                    beta1_pow_acc = self._get_accumulator(
351 352
                        self._beta1_pow_acc_str, param
                    )
353
                    with no_grad():
354 355 356
                        tmp = _C_ops.scale(
                            beta1_pow_acc, self._beta1, 0.0, True
                        )
357
                        beta1_pow_acc.copy_(tmp, False)
358 359 360 361 362 363 364 365 366 367 368 369 370 371
                else:
                    with param.block.program._optimized_guard(
                        [param, grad]
                    ), name_scope('adamax'):
                        beta1_pow_acc = self._get_accumulator(
                            self._beta1_pow_acc_str, param
                        )
                        block.append_op(
                            type="scale",
                            inputs={"X": beta1_pow_acc},
                            outputs={"Out": beta1_pow_acc},
                            attrs={"scale": self._beta1},
                            stop_gradient=True,
                        )
372 373 374 375
        else:
            for param, grad in parameters_and_grads['params']:
                if grad is None or param.stop_gradient is True:
                    continue
376 377
                if framework.in_dygraph_mode():
                    beta1_pow_acc = self._get_accumulator(
378 379
                        self._beta1_pow_acc_str, param
                    )
380
                    self._beta1 = parameters_and_grads.get(
381 382
                        'beta1', self._default_dict['beta1']
                    )
383
                    with no_grad():
384 385 386
                        tmp = _C_ops.scale(
                            beta1_pow_acc, self._beta1, 0.0, True
                        )
387
                        beta1_pow_acc.copy_(tmp, False)
388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404
                else:
                    with param.block.program._optimized_guard(
                        [param, grad]
                    ), name_scope('adamax'):
                        beta1_pow_acc = self._get_accumulator(
                            self._beta1_pow_acc_str, param
                        )
                        self._beta1 = parameters_and_grads.get(
                            'beta1', self._default_dict['beta1']
                        )
                        block.append_op(
                            type="scale",
                            inputs={"X": beta1_pow_acc},
                            outputs={"Out": beta1_pow_acc},
                            attrs={"scale": self._beta1},
                            stop_gradient=True,
                        )
405 406 407 408 409 410 411

    def _update_param_group(self, parameters):
        self._beta1 = parameters.get('beta1', self._default_dict['beta1'])
        self._beta2 = parameters.get('beta2', self._default_dict['beta2'])
        self._epsilon = parameters.get('epsilon', self._default_dict['epsilon'])
        parameters = parameters.get('params')
        return parameters