lamb.py 14.2 KB
Newer Older
T
Thomas Young 已提交
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import paddle
16
from paddle import _C_ops, _legacy_C_ops
17
from paddle.fluid.executor import global_scope
18

19
from ..fluid import core, framework, unique_name
20 21 22 23
from ..fluid.framework import Variable
from ..fluid.layer_helper import LayerHelper
from .optimizer import Optimizer

24 25
__all__ = []

26 27

class Lamb(Optimizer):
28
    r"""
29 30 31 32 33 34 35 36 37 38 39
    LAMB (Layer-wise Adaptive Moments optimizer for Batching training) Optimizer.

    LAMB Optimizer is designed to scale up the batch size of training without losing
    accuracy, which supports adaptive element-wise updating and accurate layer-wise
    correction. For more information, please refer to `Large Batch Optimization for
    Deep Learning: Training BERT in 76 minutes <https://arxiv.org/abs/1904.00962>`_ .

    The updating of parameters follows:

    ..  math::

40
        m_t &= \beta_1 m_{t - 1}+ (1 - \beta_1)g_t
41

42
        v_t &= \beta_2 v_{t - 1}  + (1 - \beta_2)g_t^2
43

44
        m_t &= \frac{m_t}{\beta_1^t}
45

46
        v_t &= \frac{v_t}{\beta_2^t}
47

48
        r_t &= \frac{m_t}{\sqrt{v_t}+\epsilon}
49

50
        w_t &= w_{t-1} -\eta_t \frac{\left \| w_{t-1}\right \|}{\left \| r_t + \lambda w_{t-1}\right \|} (r_t + \lambda w_{t-1})
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65


    where :math:`m` is the 1st moment, and :math:`v` the 2nd moment, :math:`\\eta` the
    learning rate, :math:`\\lambda` the LAMB weight decay rate.

    Args:
        learning_rate (float|Variable, optional): the learning rate used to update parameters. \
            Can be a float value or a Variable with data type float32. Default 0.001.
        lamb_weight_decay (float, optional): The LAMB weight decay rate. Default 0.01. Remind that weight_decay should be None.
        beta1 (float, optional): The exponential decay rate for the 1st moment estimates.
            Default 0.9.
        beta2 (float, optional): The exponential decay rate for the 2nd moment estimates.
            Default 0.999.
        epsilon (float, optional): A small float value for numerical stability. Default 1e-6.
        parameters (Iterable, optional):  Iterable of ``Variable`` names to update to minimize ``loss``. \
66 67 68 69
            This parameter is required in dygraph mode. And you can specify different options for \
            different parameter groups such as the learning rate, weight decay, etc, \
            then the parameters are list of dict. Note that the learning_rate in paramter groups \
            represents the scale of base learning_rate. \
70 71 72
            The default value is None in static mode, at this time all parameters will be updated.
        grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of
            some derived class of ``GradientClipBase`` . There are three cliping strategies
73 74 75
            ( :ref:`api_paddle_fluid_clip_ClipGradByGlobalNorm` , :ref:`api_paddle_fluid_clip_ClipGradByNorm` ,
            :ref:`api_paddle_fluid_clip_ClipGradByValue` ). If you want better convergence, it is recommended
            to use :ref:`api_paddle_fluid_clip_ClipGradByGlobalNorm` . Default None, meaning there is no gradient clipping.
76 77 78 79
        name(str|None): For detailed information, please refer to
            :ref:`api_guide_Name` . Usually name is no need to set and None by default.
    Examples:
        .. code-block:: python
80

81
            import paddle
82 83

            inp = paddle.uniform(shape=[10, 10], dtype='float32', min=-0.1, max=0.1)
84 85 86 87 88 89 90 91 92
            linear = paddle.nn.Linear(10, 10)
            out = linear(inp)
            loss = paddle.mean(out)
            beta1 = paddle.to_tensor([0.9], dtype="float32")
            beta2 = paddle.to_tensor([0.85], dtype="float32")
            lamb = paddle.optimizer.Lamb(learning_rate=0.002, parameters=linear.parameters(), lamb_weight_decay=0.01)
            back = out.backward()
            lamb.step()
            lamb.clear_grad()
93

94 95 96 97 98 99
    """
    _moment1_acc_str = "moment1"
    _moment2_acc_str = "moment2"
    _beta1_pow_acc_str = "beta1_pow_acc"
    _beta2_pow_acc_str = "beta2_pow_acc"

100 101 102 103 104 105 106 107 108 109 110 111 112
    def __init__(
        self,
        learning_rate=0.001,
        lamb_weight_decay=0.01,
        beta1=0.9,
        beta2=0.999,
        epsilon=1e-6,
        parameters=None,
        grad_clip=None,
        exclude_from_weight_decay_fn=None,
        multi_precision=False,
        name=None,
    ):
