lamb.py 14.2 KB
Newer Older
T
Thomas Young 已提交
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .optimizer import Optimizer
from ..fluid import core
from ..fluid import framework
from ..fluid.framework import Variable
19 20 21
from ..fluid import layers
from ..fluid import unique_name
from ..fluid.layer_helper import LayerHelper
22
from paddle import _C_ops, _legacy_C_ops
23
from paddle.fluid.executor import global_scope
24

25 26
__all__ = []

27 28

class Lamb(Optimizer):
29
    r"""
30 31 32 33 34 35 36 37 38 39 40
    LAMB (Layer-wise Adaptive Moments optimizer for Batching training) Optimizer.

    LAMB Optimizer is designed to scale up the batch size of training without losing
    accuracy, which supports adaptive element-wise updating and accurate layer-wise
    correction. For more information, please refer to `Large Batch Optimization for
    Deep Learning: Training BERT in 76 minutes <https://arxiv.org/abs/1904.00962>`_ .

    The updating of parameters follows:

    ..  math::

41
        m_t &= \beta_1 m_{t - 1}+ (1 - \beta_1)g_t
42

43
        v_t &= \beta_2 v_{t - 1}  + (1 - \beta_2)g_t^2
44

45
        m_t &= \frac{m_t}{\beta_1^t}
46

47
        v_t &= \frac{v_t}{\beta_2^t}
48

49
        r_t &= \frac{m_t}{\sqrt{v_t}+\epsilon}
50

51
        w_t &= w_{t-1} -\eta_t \frac{\left \| w_{t-1}\right \|}{\left \| r_t + \lambda w_{t-1}\right \|} (r_t + \lambda w_{t-1})
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66


    where :math:`m` is the 1st moment, and :math:`v` the 2nd moment, :math:`\\eta` the
    learning rate, :math:`\\lambda` the LAMB weight decay rate.

    Args:
        learning_rate (float|Variable, optional): the learning rate used to update parameters. \
            Can be a float value or a Variable with data type float32. Default 0.001.
        lamb_weight_decay (float, optional): The LAMB weight decay rate. Default 0.01. Remind that weight_decay should be None.
        beta1 (float, optional): The exponential decay rate for the 1st moment estimates.
            Default 0.9.
        beta2 (float, optional): The exponential decay rate for the 2nd moment estimates.
            Default 0.999.
        epsilon (float, optional): A small float value for numerical stability. Default 1e-6.
        parameters (Iterable, optional):  Iterable of ``Variable`` names to update to minimize ``loss``. \
67 68 69 70
            This parameter is required in dygraph mode. And you can specify different options for \
            different parameter groups such as the learning rate, weight decay, etc, \
            then the parameters are list of dict. Note that the learning_rate in paramter groups \
            represents the scale of base learning_rate. \
71 72 73
            The default value is None in static mode, at this time all parameters will be updated.
        grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of
            some derived class of ``GradientClipBase`` . There are three cliping strategies
74 75 76
            ( :ref:`api_paddle_fluid_clip_ClipGradByGlobalNorm` , :ref:`api_paddle_fluid_clip_ClipGradByNorm` ,
            :ref:`api_paddle_fluid_clip_ClipGradByValue` ). If you want better convergence, it is recommended
            to use :ref:`api_paddle_fluid_clip_ClipGradByGlobalNorm` . Default None, meaning there is no gradient clipping.
77 78 79 80
        name(str|None): For detailed information, please refer to
            :ref:`api_guide_Name` . Usually name is no need to set and None by default.
    Examples:
        .. code-block:: python
81

82
            import paddle
83 84

            inp = paddle.uniform(shape=[10, 10], dtype='float32', min=-0.1, max=0.1)
85 86 87 88 89 90 91 92 93
            linear = paddle.nn.Linear(10, 10)
            out = linear(inp)
            loss = paddle.mean(out)
            beta1 = paddle.to_tensor([0.9], dtype="float32")
            beta2 = paddle.to_tensor([0.85], dtype="float32")
            lamb = paddle.optimizer.Lamb(learning_rate=0.002, parameters=linear.parameters(), lamb_weight_decay=0.01)
            back = out.backward()
            lamb.step()
            lamb.clear_grad()
94

95 96 97 98 99 100
    """
    _moment1_acc_str = "moment1"
    _moment2_acc_str = "moment2"
    _beta1_pow_acc_str = "beta1_pow_acc"
    _beta2_pow_acc_str = "beta2_pow_acc"

101 102 103 104 105 106 107 108 109 110 111 112 113
    def __init__(
        self,
        learning_rate=0.001,
        lamb_weight_decay=0.01,
        beta1=0.9,
        beta2=0.999,
        epsilon=1e-6,
        parameters=None,
        grad_clip=None,
        exclude_from_weight_decay_fn=None,
        multi_precision=False,
        name=None,
    ):
