adamw.py 15.7 KB
Newer Older
Z
zhaoyingli 已提交
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
M
MRXLT 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .optimizer import Optimizer
from .adam import Adam
17
from ..fluid import core
M
MRXLT 已提交
18
from ..fluid import framework
R
Roc 已提交
19
from ..fluid.framework import Variable
20
from ..fluid.dygraph import base as imperative_base
21
from collections.abc import Callable
M
MRXLT 已提交
22
import paddle
23

R
Roc 已提交
24 25
_C_ops = core.ops

26 27
__all__ = []

M
MRXLT 已提交
28

M
MRXLT 已提交
29
class AdamW(Adam):
30
    r"""
31
    The AdamW optimizer is implemented based on the AdamW Optimization
M
MRXLT 已提交
32 33 34 35 36 37 38
    in paper `DECOUPLED WEIGHT DECAY REGULARIZATION <https://arxiv.org/pdf/1711.05101.pdf>`_.
    it can resolves the problem of L2 regularization failure in the Adam optimizer.

    .. math::

        t & = t + 1

39
        moment\_1\_out & = {\beta}_1 * moment\_1 + (1 - {\beta}_1) * grad
40

41
        moemnt\_2\_out & = {\beta}_2 * moment\_2 + (1 - {\beta}_2) * grad * grad
M
MRXLT 已提交
42

43 44
        learning\_rate & = learning\_rate * 
            \frac{\sqrt{1 - {\beta}_2^t}}{1 - {beta}_1^t}
M
MRXLT 已提交
45

46
        param\_out & = param - learning\_rate * (\frac{moment\_1}{\sqrt{moment\_2} + \epsilon} + \lambda * param)
M
MRXLT 已提交
47 48 49


    Args:
50 51
        learning_rate (float|LRScheduler, optional): The learning rate used to update ``Parameter``.
            It can be a float value or a LRScheduler. The default value is 0.001.
Z
zhaoyingli 已提交
52 53
        parameters (list|tuple, optional): List/Tuple of ``Tensor`` names to update to minimize ``loss``. \
            This parameter is required in dygraph mode. And you can specify different options for \
54 55 56 57
            different parameter groups such as the learning rate, weight decay, etc, \
            then the parameters are list of dict. Note that the learning_rate in paramter groups \
            represents the scale of base learning_rate. \
	    The default value is None in static mode, at this time all parameters will be updated.
M
MRXLT 已提交
58 59 60 61 62 63 64 65
        beta1 (float|Tensor, optional): The exponential decay rate for the 1st moment estimates.
            It should be a float number or a Tensor with shape [1] and data type as float32.
            The default value is 0.9.
        beta2 (float|Tensor, optional): The exponential decay rate for the 2nd moment estimates.
            It should be a float number or a Tensor with shape [1] and data type as float32.
            The default value is 0.999.
        epsilon (float, optional): A small float value for numerical stability.
            The default value is 1e-08.
M
MRXLT 已提交
66
        weight_decay (float|Tensor, optional): The weight decay coefficient, it can be float or Tensor. The default value is 0.01.
67 68 69 70
        lr_ratio (function|None, optional): If it is not None, 
            the learning rate will be updated with layerwise learning rate ratio.
            Otherwise, the learning rate is the original.
            Default: None.
M
MRXLT 已提交
71
        apply_decay_param_fun (function|None, optional): If it is not None,
72
            only tensors that makes apply_decay_param_fun(Tensor.name)==True
H
hutuxian 已提交
73
            will be updated with weight decay. It only works when we want to specify tensors.
M
MRXLT 已提交
74
            Default: None.
75 76 77
        grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of
            some derived class of ``GradientClipBase`` . There are three cliping strategies
            ( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` ,
M
MRXLT 已提交
78 79 80 81 82 83 84 85
            :ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping.
        lazy_mode (bool, optional): The official Adam algorithm has two moving-average accumulators.
            The accumulators are updated at every step. Every element of the two moving-average
            is updated in both dense mode and sparse mode. If the size of parameter is very large,
            then the update may be very slow. The lazy mode only update the element that has
            gradient in current mini-batch, so it will be much more faster. But this mode has
            different semantics with the original Adam algorithm and may lead to different result.
            The default value is False.
86
        multi_precision (bool, optional): Whether to use multi-precision during weight updating. Default is false.
87 88 89
        name (str, optional): Normally there is no need for user to set this property.
            For more information, please refer to :ref:`api_guide_Name`.
            The default value is None.
M
MRXLT 已提交
90 91 92 93 94
    **Notes**:
        **Currently, AdamW doesn't support sparse parameter optimization.**

    Examples:
        .. code-block:: python
C
Chen Long 已提交
95
            
M
MRXLT 已提交
96 97 98
            import paddle

            linear = paddle.nn.Linear(10, 10)
99
            inp = paddle.rand([10,10], dtype="float32")
M
MRXLT 已提交
100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
            out = linear(inp)
            loss = paddle.mean(out)

            beta1 = paddle.to_tensor([0.9], dtype="float32")
            beta2 = paddle.to_tensor([0.99], dtype="float32")

            adam = paddle.optimizer.AdamW(learning_rate=0.1,
                    parameters=linear.parameters(),
                    beta1=beta1,
                    beta2=beta2,
                    weight_decay=0.01)
            out.backward()
            adam.step()
            adam.clear_grad()

115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138

            #Note that the learning_rate of linear_2 is 0.01.
            linear_1 = paddle.nn.Linear(10, 10)
            linear_2 = paddle.nn.Linear(10, 10)
            inp = paddle.uniform(shape=[10, 10], min=-0.1, max=0.1)
            out = linear_1(inp)
            out = linear_2(out)
            loss = paddle.mean(out)
            adam = paddle.optimizer.AdamW(
                learning_rate=0.1,
                parameters=[{
                    'params': linear_1.parameters()
                }, {
                    'params': linear_2.parameters(),
                    'weight_decay': 0.001,
                    'learning_rate': 0.1,
                    'beta1': 0.8
                }],
                weight_decay=0.01,
                beta1=0.9)                   
            out.backward()
            adam.step()
            adam.clear_grad()

M
MRXLT 已提交
139 140 141 142 143 144 145
    """

    def __init__(self,
                 learning_rate=0.001,
                 beta1=0.9,
                 beta2=0.999,
                 epsilon=1e-8,
M
MRXLT 已提交
146 147
                 parameters=None,
                 weight_decay=0.01,
148
                 lr_ratio=None,
M
MRXLT 已提交
149 150
                 apply_decay_param_fun=None,
                 grad_clip=None,
151
                 lazy_mode=False,
152
                 multi_precision=False,
153
                 name=None):