113 114 115 116
        assert learning_rate is not None
        assert beta1 is not None
        assert beta2 is not None
        assert epsilon is not None
117
        super().__init__(
118 119 120 121 122 123
            learning_rate=learning_rate,
            parameters=parameters,
            weight_decay=None,
            grad_clip=grad_clip,
            name=name,
        )
124 125 126 127 128
        self.type = "lamb"
        self._beta1 = beta1
        self._beta2 = beta2
        self._epsilon = epsilon
        self._lamb_weight_decay = lamb_weight_decay
129
        self._exclude_from_weight_decay_fn = exclude_from_weight_decay_fn
130 131 132 133 134 135 136
        self._default_dict = {
            'beta1': beta1,
            'beta2': beta2,
            'epsilon': epsilon,
            'lamb_weight_decay': lamb_weight_decay,
            'exclude_from_weight_decay_fn': exclude_from_weight_decay_fn,
        }
137
        self._master_weights = {}
138
        self._used_master_weights = {}
139
        # TODO(zengjinle): expose API as soon as possible
140
        self._multi_precision = multi_precision
141

142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
    def _get_parameter(self, name, scope=None):
        if scope is None:
            scope = global_scope()

        p_t = scope.find_var(name).get_tensor()

        master_name = self._used_master_weights.get(name)
        if master_name is not None:
            master_p_t = scope.find_var(master_name).get_tensor()
            assert master_p_t._dtype() != p_t._dtype()
            assert master_p_t.shape() == p_t.shape()
        else:
            master_p_t = None
        return p_t, master_p_t

157 158 159 160 161 162 163 164 165
    def _create_master_weight(self, param):
        assert self._multi_precision
        if param.name in self._master_weights:
            var = self._master_weights[param.name]
        else:
            assert isinstance(self.helper, LayerHelper)

            var_name = param.name + "_fp32_master"
            var_name = unique_name.generate(var_name)
166
            var = paddle.static.create_global_var(
167 168 169 170 171 172
                name=var_name,
                shape=param.shape,
                value=0,
                dtype='float32',
                persistable=True,
            )
173
            block = self.helper.startup_program.global_block()
174 175 176 177 178 179 180 181 182
            block.append_op(
                type="cast",
                inputs={"X": [param]},
                outputs={"Out": [var]},
                attrs={
                    "in_dtype": param.dtype,
                    "out_dtype": core.VarDesc.VarType.FP32,
                },
            )
183 184
            self._master_weights[param.name] = var
        return var
185 186 187

    def _create_accumulators(self, block, parameters):
        assert isinstance(block, framework.Block)
188 189
        if isinstance(parameters, dict):
            parameters = self._update_param_group(parameters)
190 191 192

        # Create accumulator tensors for first and second moments
        for p in parameters:
193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208
            if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16:
                master_p = self._create_master_weight(p)
                self._add_moments_pows(master_p)
            else:
                self._add_moments_pows(p)

    def _get_accumulator(self, name, param):
        """Utility function to fetch an accumulator for a parameter
        Args:
            name: name of the accumulator
            param: parameter variable for which accumulator is to be fetched
        Returns:
            accumulator variable for the parameter
        """
        if self._name is not None:
            name = self._name + "_" + name
209 210 211 212 213 214
        find_master = (
            self._multi_precision and param.dtype == core.VarDesc.VarType.FP16
        )
        target_param = (
            self._master_weights[param.name] if find_master else param
        )
215
        target_name = target_param.name
216 217 218 219
        if (
            name not in self._accumulators
            or target_name not in self._accumulators[name]
        ):
220 221
            raise Exception(
                "Accumulator {} does not exist for parameter {}".format(
222 223 224
                    name, target_name
                )
            )
225 226 227 228 229 230 231 232 233 234
        return self._accumulators[name][target_name]

    def _add_moments_pows(self, p):
        acc_dtype = p.dtype
        if acc_dtype == core.VarDesc.VarType.FP16:
            acc_dtype = core.VarDesc.VarType.FP32

        self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype)
        self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype)
        self._add_accumulator(
235 236 237 238 239 240 241 242 243 244
            name=self._beta1_pow_acc_str,
            param=p,
            dtype=acc_dtype,
            fill_value=0.9
            if isinstance(self._beta1, Variable)
            else self._beta1,
            shape=[1],
            type=core.VarDesc.VarType.LOD_TENSOR,
            device='cpu',
        )
245
        self._add_accumulator(
246 247 248 249 250 251 252 253 254 255
            name=self._beta2_pow_acc_str,
            param=p,
            dtype=acc_dtype,
            fill_value=0.999
            if isinstance(self._beta2, Variable)
            else self._beta2,
            shape=[1],
            type=core.VarDesc.VarType.LOD_TENSOR,
            device='cpu',
        )