114 115 116 117
        assert learning_rate is not None
        assert beta1 is not None
        assert beta2 is not None
        assert epsilon is not None
118
        super().__init__(
119 120 121 122 123 124
            learning_rate=learning_rate,
            parameters=parameters,
            weight_decay=None,
            grad_clip=grad_clip,
            name=name,
        )
125 126 127 128 129
        self.type = "lamb"
        self._beta1 = beta1
        self._beta2 = beta2
        self._epsilon = epsilon
        self._lamb_weight_decay = lamb_weight_decay
130
        self._exclude_from_weight_decay_fn = exclude_from_weight_decay_fn
131 132 133 134 135 136 137
        self._default_dict = {
            'beta1': beta1,
            'beta2': beta2,
            'epsilon': epsilon,
            'lamb_weight_decay': lamb_weight_decay,
            'exclude_from_weight_decay_fn': exclude_from_weight_decay_fn,
        }
138
        self._master_weights = {}
139
        self._used_master_weights = {}
140
        # TODO(zengjinle): expose API as soon as possible
141
        self._multi_precision = multi_precision
142

143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
    def _get_parameter(self, name, scope=None):
        if scope is None:
            scope = global_scope()

        p_t = scope.find_var(name).get_tensor()

        master_name = self._used_master_weights.get(name)
        if master_name is not None:
            master_p_t = scope.find_var(master_name).get_tensor()
            assert master_p_t._dtype() != p_t._dtype()
            assert master_p_t.shape() == p_t.shape()
        else:
            master_p_t = None
        return p_t, master_p_t

158 159 160 161 162 163 164 165 166
    def _create_master_weight(self, param):
        assert self._multi_precision
        if param.name in self._master_weights:
            var = self._master_weights[param.name]
        else:
            assert isinstance(self.helper, LayerHelper)

            var_name = param.name + "_fp32_master"
            var_name = unique_name.generate(var_name)
167 168 169 170 171 172 173
            var = layers.create_global_var(
                name=var_name,
                shape=param.shape,
                value=0,
                dtype='float32',
                persistable=True,
            )
174
            block = self.helper.startup_program.global_block()
175 176 177 178 179 180 181 182 183
            block.append_op(
                type="cast",
                inputs={"X": [param]},
                outputs={"Out": [var]},
                attrs={
                    "in_dtype": param.dtype,
                    "out_dtype": core.VarDesc.VarType.FP32,
                },
            )
184 185
            self._master_weights[param.name] = var
        return var
186 187 188

    def _create_accumulators(self, block, parameters):
        assert isinstance(block, framework.Block)
189 190
        if isinstance(parameters, dict):
            parameters = self._update_param_group(parameters)
191 192 193

        # Create accumulator tensors for first and second moments
        for p in parameters:
194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
            if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16:
                master_p = self._create_master_weight(p)
                self._add_moments_pows(master_p)
            else:
                self._add_moments_pows(p)

    def _get_accumulator(self, name, param):
        """Utility function to fetch an accumulator for a parameter
        Args:
            name: name of the accumulator
            param: parameter variable for which accumulator is to be fetched
        Returns:
            accumulator variable for the parameter
        """
        if self._name is not None:
            name = self._name + "_" + name
210 211 212 213 214 215
        find_master = (
            self._multi_precision and param.dtype == core.VarDesc.VarType.FP16
        )
        target_param = (
            self._master_weights[param.name] if find_master else param
        )
216
        target_name = target_param.name
217 218 219 220
        if (
            name not in self._accumulators
            or target_name not in self._accumulators[name]
        ):
221 222
            raise Exception(
                "Accumulator {} does not exist for parameter {}".format(
223 224 225
                    name, target_name
                )
            )
226 227 228 229 230 231 232 233 234 235
        return self._accumulators[name][target_name]

    def _add_moments_pows(self, p):
        acc_dtype = p.dtype
        if acc_dtype == core.VarDesc.VarType.FP16:
            acc_dtype = core.VarDesc.VarType.FP32

        self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype)
        self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype)
        self._add_accumulator(
236 237 238 239 240 241 242 243 244 245
            name=self._beta1_pow_acc_str,
            param=p,
            dtype=acc_dtype,
            fill_value=0.9
            if isinstance(self._beta1, Variable)
            else self._beta1,
            shape=[1],
            type=core.VarDesc.VarType.LOD_TENSOR,
            device='cpu',
        )
246
        self._add_accumulator(
247 248 249 250 251 252 253 254 255 256
            name=self._beta2_pow_acc_str,
            param=p,
            dtype=acc_dtype,
            fill_value=0.999
            if isinstance(self._beta2, Variable)
            else self._beta2,
            shape=[1],
            type=core.VarDesc.VarType.LOD_TENSOR,
            device='cpu',
        )
