提交 24593b1c 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Adds `get_config` and `from_config` to Optimizers V2.

PiperOrigin-RevId: 216546565
上级 f0225119
......@@ -143,10 +143,12 @@ class CheckpointingTests(test.TestCase):
suffix = "/.ATTRIBUTES/VARIABLE_VALUE"
expected_checkpoint_names = [
name + suffix for name in expected_checkpoint_names]
# The Dense layers also save get_config() JSON
expected_checkpoint_names.extend(
["model/_second/.ATTRIBUTES/OBJECT_CONFIG_JSON",
"model/_named_dense/.ATTRIBUTES/OBJECT_CONFIG_JSON"])
# The optimizer and Dense layers also save get_config() JSON
expected_checkpoint_names.extend([
"optimizer/.ATTRIBUTES/OBJECT_CONFIG_JSON",
"model/_second/.ATTRIBUTES/OBJECT_CONFIG_JSON",
"model/_named_dense/.ATTRIBUTES/OBJECT_CONFIG_JSON"
])
named_variables = {v.name: v for v in named_variables}
six.assertCountEqual(self, expected_checkpoint_names,
named_variables.keys())
......
......@@ -37,6 +37,7 @@ class Adadelta(optimizer_v2.OptimizerV2):
Tensor or a Python value.
Arguments:
learning_rate: float hyperparameter >= 0. Learning rate. It is recommended
to leave it at the default value.
rho: float hyperparameter >= 0. The decay rate.
......@@ -114,3 +115,12 @@ class Adadelta(optimizer_v2.OptimizerV2):
grad,
indices,
use_locking=self._use_locking)
def get_config(self):
config = super(Adadelta, self).get_config()
config.update({
"learning_rate": self._serialize_hyperparameter("learning_rate"),
"rho": self._serialize_hyperparameter("rho"),
"epsilon": self._serialize_hyperparameter("epsilon")
})
return config
......@@ -22,6 +22,7 @@ import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.keras.optimizer_v2 import adadelta
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
......@@ -161,6 +162,22 @@ class AdadeltaOptimizerTest(test.TestCase):
self.assertAllCloseAccordingToType(
[[-111, -138]], var0.eval())
def testConfig(self):
def rho():
return ops.convert_to_tensor(1.0)
epsilon = ops.convert_to_tensor(1.0)
opt = adadelta.Adadelta(learning_rate=1.0, rho=rho, epsilon=epsilon)
config = opt.get_config()
opt2 = adadelta.Adadelta.from_config(config)
self.assertEqual(opt._hyper["learning_rate"][1],
opt2._hyper["learning_rate"][1])
self.assertEqual(opt._hyper["rho"][1].__name__,
opt2._hyper["rho"][1].__name__)
self.assertEqual(opt._hyper["epsilon"][1], opt2._hyper["epsilon"][1])
if __name__ == "__main__":
test.main()
......@@ -117,3 +117,11 @@ class Adagrad(optimizer_v2.OptimizerV2):
grad,
indices,
use_locking=self._use_locking)
def get_config(self):
config = super(Adagrad, self).get_config()
config.update({
"learning_rate": self._serialize_hyperparameter("learning_rate"),
"initial_accumulator_value": self._initial_accumulator_value
})
return config
......@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import types as python_types
import numpy as np
from tensorflow.python.framework import constant_op
......@@ -271,6 +273,17 @@ class AdagradOptimizerTest(test.TestCase):
# Creating optimizer should cause no exception.
adagrad.Adagrad(3.0, initial_accumulator_value=0.1)
def testConfig(self):
opt = adagrad.Adagrad(
learning_rate=lambda: ops.convert_to_tensor(1.0),
initial_accumulator_value=2.0)
config = opt.get_config()
opt2 = adagrad.Adagrad.from_config(config)
self.assertIsInstance(opt2._hyper["learning_rate"][1],
python_types.LambdaType)
self.assertEqual(opt._initial_accumulator_value,
opt2._initial_accumulator_value)
if __name__ == "__main__":
test.main()
......@@ -201,3 +201,13 @@ class Adam(optimizer_v2.OptimizerV2):
update_beta_2 = beta_2_power.assign(
beta_2_power * state.get_hyper("beta_2"), use_locking=self._use_locking)
return control_flow_ops.group(update_beta_1, update_beta_2)
def get_config(self):
config = super(Adam, self).get_config()
config.update({
"learning_rate": self._serialize_hyperparameter("learning_rate"),
"beta_1": self._serialize_hyperparameter("beta_1"),
"beta_2": self._serialize_hyperparameter("beta_2"),
"epsilon": self._serialize_hyperparameter("epsilon")
})
return config
......@@ -329,5 +329,16 @@ class AdamOptimizerTest(test.TestCase):
# for v1 and v2 respectively.
self.assertEqual(6, len(set(opt.variables())))
def testConfig(self):
opt = adam.Adam(learning_rate=1.0, beta_1=2.0, beta_2=3.0, epsilon=4.0)
config = opt.get_config()
opt2 = adam.Adam.from_config(config)
self.assertEqual(opt._hyper["learning_rate"][1],
opt2._hyper["learning_rate"][1])
self.assertEqual(opt._hyper["beta_1"][1], opt2._hyper["beta_1"][1])
self.assertEqual(opt._hyper["beta_2"][1], opt2._hyper["beta_2"][1])
self.assertEqual(opt._hyper["epsilon"][1], opt2._hyper["epsilon"][1])
if __name__ == "__main__":
test.main()
......@@ -143,10 +143,12 @@ class CheckpointingTests(test.TestCase):
suffix = "/.ATTRIBUTES/VARIABLE_VALUE"
expected_checkpoint_names = [
name + suffix for name in expected_checkpoint_names]
# The Dense layers also save get_config() JSON
expected_checkpoint_names.extend(
["model/_second/.ATTRIBUTES/OBJECT_CONFIG_JSON",
"model/_named_dense/.ATTRIBUTES/OBJECT_CONFIG_JSON"])
# The optimizer and Dense layers also save get_config() JSON
expected_checkpoint_names.extend([
"optimizer/.ATTRIBUTES/OBJECT_CONFIG_JSON",
"model/_second/.ATTRIBUTES/OBJECT_CONFIG_JSON",
"model/_named_dense/.ATTRIBUTES/OBJECT_CONFIG_JSON"
])
named_variables = {v.name: v for v in named_variables}
six.assertCountEqual(self, expected_checkpoint_names,
named_variables.keys())
......
......@@ -1319,6 +1319,42 @@ class OptimizerV2(optimizer_v1.Optimizer):
variable=variable,
optional_op_name=self._name)
def get_config(self):
"""Returns the config of the optimimizer.
An optimizer config is a Python dictionary (serializable)
containing the configuration of an optimizer.
The same optimizer can be reinstantiated later
(without any saved state) from this configuration.
Returns:
Python dictionary.
"""
return {"name": self._name}
@classmethod
def from_config(cls, config, custom_objects=None):
"""Creates an optimizer from its config.
This method is the reverse of `get_config`,
capable of instantiating the same optimizer from the config
dictionary.
Arguments:
config: A Python dictionary, typically the output of get_config.
custom_objects: A Python dictionary mapping names to additional Python
objects used to create this optimizer, such as a function used for a
hyperparameter.
Returns:
An optimizer instance.
"""
return cls(**config)
def _serialize_hyperparameter(self, hyperparameter_name):
"""Serialize a hyperparameter that can be a float, callable, or Tensor."""
return self._hyper[hyperparameter_name][1]
# --------------
# Unsupported parent methods
# --------------
......
......@@ -237,3 +237,14 @@ class RMSProp(optimizer_v2.OptimizerV2):
grad,
indices,
use_locking=self._use_locking)
def get_config(self):
config = super(RMSProp, self).get_config()
config.update({
"learning_rate": self._serialize_hyperparameter("learning_rate"),
"rho": self._serialize_hyperparameter("rho"),
"momentum": self._serialize_hyperparameter("momentum"),
"epsilon": self._serialize_hyperparameter("epsilon"),
"centered": self._centered
})
return config
......@@ -20,6 +20,7 @@ from __future__ import print_function
import copy
import math
import types as python_types
from absl.testing import parameterized
import numpy as np
......@@ -439,6 +440,27 @@ class RMSPropOptimizerTest(test.TestCase, parameterized.TestCase):
(0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5)))
]), var1.eval())
def testConfig(self):
def momentum():
return ops.convert_to_tensor(3.0)
opt = rmsprop.RMSProp(
learning_rate=1.0,
rho=2.0,
momentum=momentum,
epsilon=lambda: ops.convert_to_tensor(4.0),
centered=True)
config = opt.get_config()
opt2 = rmsprop.RMSProp.from_config(config)
self.assertEqual(opt._hyper["learning_rate"][1],
opt2._hyper["learning_rate"][1])
self.assertEqual(opt._hyper["rho"][1], opt2._hyper["rho"][1])
self.assertEqual(opt._hyper["momentum"][1].__name__,
opt2._hyper["momentum"][1].__name__)
self.assertIsInstance(opt2._hyper["epsilon"][1], python_types.LambdaType)
self.assertEqual(True, opt2._centered)
if __name__ == "__main__":
test.main()
......@@ -168,3 +168,17 @@ class SGD(optimizer_v2.OptimizerV2):
grad.values * state.get_hyper("learning_rate", var.dtype.base_dtype),
grad.indices, grad.dense_shape)
return var.scatter_sub(delta, use_locking=self._use_locking)
def get_config(self):
config = super(SGD, self).get_config()
# Control whether momentum variables are created.
if not self._use_momentum:
momentum = None
else:
momentum = self._serializer_hyperparameter("momentum")
config.update({
"learning_rate": self._serialize_hyperparameter("learning_rate"),
"momentum": momentum,
"nesterov": self._use_nesterov
})
return config
......@@ -754,6 +754,20 @@ class MomentumOptimizerTest(test.TestCase):
(0.9 * 0.01 + 0.01) * 2.0)
]), var1.eval())
def testConfig(self):
opt = sgd.SGD(learning_rate=1.0, momentum=2.0, nesterov=True)
config = opt.get_config()
opt2 = sgd.SGD.from_config(config)
self.assertEqual(opt._hyper["learning_rate"][1],
opt2._hyper["learning_rate"][1])
self.assertEqual(opt._hyper["momentum"][1], opt2._hyper["momentum"][1])
self.assertEqual(opt2._use_nesterov, True)
opt = sgd.SGD(momentum=None)
config = opt.get_config()
opt2 = sgd.SGD.from_config(config)
self.assertEqual(False, opt2._use_momentum)
if __name__ == "__main__":
test.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册