M
MRXLT 已提交
154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
        assert learning_rate is not None
        assert beta1 is not None
        assert beta2 is not None
        assert epsilon is not None
        if not 0 <= beta1 < 1:
            raise ValueError("Invaild value of beta1, expect beta1 in [0,1).")
        if not 0 <= beta2 < 1:
            raise ValueError("Invaild value of beta2, expect beta2 in [0,1).")
        if not 0 <= epsilon:
            raise ValueError("Invaild value of epsilon, expect epsilon >= 0.")
        coeff = weight_decay
        if not isinstance(coeff, float) and \
                not isinstance(coeff, framework.Variable):
            raise TypeError("coeff should be float or Tensor.")
        self._params_name = set()
        self._apply_decay_param_fun = apply_decay_param_fun
        self._coeff = coeff
171
        self._lr_to_coeff = dict()
172 173
        if lr_ratio is not None:
            assert isinstance(lr_ratio, Callable)
Z
zhaoyingli 已提交
174
            if not core.is_compiled_with_cuda():
175
                raise NotImplementedError(
Z
zhaoyingli 已提交
176
                    "'lr_ratio' is unimplemented in CPU, XPU and NPU")
177
        self._lr_ratio = lr_ratio
178

M
MRXLT 已提交
179
        super(AdamW, self).__init__(
M
MRXLT 已提交
180 181 182 183 184 185 186
            learning_rate=learning_rate,
            parameters=parameters,
            beta1=beta1,
            beta2=beta2,
            epsilon=epsilon,
            grad_clip=grad_clip,
            name=name,
187 188
            lazy_mode=lazy_mode,
            multi_precision=multi_precision)