257 258 259

    def _append_optimize_op(self, block, param_and_grad):
        assert isinstance(block, framework.Block)
260 261 262
        if isinstance(param_and_grad, dict):
            param_and_grad = self._update_param_group(param_and_grad)

263 264
        block.program._use_lamb = True

265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281
        moment1 = self._get_accumulator(
            self._moment1_acc_str, param_and_grad[0]
        )
        moment2 = self._get_accumulator(
            self._moment2_acc_str, param_and_grad[0]
        )
        beta1_pow_acc = self._get_accumulator(
            self._beta1_pow_acc_str, param_and_grad[0]
        )
        beta2_pow_acc = self._get_accumulator(
            self._beta2_pow_acc_str, param_and_grad[0]
        )

        if (
            self._exclude_from_weight_decay_fn is not None
            and self._exclude_from_weight_decay_fn(param_and_grad[0])
        ):
282 283 284
            weight_decay = 0.0
        else:
            weight_decay = self._lamb_weight_decay
285 286
        lr = self._create_param_lr(param_and_grad)

287 288 289 290
        find_master = (
            self._multi_precision
            and param_and_grad[0].dtype == core.VarDesc.VarType.FP16
        )
291 292 293 294 295 296
        p_name = param_and_grad[0].name
        if find_master:
            master_weight = self._master_weights[p_name]
            self._used_master_weights[p_name] = master_weight.name
        else:
            master_weight = None
297 298
        found_inf = self._get_auxiliary_var('found_inf')

T
Thomas Young 已提交
299
        if framework.in_dygraph_mode():
300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315
            _C_ops.lamb_(
                param_and_grad[0],
                param_and_grad[1],
                lr,
                moment1,
                moment2,
                beta1_pow_acc,
                beta2_pow_acc,
                master_weight,
                found_inf,
                weight_decay,
                self._beta1,
                self._beta2,
                self._epsilon,
                find_master,
            )
T
Thomas Young 已提交
316
            return None
J
Jiabin Yang 已提交
317
        if framework._non_static_mode():
318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343
            _legacy_C_ops.lamb(
                param_and_grad[0],
                param_and_grad[1],
                lr,
                moment1,
                moment2,
                beta1_pow_acc,
                beta2_pow_acc,
                master_weight,
                param_and_grad[0],
                moment1,
                moment2,
                beta1_pow_acc,
                beta2_pow_acc,
                master_weight,
                'beta1',
                self._beta1,
                'beta2',
                self._beta2,
                'epsilon',
                self._epsilon,
                'weight_decay',
                weight_decay,
                'multi_precision',
                find_master,
            )
344
            return None
345 346

        # create the lamb optimize op
347 348 349 350 351 352 353
        inputs = {
            "Param": param_and_grad[0],
            "Grad": param_and_grad[1],
            "LearningRate": lr,
            "Moment1": moment1,
            "Moment2": moment2,
            "Beta1Pow": beta1_pow_acc,
354
            "Beta2Pow": beta2_pow_acc,
355 356 357 358 359 360
        }
        outputs = {
            "ParamOut": param_and_grad[0],
            "Moment1Out": moment1,
            "Moment2Out": moment2,
            "Beta1PowOut": beta1_pow_acc,
361
            "Beta2PowOut": beta2_pow_acc,
362 363 364 365 366
        }
        attrs = {
            "beta1": self._beta1,
            "beta2": self._beta2,
            "epsilon": self._epsilon,
367 368
            "weight_decay": weight_decay,
            "multi_precision": find_master,
369 370
        }

371 372 373 374 375 376 377
        if find_master:
            inputs["MasterParam"] = master_weight
            outputs["MasterParamOut"] = master_weight

        if found_inf:
            inputs["SkipUpdate"] = found_inf

378 379 380 381 382 383 384
        lamb_op = block.append_op(
            type=self.type,
            inputs=inputs,
            outputs=outputs,
            attrs=attrs,
            stop_gradient=True,
        )
385 386

        return lamb_op
387 388 389 390 391 392

    def _update_param_group(self, parameters):
        self._beta1 = parameters.get('beta1', self._default_dict['beta1'])
        self._beta2 = parameters.get('beta2', self._default_dict['beta2'])
        self._epsilon = parameters.get('epsilon', self._default_dict['epsilon'])
        self._lamb_weight_decay = parameters.get(
393 394
            'lamb_weight_decay', self._default_dict['lamb_weight_decay']
        )
395 396
        self._exclude_from_weight_decay_fn = parameters.get(
            'exclude_from_weight_decay_fn',
397 398
            self._default_dict['exclude_from_weight_decay_fn'],
        )
399 400
        parameters = parameters.get('params')
        return parameters