lamb.py 13.6 KB
Newer Older
T
Thomas Young 已提交
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import paddle
16
from paddle import _C_ops
17
from paddle.fluid.executor import global_scope
18

19
from ..fluid import core, framework, unique_name
20 21 22 23
from ..fluid.framework import Variable
from ..fluid.layer_helper import LayerHelper
from .optimizer import Optimizer

24 25
__all__ = []

26 27

class Lamb(Optimizer):
28
    r"""
29 30 31 32 33 34 35 36 37 38 39
    LAMB (Layer-wise Adaptive Moments optimizer for Batching training) Optimizer.

    LAMB Optimizer is designed to scale up the batch size of training without losing
    accuracy, which supports adaptive element-wise updating and accurate layer-wise
    correction. For more information, please refer to `Large Batch Optimization for
    Deep Learning: Training BERT in 76 minutes <https://arxiv.org/abs/1904.00962>`_ .

    The updating of parameters follows:

    ..  math::

40
        m_t &= \beta_1 m_{t - 1}+ (1 - \beta_1)g_t
41

42
        v_t &= \beta_2 v_{t - 1}  + (1 - \beta_2)g_t^2
43

44
        m_t &= \frac{m_t}{\beta_1^t}
45

46
        v_t &= \frac{v_t}{\beta_2^t}
47

48
        r_t &= \frac{m_t}{\sqrt{v_t}+\epsilon}
49

50
        w_t &= w_{t-1} -\eta_t \frac{\left \| w_{t-1}\right \|}{\left \| r_t + \lambda w_{t-1}\right \|} (r_t + \lambda w_{t-1})
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65


    where :math:`m` is the 1st moment, and :math:`v` the 2nd moment, :math:`\\eta` the
    learning rate, :math:`\\lambda` the LAMB weight decay rate.

    Args:
        learning_rate (float|Variable, optional): the learning rate used to update parameters. \
            Can be a float value or a Variable with data type float32. Default 0.001.
        lamb_weight_decay (float, optional): The LAMB weight decay rate. Default 0.01. Remind that weight_decay should be None.
        beta1 (float, optional): The exponential decay rate for the 1st moment estimates.
            Default 0.9.
        beta2 (float, optional): The exponential decay rate for the 2nd moment estimates.
            Default 0.999.
        epsilon (float, optional): A small float value for numerical stability. Default 1e-6.
        parameters (Iterable, optional):  Iterable of ``Variable`` names to update to minimize ``loss``. \
66 67 68 69
            This parameter is required in dygraph mode. And you can specify different options for \
            different parameter groups such as the learning rate, weight decay, etc, \
            then the parameters are list of dict. Note that the learning_rate in paramter groups \
            represents the scale of base learning_rate. \
70
            The default value is None in static graph mode, at this time all parameters will be updated.
71 72
        grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of
            some derived class of ``GradientClipBase`` . There are three cliping strategies
73 74 75
            ( :ref:`api_paddle_fluid_clip_ClipGradByGlobalNorm` , :ref:`api_paddle_fluid_clip_ClipGradByNorm` ,
            :ref:`api_paddle_fluid_clip_ClipGradByValue` ). If you want better convergence, it is recommended
            to use :ref:`api_paddle_fluid_clip_ClipGradByGlobalNorm` . Default None, meaning there is no gradient clipping.
76 77 78 79
        name(str|None): For detailed information, please refer to
            :ref:`api_guide_Name` . Usually name is no need to set and None by default.
    Examples:
        .. code-block:: python
80

81
            import paddle
82 83

            inp = paddle.uniform(shape=[10, 10], dtype='float32', min=-0.1, max=0.1)
84 85 86 87 88 89 90 91 92
            linear = paddle.nn.Linear(10, 10)
            out = linear(inp)
            loss = paddle.mean(out)
            beta1 = paddle.to_tensor([0.9], dtype="float32")
            beta2 = paddle.to_tensor([0.85], dtype="float32")
            lamb = paddle.optimizer.Lamb(learning_rate=0.002, parameters=linear.parameters(), lamb_weight_decay=0.01)
            back = out.backward()
            lamb.step()
            lamb.clear_grad()
93

94 95 96 97 98 99
    """
    _moment1_acc_str = "moment1"
    _moment2_acc_str = "moment2"
    _beta1_pow_acc_str = "beta1_pow_acc"
    _beta2_pow_acc_str = "beta2_pow_acc"

