提交 a7894f9e 编写于 作者: C Chen Qian 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 424391275
上级 885fda09
......@@ -45,6 +45,8 @@ class OptimizerConfig(oneof.OneOfConfig):
"""
type: Optional[str] = None
sgd: opt_cfg.SGDConfig = opt_cfg.SGDConfig()
sgd_experimental: opt_cfg.SGDExperimentalConfig = (
opt_cfg.SGDExperimentalConfig())
adam: opt_cfg.AdamConfig = opt_cfg.AdamConfig()
adamw: opt_cfg.AdamWeightDecayConfig = opt_cfg.AdamWeightDecayConfig()
lamb: opt_cfg.LAMBConfig = opt_cfg.LAMBConfig()
......
......@@ -54,6 +54,26 @@ class SGDConfig(BaseOptimizerConfig):
momentum: float = 0.0
# TODO(b/216129465): Merge this config with SGDConfig after the experimental
# optimizer graduates.
@dataclasses.dataclass
class SGDExperimentalConfig(BaseOptimizerConfig):
"""Configuration for SGD optimizer.
The attributes for this class matches the arguments of
`tf.keras.optimizer.experimental.SGD`.
Attributes:
name: name of the optimizer.
nesterov: nesterov for SGD optimizer.
momentum: momentum for SGD optimizer.
"""
name: str = "SGD"
nesterov: bool = False
momentum: float = 0.0
jit_compile: bool = False
@dataclasses.dataclass
class RMSPropConfig(BaseOptimizerConfig):
"""Configuration for RMSProp optimizer.
......
......@@ -18,7 +18,6 @@ from typing import Callable, Optional, Union, List, Tuple
import gin
import tensorflow as tf
import tensorflow_addons.optimizers as tfa_optimizers
from official.modeling.optimization import slide_optimizer
from official.modeling.optimization import adafactor_optimizer
from official.modeling.optimization import ema_optimizer
......@@ -29,6 +28,7 @@ from official.nlp import optimization as nlp_optimization
OPTIMIZERS_CLS = {
'sgd': tf.keras.optimizers.SGD,
'sgd_experimental': tf.keras.optimizers.experimental.SGD,
'adam': tf.keras.optimizers.Adam,
'adamw': nlp_optimization.AdamWeightDecay,
'lamb': tfa_optimizers.LAMB,
......@@ -178,7 +178,8 @@ class OptimizerFactory:
takes an optimizer and returns an optimizer.
Returns:
tf.keras.optimizers.Optimizer instance.
`tf.keras.optimizers.Optimizer` or
`tf.keras.optimizers.experimental.Optimizer` instance.
"""
optimizer_dict = self._optimizer_config.as_dict()
......@@ -201,8 +202,10 @@ class OptimizerFactory:
optimizer, **self._ema_config.as_dict())
if postprocessor:
optimizer = postprocessor(optimizer)
assert isinstance(optimizer, tf.keras.optimizers.Optimizer), (
'OptimizerFactory.build_optimizer returning a non-optimizer object: '
assert isinstance(
optimizer, (tf.keras.optimizers.Optimizer,
tf.keras.optimizers.experimental.Optimizer)
), ('OptimizerFactory.build_optimizer returning a non-optimizer object: '
'{}'.format(optimizer))
return optimizer
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册