From 6a6ed34d6cf9ae1870c742a0bc6123fb95bfb239 Mon Sep 17 00:00:00 2001 From: Hongkun Yu Date: Thu, 13 Apr 2023 12:56:02 -0700 Subject: [PATCH] Copy lamb optimizer from Tensorflow Addons as TFA is not maintained. Rename lars_optimizer to lars to be consistent. PiperOrigin-RevId: 524078216 --- official/modeling/optimization/lamb.py | 252 ++++++++++++++++++ official/modeling/optimization/lamb_test.py | 177 ++++++++++++ .../{lars_optimizer.py => lars.py} | 0 .../optimization/optimizer_factory.py | 10 +- 4 files changed, 434 insertions(+), 5 deletions(-) create mode 100644 official/modeling/optimization/lamb.py create mode 100644 official/modeling/optimization/lamb_test.py rename official/modeling/optimization/{lars_optimizer.py => lars.py} (100%) diff --git a/official/modeling/optimization/lamb.py b/official/modeling/optimization/lamb.py new file mode 100644 index 000000000..de352c46b --- /dev/null +++ b/official/modeling/optimization/lamb.py @@ -0,0 +1,252 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Layer-wise Adaptive Moments (LAMB) optimizer. + +See paper [Large Batch Optimization for Deep Learning: Training BERT in +76 minutes](https://arxiv.org/abs/1904.00962). +""" +import re +from typing import Optional, Union, Callable, List + +import numpy as np +import tensorflow as tf + +FloatTensorLike = Union[tf.Tensor, float, np.float16, np.float32] + + +@tf.keras.utils.register_keras_serializable(package="Addons") +class LAMB(tf.keras.optimizers.legacy.Optimizer): + """Optimizer that implements the Layer-wise Adaptive Moments (LAMB). + + See paper [Large Batch Optimization for Deep Learning: Training BERT + in 76 minutes](https://arxiv.org/abs/1904.00962). + """ + + def __init__( + self, + learning_rate: Union[FloatTensorLike, Callable] = 0.001, + beta_1: FloatTensorLike = 0.9, + beta_2: FloatTensorLike = 0.999, + epsilon: FloatTensorLike = 1e-6, + weight_decay_rate: FloatTensorLike = 0.0, + exclude_from_weight_decay: Optional[List[str]] = None, + exclude_from_layer_adaptation: Optional[List[str]] = None, + name: str = "LAMB", + **kwargs, + ): + """Construct a new LAMB optimizer. + + Args: + learning_rate: A `Tensor` or a floating point value. or a schedule that + is a `tf.keras.optimizers.schedules.LearningRateSchedule` The learning + rate. + beta_1: A `float` value or a constant `float` tensor. The exponential + decay rate for the 1st moment estimates. + beta_2: A `float` value or a constant `float` tensor. The exponential + decay rate for the 2nd moment estimates. + epsilon: A small constant for numerical stability. + weight_decay_rate: weight decay rate. + exclude_from_weight_decay: List of regex patterns of variables excluded + from weight decay. Variables whose name contain a substring matching + the pattern will be excluded. + exclude_from_layer_adaptation: List of regex patterns of variables + excluded from layer adaptation. Variables whose name contain a + substring matching the pattern will be excluded. + name: Optional name for the operations created when applying gradients. + Defaults to "LAMB". + **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, + `lr`, `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is + clip gradients by value, `decay` is included for backward + compatibility to allow time inverse decay of learning rate. `lr` is + included for backward compatibility, recommended to use + `learning_rate` instead. + """ + super().__init__(name, **kwargs) + + # Just adding the square of the weights to the loss function is *not* + # the correct way of using L2 regularization/weight decay with Adam, + # since that will interact with the m and v parameters in strange ways. + # + # Instead we want to decay the weights in a manner that doesn't interact + # with the m/v parameters. + self._set_hyper("weight_decay_rate", weight_decay_rate) + self._set_hyper("learning_rate", kwargs.get("lr", learning_rate)) + + # This is learning rate decay for using keras learning rate schedule. + self._set_hyper("decay", self._initial_decay) + self._set_hyper("beta_1", beta_1) + self._set_hyper("beta_2", beta_2) + self.epsilon = epsilon or tf.backend_config.epsilon() + self.exclude_from_weight_decay = exclude_from_weight_decay + # exclude_from_layer_adaptation is set to exclude_from_weight_decay if + # the arg is None. + if exclude_from_layer_adaptation: + self.exclude_from_layer_adaptation = exclude_from_layer_adaptation + else: + self.exclude_from_layer_adaptation = exclude_from_weight_decay + + def _create_slots(self, var_list): + # Create slots for the first and second moments. + # Separate for-loops to respect the ordering of slot variables from v1. + for var in var_list: + self.add_slot(var, "m") + for var in var_list: + self.add_slot(var, "v") + + def _prepare_local(self, var_device, var_dtype, apply_state): + super()._prepare_local(var_device, var_dtype, apply_state) + + local_step = tf.cast(self.iterations + 1, var_dtype) + beta_1_t = tf.identity(self._get_hyper("beta_1", var_dtype)) + beta_2_t = tf.identity(self._get_hyper("beta_2", var_dtype)) + weight_decay_rate = tf.identity( + self._get_hyper("weight_decay_rate", var_dtype) + ) + beta_1_power = tf.pow(beta_1_t, local_step) + beta_2_power = tf.pow(beta_2_t, local_step) + apply_state[(var_device, var_dtype)].update( + dict( + weight_decay_rate=weight_decay_rate, + epsilon=tf.convert_to_tensor(self.epsilon, var_dtype), + beta_1_t=beta_1_t, + beta_1_power=beta_1_power, + one_minus_beta_1_t=1 - beta_1_t, + beta_2_t=beta_2_t, + beta_2_power=beta_2_power, + one_minus_beta_2_t=1 - beta_2_t, + ) + ) + + def _resource_apply_dense(self, grad, var, apply_state=None): + var_device, var_dtype = var.device, var.dtype.base_dtype + coefficients = (apply_state or {}).get( + (var_device, var_dtype) + ) or self._fallback_apply_state(var_device, var_dtype) + + # m_t = beta1 * m + (1 - beta1) * g_t + m = self.get_slot(var, "m") + m_scaled_g_values = grad * coefficients["one_minus_beta_1_t"] + m_t = m * coefficients["beta_1_t"] + m_scaled_g_values + m_t = m.assign(m_t, use_locking=self._use_locking) + # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) + v = self.get_slot(var, "v") + v_scaled_g_values = (grad * grad) * coefficients["one_minus_beta_2_t"] + v_t = v * coefficients["beta_2_t"] + v_scaled_g_values + v_t = v.assign(v_t, use_locking=self._use_locking) + + m_t_hat = m_t / (1.0 - coefficients["beta_1_power"]) + v_t_hat = v_t / (1.0 - coefficients["beta_2_power"]) + + v_sqrt = tf.sqrt(v_t_hat) + update = m_t_hat / (v_sqrt + coefficients["epsilon"]) + + var_name = self._get_variable_name(var.name) + if self._do_use_weight_decay(var_name): + update += coefficients["weight_decay_rate"] * var + + ratio = 1.0 + if self._do_layer_adaptation(var_name): + w_norm = tf.norm(var, ord=2) + g_norm = tf.norm(update, ord=2) + ratio = tf.where( + tf.greater(w_norm, 0), + tf.where(tf.greater(g_norm, 0), (w_norm / g_norm), 1.0), + 1.0, + ) + + var_update = var - ratio * coefficients["lr_t"] * update + return var.assign(var_update, use_locking=self._use_locking) + + def _resource_apply_sparse(self, grad, var, indices, apply_state=None): + var_device, var_dtype = var.device, var.dtype.base_dtype + coefficients = (apply_state or {}).get( + (var_device, var_dtype) + ) or self._fallback_apply_state(var_device, var_dtype) + + # m_t = beta1 * m + (1 - beta1) * g_t + m = self.get_slot(var, "m") + m_scaled_g_values = grad * coefficients["one_minus_beta_1_t"] + m_t = m.assign(m * coefficients["beta_1_t"], use_locking=self._use_locking) + with tf.control_dependencies([m_t]): + m_t = self._resource_scatter_add(m, indices, m_scaled_g_values) + + # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) + v = self.get_slot(var, "v") + v_scaled_g_values = (grad * grad) * coefficients["one_minus_beta_2_t"] + v_t = v.assign(v * coefficients["beta_2_t"], use_locking=self._use_locking) + with tf.control_dependencies([v_t]): + v_t = self._resource_scatter_add(v, indices, v_scaled_g_values) + + m_t_hat = m_t / (1.0 - coefficients["beta_1_power"]) + v_t_hat = v_t / (1.0 - coefficients["beta_2_power"]) + + v_sqrt = tf.sqrt(v_t_hat) + update = m_t_hat / (v_sqrt + coefficients["epsilon"]) + + var_name = self._get_variable_name(var.name) + if self._do_use_weight_decay(var_name): + update += coefficients["weight_decay_rate"] * var + + ratio = 1.0 + if self._do_layer_adaptation(var_name): + w_norm = tf.norm(var, ord=2) + g_norm = tf.norm(update, ord=2) + ratio = tf.where( + tf.greater(w_norm, 0), + tf.where(tf.greater(g_norm, 0), (w_norm / g_norm), 1.0), + 1.0, + ) + + var_update = var.assign_sub( + ratio * coefficients["lr_t"] * update, use_locking=self._use_locking + ) + return tf.group(*[var_update, m_t, v_t]) + + def get_config(self): + config = super().get_config() + config.update({ + "learning_rate": self._serialize_hyperparameter("learning_rate"), + "weight_decay_rate": self._serialize_hyperparameter( + "weight_decay_rate" + ), + "decay": self._serialize_hyperparameter("decay"), + "beta_1": self._serialize_hyperparameter("beta_1"), + "beta_2": self._serialize_hyperparameter("beta_2"), + "epsilon": self.epsilon, + }) + return config + + def _do_use_weight_decay(self, param_name): + """Whether to use L2 weight decay for `param_name`.""" + if self.exclude_from_weight_decay: + for r in self.exclude_from_weight_decay: + if re.search(r, param_name) is not None: + return False + return True + + def _do_layer_adaptation(self, param_name): + """Whether to do layer-wise learning rate adaptation for `param_name`.""" + if self.exclude_from_layer_adaptation: + for r in self.exclude_from_layer_adaptation: + if re.search(r, param_name) is not None: + return False + return True + + def _get_variable_name(self, param_name): + """Get the variable name from the tensor name.""" + m = re.match("^(.*):\\d+$", param_name) + if m is not None: + param_name = m.group(1) + return param_name diff --git a/official/modeling/optimization/lamb_test.py b/official/modeling/optimization/lamb_test.py new file mode 100644 index 000000000..f6c41ce12 --- /dev/null +++ b/official/modeling/optimization/lamb_test.py @@ -0,0 +1,177 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for LAMB Optimizer.""" +import numpy as np +from numpy import linalg + +import tensorflow as tf + +from official.modeling.optimization import lamb + + +def lamb_update_numpy(param, + g_t, + t, + m, + v, + lr=0.001, + lamb_wd=0.0, + beta1=0.9, + beta2=0.999, + epsilon=1e-6): + + m_t = beta1 * m + (1 - beta1) * g_t + v_t = beta2 * v + (1 - beta2) * g_t * g_t + + m_t_hat = m_t / (1 - beta1**(t + 1)) + v_t_hat = v_t / (1 - beta2**(t + 1)) + update = m_t_hat / (np.sqrt(v_t_hat) + epsilon) + + update += lamb_wd * param + + w_norm = linalg.norm(param, ord=2) + g_norm = linalg.norm(update, ord=2) + ratio = np.where(w_norm > 0, np.where(g_norm > 0, (w_norm / g_norm), 1.0), + 1.0) + + param_t = param - ratio * lr * update + return param_t, m_t, v_t + + +def get_beta_accumulators(opt, dtype): + local_step = tf.cast(opt.iterations + 1, dtype) + beta_1_t = tf.cast(opt._get_hyper("beta_1"), dtype) + beta_1_power = tf.math.pow(beta_1_t, local_step) + beta_2_t = tf.cast(opt._get_hyper("beta_2"), dtype) + beta_2_power = tf.math.pow(beta_2_t, local_step) + return (beta_1_power, beta_2_power) + + +class LAMBTest(tf.test.TestCase): + + def test_sparse(self): + dtype = tf.float32 + # Initialize tf for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.0, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.0, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = tf.Variable(var0_np) + var1 = tf.Variable(var1_np) + grads0_np_indices = np.array([0, 2], dtype=np.int32) + grads0 = tf.IndexedSlices( + tf.constant(grads0_np[grads0_np_indices]), + tf.constant(grads0_np_indices), + tf.constant([3]), + ) + grads1_np_indices = np.array([0, 2], dtype=np.int32) + grads1 = tf.IndexedSlices( + tf.constant(grads1_np[grads1_np_indices]), + tf.constant(grads1_np_indices), + tf.constant([3]), + ) + opt = lamb.LAMB() + + # Fetch params to validate initial values + np.testing.assert_allclose(np.asanyarray([1.0, 1.0, 2.0]), var0.numpy()) + np.testing.assert_allclose(np.asanyarray([3.0, 3.0, 4.0]), var1.numpy()) + + # Run 3 steps of LAMB + for t in range(3): + beta_1_power, beta_2_power = get_beta_accumulators(opt, dtype) + self.assertAllClose(0.9 ** (t + 1), beta_1_power) + self.assertAllClose(0.999 ** (t + 1), beta_2_power) + + opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + + var0_np, m0, v0 = lamb_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = lamb_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllClose(var0_np, var0.numpy()) + self.assertAllClose(var1_np, var1.numpy()) + + def test_basic_with_learning_rate_decay(self): + dtype = tf.float32 + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = tf.Variable(var0_np, name="var0") + var1 = tf.Variable(var1_np, name="var1") + grads0 = tf.constant(grads0_np) + grads1 = tf.constant(grads1_np) + + learning_rate = 0.001 + beta_1 = 0.9 + beta_2 = 0.999 + epsilon = 1e-7 + decay = 0.5 + lamb_wd = 0.01 + + opt = lamb.LAMB( + learning_rate=learning_rate, + beta_1=beta_1, + beta_2=beta_2, + epsilon=epsilon, + weight_decay_rate=lamb_wd, + decay=decay, + ) + + # Run 3 steps of LAMB + for t in range(3): + opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + + lr_np = learning_rate / (1 + decay * t) + + var0_np, m0, v0 = lamb_update_numpy( + var0_np, grads0_np, t, m0, v0, lr=lr_np, lamb_wd=lamb_wd) + var1_np, m1, v1 = lamb_update_numpy( + var1_np, grads1_np, t, m1, v1, lr=lr_np, lamb_wd=lamb_wd) + + # Validate updated params + self.assertAllClose(var0_np, var0.numpy()) + self.assertAllClose(var1_np, var1.numpy()) + + def test_exclude_weight_decay(self): + opt = lamb.LAMB( + 0.01, weight_decay_rate=0.01, exclude_from_weight_decay=["var1"] + ) + assert opt._do_use_weight_decay("var0") + assert not opt._do_use_weight_decay("var1") + assert not opt._do_use_weight_decay("var1_weight") + + def test_exclude_layer_adaptation(self): + opt = lamb.LAMB(0.01, exclude_from_layer_adaptation=["var1"]) + assert opt._do_layer_adaptation("var0") + assert not opt._do_layer_adaptation("var1") + assert not opt._do_layer_adaptation("var1_weight") + + def test_serialization(self): + optimizer = lamb.LAMB(1e-4) + config = tf.keras.optimizers.serialize(optimizer, use_legacy_format=True) + new_optimizer = tf.keras.optimizers.deserialize( + config, use_legacy_format=True + ) + assert new_optimizer.get_config() == optimizer.get_config() + + +if __name__ == "__main__": + tf.test.main() diff --git a/official/modeling/optimization/lars_optimizer.py b/official/modeling/optimization/lars.py similarity index 100% rename from official/modeling/optimization/lars_optimizer.py rename to official/modeling/optimization/lars.py diff --git a/official/modeling/optimization/optimizer_factory.py b/official/modeling/optimization/optimizer_factory.py index 40f78c37e..ca866b9e3 100644 --- a/official/modeling/optimization/optimizer_factory.py +++ b/official/modeling/optimization/optimizer_factory.py @@ -13,16 +13,16 @@ # limitations under the License. """Optimizer factory class.""" -from typing import Callable, Optional, Union, List, Tuple +from typing import Callable, List, Optional, Tuple, Union import gin import tensorflow as tf -import tensorflow_addons.optimizers as tfa_optimizers from official.modeling.optimization import slide_optimizer from official.modeling.optimization import adafactor_optimizer from official.modeling.optimization import ema_optimizer -from official.modeling.optimization import lars_optimizer +from official.modeling.optimization import lamb +from official.modeling.optimization import lars from official.modeling.optimization import legacy_adamw from official.modeling.optimization import lr_schedule from official.modeling.optimization.configs import optimization_config as opt_cfg @@ -33,8 +33,8 @@ SHARED_OPTIMIZERS = { 'adam_experimental': tf.keras.optimizers.experimental.Adam, 'adamw': legacy_adamw.AdamWeightDecay, 'adamw_experimental': tf.keras.optimizers.experimental.AdamW, - 'lamb': tfa_optimizers.LAMB, - 'lars': lars_optimizer.LARS, + 'lamb': lamb.LAMB, + 'lars': lars.LARS, 'slide': slide_optimizer.SLIDE, 'adafactor': adafactor_optimizer.Adafactor, } -- GitLab