提交 e01cbc24 编写于 作者: F Fan Yang 提交者: A. Unique TensorFlower

Change reference to LAMB optimizer from TFA to Model Garden's implementation.

PiperOrigin-RevId: 524902271
上级 e326be64
......@@ -220,8 +220,7 @@ class AdamWeightDecayExperimentalConfig(BaseOptimizerConfig):
class LAMBConfig(BaseOptimizerConfig):
"""Configuration for LAMB optimizer.
The attributes for this class matches the arguments of
tensorflow_addons.optimizers.LAMB.
The attributes for this class matches the arguments of LAMB optimizer.
Attributes:
name: name of the optimizer.
......
......@@ -17,10 +17,12 @@
from absl import logging
import gin
import tensorflow as tf
import tensorflow_addons.optimizers as tfa_optimizers
from official.modeling.optimization import lamb
from official.modeling.optimization import legacy_adamw
AdamWeightDecay = legacy_adamw.AdamWeightDecay
LAMB = lamb.LAMB
class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
......@@ -97,13 +99,14 @@ def create_optimizer(init_lr,
exclude_from_weight_decay=['LayerNorm', 'layer_norm', 'bias'])
elif optimizer_type == 'lamb':
logging.info('using Lamb optimizer')
optimizer = tfa_optimizers.LAMB(
optimizer = LAMB(
learning_rate=lr_schedule,
weight_decay_rate=0.01,
beta_1=beta_1,
beta_2=0.999,
epsilon=1e-6,
exclude_from_weight_decay=['LayerNorm', 'layer_norm', 'bias'])
exclude_from_weight_decay=['LayerNorm', 'layer_norm', 'bias'],
)
else:
raise ValueError('Unsupported optimizer type: ', optimizer_type)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册