提交 f9e99f4d 编写于 作者: R Reed Wanderman-Milne 提交者: TensorFlower Gardener

Change LossScaleOptimizer checkpoint format.

Now the format is identical to as if a LossScaleOptimzier is not used, except that the loss scale is saved with a LossScaleOptimizer. This allows saving checkpoints with a LossScaleOptimizer and restoring without a LossScaleOptimizer, and vice versa.

Checkpoints with LossScaleOptimizers created in older versions of TensorFlow can still be loaded. New checkpoints saved will use the new format.

PiperOrigin-RevId: 306511555
Change-Id: Ie316ab8c4fbfec7babd6f7803d337799d0ff10a5
上级 4e809823
......@@ -224,9 +224,16 @@ cuda_py_test(
name = "keras_test",
size = "medium",
srcs = ["keras_test.py"],
data = [
"//tensorflow/python/keras/mixed_precision/experimental/testdata:lso_ckpt_tf2.2",
"//tensorflow/python/keras/mixed_precision/experimental/testdata:lso_savedmodel_tf2.2",
],
python_version = "PY3",
shard_count = 10,
tags = ["no_windows"], # b/139083295: bfloat16 tests fail on Windows
tags = [
"no_pip",
"no_windows", # b/139083295: bfloat16 tests fail on Windows
],
deps = [
":test_util",
"//tensorflow/python:client_testlib",
......
......@@ -41,6 +41,7 @@ from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.engine import input_spec
from tensorflow.python.keras.engine import sequential
from tensorflow.python.keras.layers import core
from tensorflow.python.keras.mixed_precision.experimental import get_layer_policy
from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer
......@@ -993,6 +994,56 @@ class KerasModelTest(keras_parameterized.TestCase):
self.assertEqual(backend.get_value(loss_scale()), 2)
self.assertEqual(backend.get_value(loss_scale._num_good_steps), 1)
@keras_parameterized.run_all_keras_modes
def test_restore_old_loss_scale_checkpoint(self):
# Ensure a checkpoint from TF 2.2 can be loaded. The checkpoint format
# of LossScaleOptimizer changed, but old checkpoints can still be loaded
opt = gradient_descent.SGD(0.1, momentum=0.1)
opt = loss_scale_optimizer.LossScaleOptimizer(opt, 'dynamic')
model = sequential.Sequential([core.Dense(2,)])
# The checkpoint and expected values were obtained from the program in
# testdata/BUILD.
ckpt_dir = test.test_src_dir_path(
'python/keras/mixed_precision/experimental/testdata/lso_ckpt_tf2.2')
model.load_weights(os.path.join(ckpt_dir, 'ckpt'))
model.compile(opt, 'mse', run_eagerly=testing_utils.should_run_eagerly())
model(np.zeros((2, 2))) # Create model weights
opt._create_all_weights(model.weights)
expected_kernel = np.array([[9.229685, 10.901115], [10.370763, 9.757362]])
expected_slot = np.array([[10.049943, 9.917691], [10.049943, 9.917691]])
self.assertAllClose(self.evaluate(model.weights[0]), expected_kernel)
self.assertAllClose(
self.evaluate(opt.get_slot(model.weights[0], 'momentum')),
expected_slot)
self.assertEqual(self.evaluate(opt.loss_scale()), 32768)
self.assertEqual(self.evaluate(opt.loss_scale._num_good_steps), 1)
# Check restoring works even after the model is compiled and the weights
# have been created.
model.fit(np.random.normal(size=(2, 2)), np.random.normal(size=(2, 2)))
self.assertNotAllClose(self.evaluate(model.weights[0]), expected_kernel)
self.assertNotAllClose(
self.evaluate(opt.get_slot(model.weights[0], 'momentum')),
expected_slot)
model.load_weights(os.path.join(ckpt_dir, 'ckpt'))
self.assertAllClose(self.evaluate(model.weights[0]), expected_kernel)
self.assertAllClose(
self.evaluate(opt.get_slot(model.weights[0], 'momentum')),
expected_slot)
self.assertEqual(self.evaluate(opt.loss_scale()), 32768)
self.assertEqual(self.evaluate(opt.loss_scale._num_good_steps), 1)
def test_restore_old_saved_model(self):
saved_model_dir = test.test_src_dir_path(
'python/keras/mixed_precision/experimental/testdata/'
'lso_savedmodel_tf2.2')
model = save.load_model(saved_model_dir)
expected_kernel = np.array([[9.229685, 10.901115], [10.370763, 9.757362]])
self.assertAllClose(backend.eval(model.weights[0]), expected_kernel)
self.assertIsInstance(model.optimizer,
loss_scale_optimizer.LossScaleOptimizer)
@keras_parameterized.run_all_keras_modes
@parameterized.named_parameters(
{
......
......@@ -22,6 +22,7 @@ from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import mirrored_strategy
from tensorflow.python.distribute import one_device_strategy
from tensorflow.python.distribute import tpu_strategy
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import smart_cond
from tensorflow.python.keras import backend
......@@ -32,6 +33,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.training.experimental import loss_scale as loss_scale_module
from tensorflow.python.training.experimental import mixed_precision
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.util.tf_export import keras_export
......@@ -51,8 +53,126 @@ class _UnwrapPreventer(object):
self.value = value
class _DelegatingTrackableMixin(object):
"""A mixin that delegates all Trackable methods to another trackable object.
This class must be used with multiple inheritance. A class that subclasses
Trackable can also subclass this class, which causes all Trackable methods to
be delegated to the trackable object passed in the constructor.
A subclass can use this mixin to appear as if it were the trackable passed to
the constructor, from a Checkpoint's perspective. LossScaleOptimizer uses this
mixin, so that the checkpoint format for a LossScaleOptimizer is identical to
the checkpoint format for a normal optimizer. This allows a model to be saved
with a normal Optimizer and restored with a LossScaleOptimizer, or vice versa.
The only difference in checkpoint format is that the loss scale is also saved
with a LossScaleOptimizer.
"""
def __init__(self, trackable_obj):
self._trackable = trackable_obj
# pylint: disable=protected-access
@property
def _setattr_tracking(self):
return self._trackable._setattr_tracking
@_setattr_tracking.setter
def _setattr_tracking(self, value):
self._trackable._setattr_tracking = value
@property
def _update_uid(self):
return self._trackable._update_uid
@_update_uid.setter
def _update_uid(self, value):
self._trackable._update_uid = value
@property
def _unconditional_checkpoint_dependencies(self):
return self._trackable._unconditional_checkpoint_dependencies
@property
def _unconditional_dependency_names(self):
return self._trackable._unconditional_dependency_names
@property
def _name_based_restores(self):
return self._trackable._name_based_restores
def _maybe_initialize_trackable(self):
return self._trackable._maybe_initialize_trackable()
@property
def _object_identifier(self):
return self._trackable._object_identifier
@property
def _tracking_metadata(self):
return self._trackable._tracking_metadata
def _no_dependency(self, value):
return self._trackable._no_dependency(value)
def _name_based_attribute_restore(self, checkpoint):
return self._trackable._name_based_attribute_restore(checkpoint)
@property
def _checkpoint_dependencies(self):
return self._trackable._checkpoint_dependencies
@property
def _deferred_dependencies(self):
return self._trackable._deferred_dependencies
def _lookup_dependency(self, name):
self._trackable._lookup_dependency(name)
def _add_variable_with_custom_getter(self,
name,
shape=None,
dtype=dtypes.float32,
initializer=None,
getter=None,
overwrite=False,
**kwargs_for_getter):
return self._trackable._add_variable_with_custom_getter(
name, shape, dtype, initializer, getter, overwrite, **kwargs_for_getter)
def _preload_simple_restoration(self, name, shape):
return self._trackable._preload_simple_restoration(name, shape)
def _track_trackable(self, trackable, name, overwrite=False): # pylint: disable=redefined-outer-name
return self._trackable._track_trackable(trackable, name, overwrite)
def _handle_deferred_dependencies(self, name, trackable): # pylint: disable=redefined-outer-name
return self._trackable._handle_deferred_dependencies(name, trackable)
def _restore_from_checkpoint_position(self, checkpoint_position):
return self._trackable._restore_from_checkpoint_position(
checkpoint_position)
def _single_restoration_from_checkpoint_position(self, checkpoint_position,
visit_queue):
return self._trackable._single_restoration_from_checkpoint_position(
checkpoint_position, visit_queue)
def _gather_saveables_for_checkpoint(self):
return self._trackable._gather_saveables_for_checkpoint()
def _list_extra_dependencies_for_serialization(self, serialization_cache):
return self._trackable._list_extra_dependencies_for_serialization(
serialization_cache)
def _list_functions_for_serialization(self, serialization_cache):
return self._trackable._list_functions_for_serialization(
serialization_cache)
# pylint: enable=protected-access
@keras_export('keras.mixed_precision.experimental.LossScaleOptimizer')
class LossScaleOptimizer(optimizer_v2.OptimizerV2):
class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2):
"""An optimizer that applies loss scaling.
Loss scaling is a process that multiplies the loss by a multiplier called the
......@@ -144,6 +264,11 @@ class LossScaleOptimizer(optimizer_v2.OptimizerV2):
self._loss_scale = keras_loss_scale_module.get(loss_scale)
if self._loss_scale is None:
raise ValueError('loss_scale cannot be None.')
# We don't call super().__init__, since we do not want to call OptimizerV2's
# constructor.
_DelegatingTrackableMixin.__init__(self, self._optimizer)
for weight in loss_scale_module.get_loss_scale_weights(self._loss_scale):
# We cannot call `track_variable` in the LossScale class itself, because a
# file outside of Keras cannot depend on a Keras file. Calling it here
......@@ -151,12 +276,15 @@ class LossScaleOptimizer(optimizer_v2.OptimizerV2):
# a Keras class, and the only way to use LossScale with a Keras class is
# through the LossScaleOptimizer.
backend.track_variable(weight)
self._track_trackable(self._optimizer, 'base_optimizer')
self._track_trackable(self._loss_scale, 'loss_scale')
# Needed because the superclass's __getattribute__ checks this.
self._hyper = {}
# To support restoring TensorFlow 2.2 checkpoints.
self._track_trackable(FakeOptimizerForRestoration(self._optimizer),
'base_optimizer')
@property
def loss_scale(self):
"""The `LossScale` instance associated with this optimizer."""
......@@ -348,6 +476,21 @@ class LossScaleOptimizer(optimizer_v2.OptimizerV2):
def _aggregate_gradients(self, grads_and_vars):
return self._optimizer._aggregate_gradients(grads_and_vars) # pylint: disable=protected-access
def _restore_slot_variable(self, slot_name, variable, slot_variable):
return self._optimizer._restore_slot_variable(slot_name, variable, # pylint: disable=protected-access
slot_variable)
def _create_or_restore_slot_variable(self, slot_variable_position, slot_name,
variable):
return self._optimizer._create_or_restore_slot_variable( # pylint: disable=protected-access
slot_variable_position, slot_name, variable)
def get_slot(self, var, slot_name):
return self._optimizer.get_slot(var, slot_name)
def add_slot(self, var, slot_name, initializer='zeros'):
return self._optimizer.add_slot(var, slot_name, initializer)
# For the most part, we only expose methods in the base OptimizerV2, not
# individual subclasses like Adam. However, although "learning_rate" and "lr"
# properties are not part of the base OptimizerV2 class, they are part of most
......@@ -369,23 +512,6 @@ class LossScaleOptimizer(optimizer_v2.OptimizerV2):
def lr(self, lr):
self._optimizer.lr = lr
def get_slot(self, var, slot_name):
# We cannot implement get_slot for the following reason: When saving a
# checkpoint, two optimizers cannot share slot variables. Since both the
# LossScaleOptimizer and the wrapped optimizer (self and self._optimizer
# respectively) are checkpointed, we cannot expose the wrapped optimizer's
# slots in the LossScaleOptimizer. Otherwise, a checkpoint would believe
# both optimizers share slot variables.
raise AttributeError(
'You cannot call get_slot on a LossScaleOptimizer. This limitation '
'will be removed in the future.')
def add_slot(self, var, slot_name, initializer='zeros'):
# We disallow adding a slot for consistency with `get_slot`.
raise AttributeError(
'You cannot call add_slot on a LossScaleOptimizer. This limitation '
'will be removed in the future.')
# We do not override some OptimizerV2 methods. For each, we describe why we do
# not delegate them to self._optimizer:
# * get_updates: get_updates() calls get_gradients(). Since we override
......@@ -402,6 +528,51 @@ class LossScaleOptimizer(optimizer_v2.OptimizerV2):
# TODO(reedwm): Maybe throw an error if mixed precision is used without this
# optimizer being used.
# Trackable delegations: Delegate all Trackable methods to the wrapped
# optimizer. This is so the checkpoint format for a LossScaleOptimizer is
# identical to the checkpoint format for a normal optimizer, except the loss
# scale is stored in the checkpoint.
class FakeOptimizerForRestoration(trackable.Trackable):
"""A fake optimizer used to support restoring TensorFlow 2.2 checkpoints.
The checkpoint format for LossScaleOptimizers changed after TF 2.2. This class
exists to support restoring TF 2.2 checkpoints in newer version of TensorFlow.
In TF 2.2, LossScaleOptimizer would track the wrapped optimizer by calling the
following in LossScaleOptimizer.__init__
```
self._track_trackable(self._optimizer, 'base_optimizer')
```
This means a dependency from the LossScaleOptimizer to the wrapped optimizer
would be stored in the checkpoint. However now, the checkpoint format with a
LossScaleOptimizer is the same as the format without a LossScaleOptimizer,
except the loss scale is also stored. This means there is no dependency from
the LossScaleOptimizer to the wrapped optimizer. Instead, the
LossScaleOptimizer acts as if it is the wrapped optimizer, from a checkpoint's
perspective, by overriding all Trackable methods and delegating them to the
wrapped optimizer.
To allow restoring TF 2.2. checkpoints, LossScaleOptimizer adds a dependency
on this class instead of the inner optimizer. When restored, this class will
instead restore the slot variables of the inner optimizer. Since this class
has no variables, it does not affect the checkpoint when saved.
"""
def __init__(self, optimizer):
self._optimizer = optimizer
def get_slot_names(self):
return self._optimizer.get_slot_names()
def _create_or_restore_slot_variable(self, slot_variable_position, slot_name,
variable):
return self._optimizer._create_or_restore_slot_variable( # pylint: disable=protected-access
slot_variable_position, slot_name, variable)
# pylint: disable=protected-access
mixed_precision._register_wrapper_optimizer_cls(optimizer_v2.OptimizerV2,
......
......@@ -305,20 +305,6 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
opt.set_weights([np.array(2.)])
self.assertEqual(self.evaluate(opt.variables()[0]), 2)
def testSlotMethodErrors(self):
opt = gradient_descent.SGD(1.0, momentum=1.0)
opt = loss_scale_optimizer.LossScaleOptimizer(opt, 'dynamic')
with self.assertRaisesRegexp(
AttributeError,
'You cannot call get_slot on a LossScaleOptimizer. This limitation '
'will be removed in the future.'):
opt.get_slot(None, None)
with self.assertRaisesRegexp(
AttributeError,
'You cannot call add_slot on a LossScaleOptimizer. This limitation '
'will be removed in the future.'):
opt.add_slot(None, None)
def testPassingNoneToLossScale(self):
opt = gradient_descent.SGD()
with self.assertRaisesRegexp(ValueError, r'loss_scale cannot be None'):
......@@ -394,9 +380,49 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
run_fn = lambda: opt.minimize(loss, [var])
strategy.experimental_run(run_fn)
@parameterized.named_parameters(*TESTCASES)
def testCheckpoint(self, strategy_fn):
@parameterized.named_parameters({
'testcase_name': 'SaveAndRestoreBase',
'strategy_fn': default_strategy_fn,
'save_with_ls': True,
'restore_with_ls': True,
}, {
'testcase_name': 'SaveAndRestoreDistribute',
'strategy_fn': create_mirrored_strategy,
'save_with_ls': True,
'restore_with_ls': True,
}, {
'testcase_name': 'SaveBase',
'strategy_fn': default_strategy_fn,
'save_with_ls': True,
'restore_with_ls': False,
}, {
'testcase_name': 'SaveDistribute',
'strategy_fn': create_mirrored_strategy,
'save_with_ls': True,
'restore_with_ls': False,
}, {
'testcase_name': 'RestoreBase',
'strategy_fn': default_strategy_fn,
'save_with_ls': False,
'restore_with_ls': True,
}, {
'testcase_name': 'RestoreDistribute',
'strategy_fn': create_mirrored_strategy,
'save_with_ls': False,
'restore_with_ls': True,
})
def testCheckpoint(self, strategy_fn, save_with_ls, restore_with_ls):
class MySGD(gradient_descent.SGD):
"""A custom optimizer that tracks an extra variable."""
def __init__(self, *args, **kwargs):
super(MySGD, self).__init__(*args, **kwargs)
self.my_var = variables.Variable(0.)
self._track_trackable(self.my_var, 'my_var')
strategy = strategy_fn()
replicas = strategy.num_replicas_in_sync
if (isinstance(strategy, mirrored_strategy.MirroredStrategy) and
not context.executing_eagerly()):
# TODO(b/121381184): Enable running the test in this case.
......@@ -405,38 +431,89 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
with self.test_session(), strategy.scope():
# Build and run a simple model.
var = variables.Variable([2.0])
loss_scale = loss_scale_module.DynamicLossScale(
initial_loss_scale=1., increment_period=2.,
multiplier=2.)
opt = gradient_descent.SGD(1., momentum=1.)
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)
run_fn = lambda: opt.minimize(lambda: var + 1., var_list=[var])
opt = inner_opt = MySGD(1., momentum=1.)
if save_with_ls:
loss_scale = loss_scale_module.DynamicLossScale(
initial_loss_scale=1., increment_period=2.,
multiplier=2.)
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)
run_fn = lambda: opt.minimize(lambda: var / replicas + 1., var_list=[var])
opt_op = strategy.experimental_run(run_fn)
self.evaluate(variables.global_variables_initializer())
self.evaluate(opt_op)
self.assertEqual(self.evaluate(loss_scale()), 1.)
self.assertEqual(self.evaluate(loss_scale._num_good_steps), 1)
slot_var = opt._optimizer.get_slot(var, 'momentum')
slot_value = self.evaluate(slot_var).item()
self.evaluate(strategy.experimental_local_results(opt_op))
# Assert values.
self.assertEqual(self.evaluate(var), 1.)
if save_with_ls:
self.assertEqual(self.evaluate(loss_scale()), 1.)
self.assertEqual(self.evaluate(loss_scale._num_good_steps), 1)
slot_var = opt.get_slot(var, 'momentum')
self.assertEqual(self.evaluate(slot_var).item(), -1)
self.assertEqual(self.evaluate(opt.iterations), 1)
# Set optimizer variable to check arbitrary optimizer attributes can be
# saved/restored
self.evaluate(inner_opt.my_var.assign(1.))
# Save a checkpoint.
checkpoint = trackable_utils.Checkpoint(optimizer=opt, var=var)
prefix = os.path.join(self.get_temp_dir(), 'ckpt')
save_path = checkpoint.save(prefix)
# Run model again.
self.evaluate(strategy.experimental_run(run_fn))
self.assertEqual(self.evaluate(loss_scale()), 2.)
self.assertEqual(self.evaluate(loss_scale._num_good_steps), 0)
self.assertNotAlmostEqual(self.evaluate(slot_var).item(), slot_value)
# Create new model
var = variables.Variable([2.0])
opt = inner_opt = MySGD(1., momentum=1.)
if restore_with_ls:
loss_scale = loss_scale_module.DynamicLossScale(
initial_loss_scale=1., increment_period=2.,
multiplier=2.)
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)
# Restore new model.
checkpoint = trackable_utils.Checkpoint(optimizer=opt, var=var)
status = checkpoint.restore(save_path)
if save_with_ls:
status.assert_existing_objects_matched()
else:
status.assert_nontrivial_match()
# Assert restored values. We can only assert in eager mode since the
# variables are uninitialized in graph mode
if context.executing_eagerly():
self.assertEqual(self.evaluate(var), 1.)
if save_with_ls and restore_with_ls:
self.assertEqual(self.evaluate(loss_scale()), 1.)
self.assertEqual(self.evaluate(loss_scale._num_good_steps), 1)
elif restore_with_ls:
self.assertEqual(self.evaluate(loss_scale()), 1.)
self.assertEqual(self.evaluate(loss_scale._num_good_steps), 0)
self.assertEqual(self.evaluate(opt.iterations), 1)
# Run the model again.
run_fn = lambda: opt.minimize(lambda: var / replicas + 1., var_list=[var])
opt_op = strategy.experimental_run(run_fn)
# Load checkpoint and ensure loss scale is back to it's original value.
# Assert new values.
self.evaluate(variables.global_variables_initializer())
status.run_restore_ops()
self.evaluate(strategy.experimental_local_results(opt_op))
self.assertEqual(self.evaluate(var), -1)
slot_var = opt.get_slot(var, 'momentum')
self.assertEqual(self.evaluate(slot_var).item(), -2)
self.assertEqual(self.evaluate(opt.iterations), 2)
self.assertEqual(self.evaluate(inner_opt.my_var), 1)
# Restore model again to test restoring after slots are created
status = checkpoint.restore(save_path)
status.assert_consumed()
if save_with_ls and restore_with_ls:
status.assert_consumed()
elif save_with_ls:
status.assert_existing_objects_matched()
elif restore_with_ls:
status.assert_nontrivial_match()
status.run_restore_ops()
self.assertEqual(self.evaluate(loss_scale()), 1.)
self.assertEqual(self.evaluate(loss_scale._num_good_steps), 1)
self.assertAlmostEqual(self.evaluate(slot_var).item(), slot_value)
self.assertEqual(self.evaluate(var), 1)
self.assertEqual(self.evaluate(slot_var).item(), -1)
def testGetConfig(self):
opt = gradient_descent.SGD(2., momentum=0.5)
......
# Description:
# Contains checkpoints and SavedModels for testing purposes.
package(
default_visibility = [
"//tensorflow/python/keras:__subpackages__",
"//tensorflow/tools/pip_package:__pkg__",
],
licenses = ["notice"], # Apache 2.0
)
exports_files(["LICENSE"])
# These files were generated by running the following program with TensorFlow
# 2.2rc2. The final release of TF 2.2 was not out when this change was created.:
# import os
# import numpy as np
# import tensorflow as tf
#
# tf.random.set_seed(1)
# opt = tf.keras.optimizers.SGD(0.1, momentum=0.1)
# opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, 'dynamic')
# model = tf.keras.Sequential([tf.keras.layers.Dense(2)])
# model.compile(opt, 'mse')
#
# x = np.ones((10, 2))
# y = x * 100
# model.fit(x, y)
# weight_dir = os.environ['TF_LSO_WEIGHT_DIR']
# model_dir = os.environ['TF_LSO_MODEL_DIR']
# model.save_weights(weight_dir)
# model.save(model_dir)
# print(model.get_weights()[0])
# print(opt._optimizer.get_slot(model.weights[0], 'momentum'))
# print(opt.loss_scale)
filegroup(
name = "lso_ckpt_tf2.2",
srcs = glob(["lso_ckpt_tf2.2/**"]),
tags = ["no_pip"],
)
filegroup(
name = "lso_savedmodel_tf2.2",
srcs = glob(["lso_savedmodel_tf2.2/**"]),
tags = ["no_pip"],
)
model_checkpoint_path: "ckpt"
all_model_checkpoint_paths: "ckpt"
f
6
layer_with_weights-0
layer-0
 optimizer

kernel
bias
$
base_optimizer

loss_scale
ca
VARIABLE_VALUEsequential/dense/kernel6layer_with_weights-0/kernel/.ATTRIBUTES/VARIABLE_VALUE
_]
VARIABLE_VALUEsequential/dense/bias4layer_with_weights-0/bias/.ATTRIBUTES/VARIABLE_VALUE
V
iter
decay
  learning_rate

momentummomentum momentum
(
 current_loss_scale
 
good_steps
VT
VARIABLE_VALUESGD/iter8optimizer/base_optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE
XV
VARIABLE_VALUE SGD/decay9optimizer/base_optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE
hf
VARIABLE_VALUESGD/learning_rateAoptimizer/base_optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE
^\
VARIABLE_VALUE SGD/momentum<optimizer/base_optimizer/momentum/.ATTRIBUTES/VARIABLE_VALUE
jh
VARIABLE_VALUEcurrent_loss_scaleBoptimizer/loss_scale/current_loss_scale/.ATTRIBUTES/VARIABLE_VALUE
ZX
VARIABLE_VALUE
good_steps:optimizer/loss_scale/good_steps/.ATTRIBUTES/VARIABLE_VALUE

VARIABLE_VALUE$SGD/sequential/dense/kernel/momentumhlayer_with_weights-0/kernel/.OPTIMIZER_SLOT/optimizer/base_optimizer/momentum/.ATTRIBUTES/VARIABLE_VALUE

VARIABLE_VALUE"SGD/sequential/dense/bias/momentumflayer_with_weights-0/bias/.OPTIMIZER_SLOT/optimizer/base_optimizer/momentum/.ATTRIBUTES/VARIABLE_VALUE
\ No newline at end of file
path: "tensorflow.keras.mixed_precision.experimental.LossScaleOptimizer"
tf_class {
is_instance: "<class \'tensorflow.python.keras.mixed_precision.experimental.loss_scale_optimizer.LossScaleOptimizer\'>"
is_instance: "<class \'tensorflow.python.keras.mixed_precision.experimental.loss_scale_optimizer._DelegatingTrackableMixin\'>"
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.optimizer_v2.OptimizerV2\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<type \'object\'>"
......
path: "tensorflow.keras.mixed_precision.experimental.LossScaleOptimizer"
tf_class {
is_instance: "<class \'tensorflow.python.keras.mixed_precision.experimental.loss_scale_optimizer.LossScaleOptimizer\'>"
is_instance: "<class \'tensorflow.python.keras.mixed_precision.experimental.loss_scale_optimizer._DelegatingTrackableMixin\'>"
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.optimizer_v2.OptimizerV2\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<type \'object\'>"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册