256 257 258

    def _append_optimize_op(self, block, param_and_grad):
        assert isinstance(block, framework.Block)
259 260 261
        if isinstance(param_and_grad, dict):
            param_and_grad = self._update_param_group(param_and_grad)

262 263
        block.program._use_lamb = True

264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280
        moment1 = self._get_accumulator(
            self._moment1_acc_str, param_and_grad[0]
        )
        moment2 = self._get_accumulator(
            self._moment2_acc_str, param_and_grad[0]
        )
        beta1_pow_acc = self._get_accumulator(
            self._beta1_pow_acc_str, param_and_grad[0]
        )
        beta2_pow_acc = self._get_accumulator(
            self._beta2_pow_acc_str, param_and_grad[0]
        )

        if (
            self._exclude_from_weight_decay_fn is not None
            and self._exclude_from_weight_decay_fn(param_and_grad[0])
        ):
281 282 283
            weight_decay = 0.0
        else:
            weight_decay = self._lamb_weight_decay
284 285
        lr = self._create_param_lr(param_and_grad)

286 287 288 289
        find_master = (
            self._multi_precision
            and param_and_grad[0].dtype == core.VarDesc.VarType.FP16
        )
290 291 292 293 294 295
        p_name = param_and_grad[0].name
        if find_master:
            master_weight = self._master_weights[p_name]
            self._used_master_weights[p_name] = master_weight.name
        else:
            master_weight = None
296 297
        found_inf = self._get_auxiliary_var('found_inf')

T
Thomas Young 已提交
298
        if framework.in_dygraph_mode():
299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314
            _C_ops.lamb_(
                param_and_grad[0],
                param_and_grad[1],
                lr,
                moment1,
                moment2,
                beta1_pow_acc,
                beta2_pow_acc,
                master_weight,
                found_inf,
                weight_decay,
                self._beta1,
                self._beta2,
                self._epsilon,
                find_master,
            )
T
Thomas Young 已提交
315
            return None
J
Jiabin Yang 已提交
316
        if framework._non_static_mode():
317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342
            _legacy_C_ops.lamb(
                param_and_grad[0],
                param_and_grad[1],
                lr,
                moment1,
                moment2,
                beta1_pow_acc,
                beta2_pow_acc,
                master_weight,
                param_and_grad[0],
                moment1,
                moment2,
                beta1_pow_acc,
                beta2_pow_acc,
                master_weight,
                'beta1',
                self._beta1,
                'beta2',
                self._beta2,
                'epsilon',
                self._epsilon,
                'weight_decay',
                weight_decay,
                'multi_precision',
                find_master,
            )
343
            return None
344 345

        # create the lamb optimize op
346 347 348 349 350 351 352
        inputs = {
            "Param": param_and_grad[0],
            "Grad": param_and_grad[1],
            "LearningRate": lr,
            "Moment1": moment1,
            "Moment2": moment2,
            "Beta1Pow": beta1_pow_acc,
353
            "Beta2Pow": beta2_pow_acc,
354 355 356 357 358 359
        }
        outputs = {
            "ParamOut": param_and_grad[0],
            "Moment1Out": moment1,
            "Moment2Out": moment2,
            "Beta1PowOut": beta1_pow_acc,
360
            "Beta2PowOut": beta2_pow_acc,
361 362 363 364 365
        }
        attrs = {
            "beta1": self._beta1,
            "beta2": self._beta2,
            "epsilon": self._epsilon,
366 367
            "weight_decay": weight_decay,
            "multi_precision": find_master,
368 369
        }

370 371 372 373 374 375 376
        if find_master:
            inputs["MasterParam"] = master_weight
            outputs["MasterParamOut"] = master_weight

        if found_inf:
            inputs["SkipUpdate"] = found_inf

377 378 379 380 381 382 383
        lamb_op = block.append_op(
            type=self.type,
            inputs=inputs,
            outputs=outputs,
            attrs=attrs,
            stop_gradient=True,
        )
384 385

        return lamb_op
386 387 388 389 390 391

    def _update_param_group(self, parameters):
        self._beta1 = parameters.get('beta1', self._default_dict['beta1'])
        self._beta2 = parameters.get('beta2', self._default_dict['beta2'])
        self._epsilon = parameters.get('epsilon', self._default_dict['epsilon'])
        self._lamb_weight_decay = parameters.get(
392 393
            'lamb_weight_decay', self._default_dict['lamb_weight_decay']
        )
394 395
        self._exclude_from_weight_decay_fn = parameters.get(
            'exclude_from_weight_decay_fn',
396 397
            self._default_dict['exclude_from_weight_decay_fn'],
        )
398 399
        parameters = parameters.get('params')
        return parameters