100 101 102 103 104 105 106 107 108 109 110 111 112
    def __init__(
        self,
        learning_rate=0.001,
        lamb_weight_decay=0.01,
        beta1=0.9,
        beta2=0.999,
        epsilon=1e-6,
        parameters=None,
        grad_clip=None,
        exclude_from_weight_decay_fn=None,
        multi_precision=False,
        name=None,
    ):
113 114 115 116
        assert learning_rate is not None
        assert beta1 is not None
        assert beta2 is not None
        assert epsilon is not None
117
        super().__init__(
118 119 120 121 122 123
            learning_rate=learning_rate,
            parameters=parameters,
            weight_decay=None,
            grad_clip=grad_clip,
            name=name,
        )
124 125 126 127 128
        self.type = "lamb"
        self._beta1 = beta1
        self._beta2 = beta2
        self._epsilon = epsilon
        self._lamb_weight_decay = lamb_weight_decay
129
        self._exclude_from_weight_decay_fn = exclude_from_weight_decay_fn
130 131 132 133 134 135 136
        self._default_dict = {
            'beta1': beta1,
            'beta2': beta2,
            'epsilon': epsilon,
            'lamb_weight_decay': lamb_weight_decay,
            'exclude_from_weight_decay_fn': exclude_from_weight_decay_fn,
        }
137
        self._master_weights = {}
138
        self._used_master_weights = {}
139
        # TODO(zengjinle): expose API as soon as possible
140
        self._multi_precision = multi_precision
141

142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
    def _get_parameter(self, name, scope=None):
        if scope is None:
            scope = global_scope()

        p_t = scope.find_var(name).get_tensor()

        master_name = self._used_master_weights.get(name)
        if master_name is not None:
            master_p_t = scope.find_var(master_name).get_tensor()
            assert master_p_t._dtype() != p_t._dtype()
            assert master_p_t.shape() == p_t.shape()
        else:
            master_p_t = None
        return p_t, master_p_t

157 158 159 160 161 162 163 164 165
    def _create_master_weight(self, param):
        assert self._multi_precision
        if param.name in self._master_weights:
            var = self._master_weights[param.name]
        else:
            assert isinstance(self.helper, LayerHelper)

            var_name = param.name + "_fp32_master"
            var_name = unique_name.generate(var_name)
166
            var = paddle.static.create_global_var(
167 168 169 170 171 172
                name=var_name,
                shape=param.shape,
                value=0,
                dtype='float32',
                persistable=True,
            )
173
            block = self.helper.startup_program.global_block()
174 175 176 177 178 179 180 181 182
            block.append_op(
                type="cast",
                inputs={"X": [param]},
                outputs={"Out": [var]},
                attrs={
                    "in_dtype": param.dtype,
                    "out_dtype": core.VarDesc.VarType.FP32,
                },
            )
183 184
            self._master_weights[param.name] = var
        return var
185 186 187

    def _create_accumulators(self, block, parameters):
        assert isinstance(block, framework.Block)
188 189
        if isinstance(parameters, dict):
            parameters = self._update_param_group(parameters)
190 191 192

        # Create accumulator tensors for first and second moments
        for p in parameters:
193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208
            if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16:
                master_p = self._create_master_weight(p)
                self._add_moments_pows(master_p)
            else:
                self._add_moments_pows(p)

    def _get_accumulator(self, name, param):
        """Utility function to fetch an accumulator for a parameter
        Args:
            name: name of the accumulator
            param: parameter variable for which accumulator is to be fetched
        Returns:
            accumulator variable for the parameter
        """
        if self._name is not None:
            name = self._name + "_" + name
209 210 211 212 213 214
        find_master = (
            self._multi_precision and param.dtype == core.VarDesc.VarType.FP16
        )
        target_param = (
            self._master_weights[param.name] if find_master else param
        )
215
        target_name = target_param.name
216 217 218 219
        if (
            name not in self._accumulators
            or target_name not in self._accumulators[name]
        ):
220 221
            raise Exception(
                "Accumulator {} does not exist for parameter {}".format(
222 223 224
                    name, target_name
                )
            )
225 226 227 228 229 230 231 232 233 234
        return self._accumulators[name][target_name]

    def _add_moments_pows(self, p):
        acc_dtype = p.dtype
        if acc_dtype == core.VarDesc.VarType.FP16:
            acc_dtype = core.VarDesc.VarType.FP32

        self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype)
        self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype)
        self._add_accumulator(
235 236 237 238 239 240 241 242 243 244
            name=self._beta1_pow_acc_str,
            param=p,
            dtype=acc_dtype,
            fill_value=0.9
            if isinstance(self._beta1, Variable)
            else self._beta1,
            shape=[1],
            type=core.VarDesc.VarType.LOD_TENSOR,
            device='cpu',
        )