189
        self._default_dict = {'coeff': coeff}
M
MRXLT 已提交
190

R
Roc 已提交
191 192 193 194 195 196 197 198 199 200 201 202 203 204
        self.type = "adamw"

        # Use _auxiliary_vars together with _set_auxiliary_var/_get_auxiliary_var to achieve that.
        self._auxiliary_vars = dict()

    def _set_auxiliary_var(self, key, val):
        self._auxiliary_vars[key] = val

    def _get_auxiliary_var(self, key):
        if key in self._auxiliary_vars:
            return self._auxiliary_vars[key]
        else:
            return None

205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257
    def _append_decoupled_weight_decay(self, block, param_and_grad):
        """
        Add decoupled weight decay op.
            parameter = parameter - parameter * coeff * lr
        Args:
            block: block in which variable is to be created
            param_and_grad: (parameters, gradients) pairs,
                the parameters need to decay.
        Raises:
            Exception: The type of coeff and parameter is not consistent.
        """
        if isinstance(param_and_grad, dict):
            param_and_grad = self._update_param_group(param_and_grad)
        param, grad = param_and_grad

        if self._apply_decay_param_fun is not None \
                and not self._apply_decay_param_fun(param.name):
            return

        if isinstance(self._learning_rate, float):
            learning_rate = self._learning_rate
        else:
            # NOTE. We add this function to the _append_optimize_op(),
            # for we must make sure _create_param_lr() be called after
            # optimizer._create_global_learning_rate().
            learning_rate = self._create_param_lr(param_and_grad)

        with block.program._optimized_guard(
            [param, grad]), framework.name_scope('weight decay'):
            self._params_name.add(param.name)

            # If it has been calculated, the result will be reused.
            # NOTE(wangxi): In dygraph mode, apply_gradient will be executed
            # every step, so need clear _lr_to_coeff every step,
            # we do this in _create_optimization_pass
            decay_coeff = self._lr_to_coeff.get(learning_rate, None)
            if decay_coeff is None:
                # NOTE(wangxi): for pipeline to set device:all
                with paddle.static.device_guard(None):
                    decay_coeff = 1.0 - learning_rate * self._coeff
                self._lr_to_coeff[learning_rate] = decay_coeff

            find_master = (self._multi_precision and
                           param.dtype == core.VarDesc.VarType.FP16)
            if find_master:
                master_weight = self._master_weights[param.name]
                scaled_param = master_weight * decay_coeff
                paddle.fluid.layers.assign(
                    input=scaled_param, output=master_weight)
            else:
                scaled_param = param * decay_coeff
                paddle.fluid.layers.assign(input=scaled_param, output=param)

W
WangXi 已提交
258
    def _append_optimize_op(self, block, param_and_grad):
R
Roc 已提交
259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283
        assert isinstance(block, framework.Block)
        if isinstance(param_and_grad, dict):
            param_and_grad = self._update_param_group(param_and_grad)
        param, grad = param_and_grad

        # Whether we should do weight decay for the parameter.
        with_decay = True
        if self._apply_decay_param_fun is not None \
                and not self._apply_decay_param_fun(param.name):
            with_decay = False

        moment1 = self._get_accumulator(self._moment1_acc_str,
                                        param_and_grad[0])
        moment2 = self._get_accumulator(self._moment2_acc_str,
                                        param_and_grad[0])
        beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str,
                                              param_and_grad[0])
        beta2_pow_acc = self._get_accumulator(self._beta2_pow_acc_str,
                                              param_and_grad[0])
        find_master = self._multi_precision and param_and_grad[
            0].dtype == core.VarDesc.VarType.FP16
        master_weight = (self._master_weights[param_and_grad[0].name]
                         if find_master else None)
        lr = self._create_param_lr(param_and_grad)

Z
zhaoyingli 已提交
284
        # create the adamw optimize op
R
Roc 已提交
285
        if framework.in_dygraph_mode():
286 287
            lr_ratio_ = 1. if self._lr_ratio is None else self._lr_ratio(
                param_and_grad[0])
R
Roc 已提交
288 289 290 291 292

            _beta1 = self._beta1 if not isinstance(
                self._beta1, Variable) else self._beta1.numpy().item(0)
            _beta2 = self._beta2 if not isinstance(
                self._beta2, Variable) else self._beta2.numpy().item(0)
293

294
            _, _, _, _, _, _ = _C_ops.adamw(
R
Roc 已提交
295
                param_and_grad[0], param_and_grad[1], lr, moment1, moment2,
296 297 298 299
                beta1_pow_acc, beta2_pow_acc, master_weight, param_and_grad[0],
                moment1, moment2, beta1_pow_acc, beta2_pow_acc, master_weight,
                'epsilon', self._epsilon, 'lazy_mode', self._lazy_mode,
                'min_row_size_to_use_multithread', 1000, 'beta1', _beta1,
Z
zhaoyingli 已提交
300 301
                'beta2', _beta2, "with_decay", with_decay, 'coeff', self._coeff,
                'multi_precision', find_master, 'lr_ratio', lr_ratio_)
R
Roc 已提交
302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332
            return None

        inputs = {
            "Param": [param_and_grad[0]],
            "Grad": [param_and_grad[1]],
            "LearningRate": [lr],
            "Moment1": [moment1],
            "Moment2": [moment2],
            "Beta1Pow": [beta1_pow_acc],
            "Beta2Pow": [beta2_pow_acc],
        }

        # Pass found_inf to adamw, to skip update for not only param, but also momentum and beta_pow
        found_inf = self._get_auxiliary_var('found_inf')

        if found_inf:
            inputs['SkipUpdate'] = found_inf

        outputs = {
            "ParamOut": [param_and_grad[0]],
            "Moment1Out": [moment1],
            "Moment2Out": [moment2],
            "Beta1PowOut": [beta1_pow_acc],
            "Beta2PowOut": [beta2_pow_acc],
        }
        attrs = {
            "lazy_mode": self._lazy_mode,
            "min_row_size_to_use_multithread": 1000,
            "multi_precision": find_master,
            "with_decay": with_decay,
            "coeff": self._coeff,
333 334
            "lr_ratio": 1.
            if self._lr_ratio is None else self._lr_ratio(param_and_grad[0])
R
Roc 已提交
335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361
        }

        if isinstance(self._beta1, Variable):
            inputs['Beta1Tensor'] = self._beta1
        else:
            attrs['beta1'] = self._beta1
        if isinstance(self._beta2, Variable):
            inputs['Beta2Tensor'] = self._beta2
        else:
            attrs['beta2'] = self._beta2
        if isinstance(self._epsilon, Variable):
            inputs['EpsilonTensor'] = self._epsilon
        else:
            attrs['epsilon'] = self._epsilon

        if find_master:
            inputs["MasterParam"] = master_weight
            outputs["MasterParamOut"] = master_weight

        adamw_op = block.append_op(
            type=self.type,
            inputs=inputs,
            outputs=outputs,
            attrs=attrs,
            stop_gradient=True)

        return adamw_op
M
MRXLT 已提交
362

363 364 365 366 367 368 369
    def _create_optimization_pass(self, parameters_and_grads):
        optimize_ops = super(
            AdamW, self)._create_optimization_pass(parameters_and_grads)
        # In dygraph mode, clear _lr_to_coeff after applied gradient
        self._lr_to_coeff = dict()
        return optimize_ops

M
MRXLT 已提交
370 371
    def __str__(self):
        return " ".join(["Weight Decay, params:", ",".join(self._params_name)])
372 373 374 375 376

    def _update_param_group(self, parameters):
        self._coeff = parameters.get('coeff', self._default_dict['coeff'])
        parameters = parameters.get('params')
        return parameters