245
        self._add_accumulator(
246 247 248 249 250 251 252 253 254 255
            name=self._beta2_pow_acc_str,
            param=p,
            dtype=acc_dtype,
            fill_value=0.999
            if isinstance(self._beta2, Variable)
            else self._beta2,
            shape=[1],
            type=core.VarDesc.VarType.LOD_TENSOR,
            device='cpu',
        )
256 257 258

    def _append_optimize_op(self, block, param_and_grad):
        assert isinstance(block, framework.Block)
259 260 261
        if isinstance(param_and_grad, dict):
            param_and_grad = self._update_param_group(param_and_grad)

262 263
        block.program._use_lamb = True

264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280
        moment1 = self._get_accumulator(
            self._moment1_acc_str, param_and_grad[0]
        )
        moment2 = self._get_accumulator(
            self._moment2_acc_str, param_and_grad[0]
        )
        beta1_pow_acc = self._get_accumulator(
            self._beta1_pow_acc_str, param_and_grad[0]
        )
        beta2_pow_acc = self._get_accumulator(
            self._beta2_pow_acc_str, param_and_grad[0]
        )

        if (
            self._exclude_from_weight_decay_fn is not None
            and self._exclude_from_weight_decay_fn(param_and_grad[0])
        ):
281 282 283
            weight_decay = 0.0
        else:
            weight_decay = self._lamb_weight_decay
284 285
        lr = self._create_param_lr(param_and_grad)

286 287 288 289
        find_master = (
            self._multi_precision
            and param_and_grad[0].dtype == core.VarDesc.VarType.FP16
        )
290 291 292 293 294 295
        p_name = param_and_grad[0].name
        if find_master:
            master_weight = self._master_weights[p_name]
            self._used_master_weights[p_name] = master_weight.name
        else:
            master_weight = None
296

T
Thomas Young 已提交
297
        if framework.in_dygraph_mode():
298 299 300 301 302 303 304 305 306
            _C_ops.lamb_(
                param_and_grad[0],
                param_and_grad[1],
                lr,
                moment1,
                moment2,
                beta1_pow_acc,
                beta2_pow_acc,
                master_weight,
307
                None,
308 309 310 311 312 313
                weight_decay,
                self._beta1,
                self._beta2,
                self._epsilon,
                find_master,
            )
T
Thomas Young 已提交
314
            return None
315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344
        else:
            # create the lamb optimize op
            inputs = {
                "Param": param_and_grad[0],
                "Grad": param_and_grad[1],
                "LearningRate": lr,
                "Moment1": moment1,
                "Moment2": moment2,
                "Beta1Pow": beta1_pow_acc,
                "Beta2Pow": beta2_pow_acc,
            }
            outputs = {
                "ParamOut": param_and_grad[0],
                "Moment1Out": moment1,
                "Moment2Out": moment2,
                "Beta1PowOut": beta1_pow_acc,
                "Beta2PowOut": beta2_pow_acc,
            }
            attrs = {
                "beta1": self._beta1,
                "beta2": self._beta2,
                "epsilon": self._epsilon,
                "weight_decay": weight_decay,
                "multi_precision": find_master,
            }

            if find_master:
                inputs["MasterParam"] = master_weight
                outputs["MasterParamOut"] = master_weight

345
            found_inf = self._get_auxiliary_var('found_inf')
346 347 348 349 350 351 352 353 354
            if found_inf:
                inputs["SkipUpdate"] = found_inf

            lamb_op = block.append_op(
                type=self.type,
                inputs=inputs,
                outputs=outputs,
                attrs=attrs,
                stop_gradient=True,
355
            )
356

357
            return lamb_op
358 359 360 361 362 363

    def _update_param_group(self, parameters):
        self._beta1 = parameters.get('beta1', self._default_dict['beta1'])
        self._beta2 = parameters.get('beta2', self._default_dict['beta2'])
        self._epsilon = parameters.get('epsilon', self._default_dict['epsilon'])
        self._lamb_weight_decay = parameters.get(
364 365
            'lamb_weight_decay', self._default_dict['lamb_weight_decay']
        )
366 367
        self._exclude_from_weight_decay_fn = parameters.get(
            'exclude_from_weight_decay_fn',
368 369
            self._default_dict['exclude_from_weight_decay_fn'],
        )
370 371
        parameters = parameters.get('params')
        return parameters