optimizer_v2.py 47.1 KB
Newer Older
1
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Version 2 of class Optimizer."""
# pylint: disable=g-bad-name

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import abc
24
import functools
25

26
import six
27

28
from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx
29
from tensorflow.python.distribute import reduce_util as ds_reduce_util
30
from tensorflow.python.distribute import values as ds_values
31
from tensorflow.python.eager import backprop
A
Allen Lavoie 已提交
32
from tensorflow.python.eager import context
33 34
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
35
from tensorflow.python.framework import tensor_util
36
from tensorflow.python.keras import backend
37
from tensorflow.python.keras import initializers
38
from tensorflow.python.keras.engine import base_layer_utils
39
from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule
40
from tensorflow.python.keras.utils import generic_utils
41
from tensorflow.python.keras.utils import tf_utils
42 43
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
44
from tensorflow.python.ops import control_flow_ops
45
from tensorflow.python.ops import gradients
46
from tensorflow.python.ops import math_ops
47
from tensorflow.python.ops import resource_variable_ops
48
from tensorflow.python.ops import variables as tf_variables
49
from tensorflow.python.platform import tf_logging as logging
50
from tensorflow.python.saved_model import revived_types
51
from tensorflow.python.training.tracking import base as trackable
52
from tensorflow.python.training.tracking import tracking
53
from tensorflow.python.util import nest
54
from tensorflow.python.util import tf_inspect
55
from tensorflow.python.util.tf_export import keras_export
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75


def _deduplicate_indexed_slices(values, indices):
  """Sums `values` associated with any non-unique `indices`.

  Args:
    values: A `Tensor` with rank >= 1.
    indices: A one-dimensional integer `Tensor`, indexing into the first
      dimension of `values` (as in an IndexedSlices object).

  Returns:
    A tuple of (`summed_values`, `unique_indices`) where `unique_indices` is a
    de-duplicated version of `indices` and `summed_values` contains the sum of
    `values` slices associated with each unique index.
  """
  unique_indices, new_index_positions = array_ops.unique(indices)
  summed_values = math_ops.unsorted_segment_sum(
      values, new_index_positions,
      array_ops.shape(unique_indices)[0])
  return (summed_values, unique_indices)
76 77


78
@six.add_metaclass(abc.ABCMeta)
79
@keras_export("keras.optimizers.Optimizer")
80
class OptimizerV2(trackable.Trackable):
81 82 83 84
  """Updated base class for optimizers.

  This class defines the API to add Ops to train a model.  You never use this
  class directly, but instead instantiate one of its subclasses such as
85
  `tf.keras.optimizers.SGD`, `tf.keras.optimizers.Adam`.
86 87 88 89 90

  ### Usage

  ```python
  # Create an optimizer with the desired parameters.
91 92 93 94 95 96 97
  opt = tf.keras.optimizers.SGD(learning_rate=0.1)
  # `loss` is a callable that takes no argument and returns the value
  # to minimize.
  loss = lambda: 3 * var1 * var1 + 2 * var2 * var2
  # In graph mode, returns op that minimizes the loss by updating the listed
  # variables.
  opt_op = opt.minimize(loss, var_list=[var1, var2])
98
  opt_op.run()
99 100
  # In eager mode, simply call minimize to update the list of variables.
  opt.minimize(loss, var_list=[var1, var2])
101 102
  ```

103 104 105 106 107 108 109 110 111 112 113 114
  ### Custom training loop with Keras models

  In Keras models, sometimes variables are created when the model is first
  called, instead of construction time. Examples include 1) sequential models
  without input shape pre-defined, or 2) subclassed models. Pass var_list as
  callable in these cases.

  Example:
  ```python
  opt = tf.keras.optimizers.SGD(learning_rate=0.1)
  model = tf.keras.Sequential()
  model.add(tf.keras.layers.Dense(num_hidden, activation='relu'))
115
  model.add(tf.keras.layers.Dense(num_classes, activation='sigmoid'))
116 117 118 119 120 121
  loss_fn = lambda: tf.keras.losses.mse(model(input), output)
  var_list_fn = lambda: model.trainable_weights
  for input, output in data:
    opt.minimize(loss_fn, var_list_fn)
  ```

122 123 124 125 126 127
  ### Processing gradients before applying them.

  Calling `minimize()` takes care of both computing the gradients and
  applying them to the variables.  If you want to process the gradients
  before applying them you can instead use the optimizer in three steps:

128
  1.  Compute the gradients with `tf.GradientTape`.
129 130 131 132 133 134 135
  2.  Process the gradients as you wish.
  3.  Apply the processed gradients with `apply_gradients()`.

  Example:

  ```python
  # Create an optimizer.
136
  opt = tf.keras.optimizers.SGD(learning_rate=0.1)
137 138

  # Compute the gradients for a list of variables.
139 140 141 142
  with tf.GradientTape() as tape:
    loss = <call_loss_function>
  vars = <list_of_variables>
  grads = tape.gradient(loss, vars)
143

144
  # Process the gradients, for example cap them, etc.
145 146
  # capped_grads = [MyCapper(g) for g in grads]
  processed_grads = [process_gradient(g) for g in grads]
147

148 149
  # Ask the optimizer to apply the processed gradients.
  opt.apply_gradients(zip(processed_grads, var_list))
150 151
  ```

152 153 154 155
  ### Use with `tf.distribute.Strategy`.

  This optimizer class is `tf.distribute.Strategy` aware, which means it
  automatically sums gradients across all replicas. To average gradients,
156 157 158
  you divide your loss by the global batch size, which is done
  automatically if you use `tf.keras` built-in training or evaluation loops.
  See the `reduction` argument of your loss which should be set to
159 160
  `tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE` for averaging or
  `tf.keras.losses.Reduction.SUM` for not.
161

162 163 164 165
  To aggregate gradients yourself, call `apply_gradients` with
  `all_reduce_sum_gradients` set to False. This is useful if you need to process
  aggregated gradients.

166 167 168 169 170 171 172 173
  If you are not using these and you want to average gradients, you should use
  `tf.math.reduce_sum` to add up your per-example losses and then divide by the
  global batch size. Note that when using `tf.distribute.Strategy`, the first
  component of a tensor's shape is the *replica-local* batch size, which is off
  by a factor equal to the number of replicas being used to compute a single
  step. As a result, using `tf.math.reduce_mean` will give the wrong answer,
  resulting in gradients that can be many times too big.

174
  ### Variable Constraint
175

176 177 178 179 180 181 182 183 184 185
  All Keras optimizers respect variable constraints. If constraint function is
  passed to any variable, the constraint will be applied to the variable after
  the gradient has been applied to the variable.
  Important: If gradient is sparse tensor, variable constraint is not supported.

  ### Thread Compatibility

  The entire optimizer is currently thread compatible, not thread-safe. The user
  needs to perform synchronization if necessary.

186 187
  ### Slots

188 189 190 191 192
  Many optimizer subclasses, such as `Adam` and `Adagrad` allocate and manage
  additional variables associated with the variables to train.  These are called
  <i>Slots</i>.  Slots have names and you can ask the optimizer for the names of
  the slots that it uses.  Once you have a slot name you can ask the optimizer
  for the variable it created to hold the slot value.
193 194 195 196 197 198 199 200 201 202 203 204

  This can be useful if you want to log debug a training algorithm, report stats
  about the slots, etc.

  ### Hyper parameters

  These are arguments passed to the optimizer subclass constructor
  (the `__init__` method), and then passed to `self._set_hyper()`.
  They can be either regular Python values (like 1.0), tensors, or
  callables. If they are callable, the callable will be called during
  `apply_gradients()` to get the value for the hyper parameter.

205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222
  Hyper parameters can be overwritten through user code:

  Example:

  ```python
  # Create an optimizer with the desired parameters.
  opt = tf.keras.optimizers.SGD(learning_rate=0.1)
  # `loss` is a callable that takes no argument and returns the value
  # to minimize.
  loss = lambda: 3 * var1 + 2 * var2
  # In eager mode, simply call minimize to update the list of variables.
  opt.minimize(loss, var_list=[var1, var2])
  # update learning rate
  opt.learning_rate = 0.05
  opt.minimize(loss, var_list=[var1, var2])
  ```

  ### Write a customized optimizer.
223 224
  If you intend to create your own optimization algorithm, simply inherit from
  this class and override the following methods:
225 226 227 228 229

    - resource_apply_dense (update variable given gradient tensor is dense)
    - resource_apply_sparse (update variable given gradient tensor is sparse)
    - create_slots (if your optimizer algorithm requires additional variables)
    - get_config (serialization of the optimizer, include all hyper parameters)
230 231
  """

232
  def __init__(self, name, **kwargs):
233 234 235 236 237 238 239 240
    """Create a new Optimizer.

    This must be called by the constructors of subclasses.
    Note that Optimizer instances should not bind to a single graph,
    and so shouldn't keep Tensors as member variables. Generally
    you should be able to use the _set_hyper()/state.get_hyper()
    facility instead.

241 242
    This class in stateful and thread-compatible.

243 244 245
    Args:
      name: A non-empty string.  The name to use for accumulators created
        for the optimizer.
246 247 248 249 250
      **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`,
        `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip
        gradients by value, `decay` is included for backward compatibility to
        allow time inverse decay of learning rate. `lr` is included for backward
        compatibility, recommended to use `learning_rate` instead.
251 252 253 254 255 256

    Raises:
      ValueError: If name is malformed.
      RuntimeError: If _create_slots has been overridden instead of
          _create_vars.
    """
257 258 259 260 261 262
    allowed_kwargs = {"clipnorm", "clipvalue", "lr", "decay"}
    for k in kwargs:
      if k not in allowed_kwargs:
        raise TypeError("Unexpected keyword argument "
                        "passed to optimizer: " + str(k))
      # checks that all keyword arguments are non-negative.
263
      if kwargs[k] is not None and kwargs[k] < 0:
264 265
        raise ValueError("Expected {} >= 0, received: {}".format(k, kwargs[k]))

266
    self._use_locking = True
267
    self._init_set_name(name)
268
    self._hyper = {}
269
    # dict: {variable name : {slot name : variable}}
270
    self._slots = {}
A
Allen Lavoie 已提交
271
    self._slot_names = []
272
    self._weights = []
273
    self._iterations = None
274

275
    # For implementing Trackable. Stores information about how to restore
A
Allen Lavoie 已提交
276
    # slot variables which have not yet been created
277
    # (trackable._CheckpointPosition objects).
A
Allen Lavoie 已提交
278 279 280 281 282
    #  {slot_name :
    #      {_var_key(variable_to_train): [checkpoint_position, ... ], ... },
    #   ... }
    self._deferred_slot_restorations = {}

283 284 285 286
    decay = kwargs.pop("decay", 0.0)
    if decay < 0.:
      raise ValueError("decay cannot be less than 0: {}".format(decay))
    self._initial_decay = decay
287 288 289 290 291 292 293 294 295

    # Set the gradient clipping properties
    self.clipnorm = kwargs.pop("clipnorm", None)
    self.clipvalue = kwargs.pop("clipvalue", None)
    if ((self.clipnorm is not None or self.clipvalue is not None)
        and distribute_ctx.has_strategy()):
      raise ValueError("Gradient clipping in the optimizer "
                       "(by setting clipnorm or clipvalue) is currently "
                       "unsupported when using a distribution strategy.")
296

297
    self._hypers_created = False
298

299
  def minimize(self, loss, var_list, grad_loss=None, name=None):
300
    """Minimize `loss` by updating `var_list`.
301

302
    This method simply computes gradient using `tf.GradientTape` and calls
303
    `apply_gradients()`. If you want to process the gradient before applying
304
    then call `tf.GradientTape` and `apply_gradients()` explicitly instead
305 306 307
    of using this function.

    Args:
308
      loss: A callable taking no arguments which returns the value to minimize.
309
      var_list: list or tuple of `Variable` objects to update to minimize
310 311 312 313
        `loss`, or a callable returning the list or tuple of `Variable` objects.
        Use callable when the variable list would otherwise be incomplete before
        `minimize` since the variables are created at the first time `loss` is
        called.
314
      grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
315
      name: Optional name for the returned operation.
316 317

    Returns:
318 319
      An `Operation` that updates the variables in `var_list`. The `iterations`
      will be automatically increased by 1.
320 321 322 323 324

    Raises:
      ValueError: If some of the variables are not `Variable` objects.

    """
325 326
    grads_and_vars = self._compute_gradients(
        loss, var_list=var_list, grad_loss=grad_loss)
327 328 329

    return self.apply_gradients(grads_and_vars, name=name)

330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348
  def _clip_gradients(self, grads):
    """Clip gradients according to the clipnorm and clipvalue attributes."""
    if self.clipnorm is not None:
      if distribute_ctx.has_strategy():
        raise ValueError("Gradient clipping in the optimizer "
                         "(by setting clipnorm or clipvalue) is currently "
                         "unsupported when using a distribution strategy.")
      grads = [clip_ops.clip_by_norm(g, self.clipnorm) for g in grads]
    if self.clipvalue is not None:
      if distribute_ctx.has_strategy():
        raise ValueError("Gradient clipping in the optimizer "
                         "(by setting clipnorm or clipvalue) is currently "
                         "unsupported when using a distribution strategy.")
      grads = [
          clip_ops.clip_by_value(g, -self.clipvalue, self.clipvalue)
          for g in grads
      ]
    return grads

349
  def _compute_gradients(self, loss, var_list, grad_loss=None):
350 351 352 353 354 355 356 357 358
    """Compute gradients of `loss` for the variables in `var_list`.

    This is the first part of `minimize()`.  It returns a list
    of (gradient, variable) pairs where "gradient" is the gradient
    for "variable".  Note that "gradient" can be a `Tensor`, an
    `IndexedSlices`, or `None` if there is no gradient for the
    given variable.

    Args:
359
      loss: A callable taking no arguments which returns the value to minimize.
360 361 362 363 364
      var_list: list or tuple of `Variable` objects to update to minimize
        `loss`, or a callable returning the list or tuple of `Variable` objects.
        Use callable when the variable list would otherwise be incomplete before
        `minimize` and the variables are created at the first time when `loss`
        is called.
365 366 367 368 369 370 371 372 373 374 375
      grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.

    Returns:
      A list of (gradient, variable) pairs. Variable is always present, but
      gradient can be `None`.

    Raises:
      TypeError: If `var_list` contains anything else than `Variable` objects.
      ValueError: If some arguments are invalid, or var_list is None.
    """
    # TODO(josh11b): Test that we handle weight decay in a reasonable way.
376
    with backprop.GradientTape() as tape:
377 378
      if not callable(var_list):
        tape.watch(var_list)
379
      loss_value = loss()
380 381 382
    if callable(var_list):
      var_list = var_list()
    var_list = nest.flatten(var_list)
383
    with backend.name_scope(self._name + "/gradients"):
384
      grads = tape.gradient(loss_value, var_list, grad_loss)
385
      grads = self._clip_gradients(grads)
386 387 388 389 390 391 392 393 394

    grads_and_vars = list(zip(grads, var_list))
    self._assert_valid_dtypes([
        v for g, v in grads_and_vars
        if g is not None and v.dtype != dtypes.resource
    ])

    return grads_and_vars

395 396 397 398 399 400 401 402 403 404 405 406 407 408
  def get_gradients(self, loss, params):
    """Returns gradients of `loss` with respect to `params`.

    Arguments:
      loss: Loss tensor.
      params: List of variables.

    Returns:
      List of gradient tensors.

    Raises:
      ValueError: In case any gradient cannot be computed (e.g. if gradient
        function not implemented).
    """
409
    params = nest.flatten(params)
410 411
    with backend.get_graph().as_default(), backend.name_scope(self._name +
                                                              "/gradients"):
412
      grads = gradients.gradients(loss, params)
413 414 415 416 417 418 419
      for grad, param in zip(grads, params):
        if grad is None:
          raise ValueError("Variable {} has `None` for gradient. "
                           "Please make sure that all of your ops have a "
                           "gradient defined (i.e. are differentiable). "
                           "Common ops without gradient: "
                           "K.argmax, K.round, K.eval.".format(param))
420
      grads = self._clip_gradients(grads)
421 422
    return grads

423 424 425 426
  def apply_gradients(self,
                      grads_and_vars,
                      name=None,
                      all_reduce_sum_gradients=True):
427 428 429 430 431
    """Apply gradients to variables.

    This is the second part of `minimize()`. It returns an `Operation` that
    applies gradients.

432 433 434 435 436 437 438 439 440 441 442 443 444 445
    The method sums gradients from all replicas in the presence of
    `tf.distribute.Strategy` by default. You can aggregate gradients yourself by
    passing `all_reduce_sum_gradients=False`.

    Example:

    ```python
    grads = tape.gradient(loss, vars)
    grads = tf.distribute.get_replica_context().all_reduce('sum', grads)
    # Processing aggregated gradients.
    optimizer.apply_gradients(zip(grads, vars), all_reduce_sum_gradients=False)

    ```

446
    Args:
447
      grads_and_vars: List of (gradient, variable) pairs.
448 449
      name: Optional name for the returned operation.  Default to the name
        passed to the `Optimizer` constructor.
450 451 452
      all_reduce_sum_gradients: Whether to sum gradients from different
        replicas in the presense of `tf.distribute.Strategy`. If False, it's
        user responsibility to aggregate the gradients. Default to True.
453 454

    Returns:
Z
Zhenyu Tan 已提交
455
      An `Operation` that applies the specified gradients. The `iterations`
456
      will be automatically increased by 1.
457 458 459 460 461 462 463 464

    Raises:
      TypeError: If `grads_and_vars` is malformed.
      ValueError: If none of the variables have gradients.
    """
    grads_and_vars = _filter_grads(grads_and_vars)
    var_list = [v for (_, v) in grads_and_vars]

465
    with backend.name_scope(self._name):
466 467 468 469 470
      # Create iteration if necessary.
      with ops.init_scope():
        _ = self.iterations
        self._create_hypers()
        self._create_slots(var_list)
471

472 473 474 475
      if not grads_and_vars:
        # Distribution strategy does not support reducing an empty list of
        # gradients
        return control_flow_ops.no_op()
476
      apply_state = self._prepare(var_list)
477
      return distribute_ctx.get_replica_context().merge_call(
478
          functools.partial(self._distributed_apply, apply_state=apply_state),
479
          args=(grads_and_vars,),
480 481 482 483
          kwargs={
              "name": name,
              "all_reduce_sum_gradients": all_reduce_sum_gradients,
          })
484

485 486 487 488 489
  def _aggregate_gradients(self, distribution, grads_and_vars):
    """Returns all-reduced gradients."""
    return distribution.extended.batch_reduce_to(
        ds_reduce_util.ReduceOp.SUM, grads_and_vars)

490 491
  def _distributed_apply(self, distribution, grads_and_vars, name, apply_state,
                         all_reduce_sum_gradients):
492
    """`apply_gradients` using a `DistributionStrategy`."""
493 494 495 496
    if all_reduce_sum_gradients:
      reduced_grads = self._aggregate_gradients(distribution, grads_and_vars)
      var_list = [v for _, v in grads_and_vars]
      grads_and_vars = zip(reduced_grads, var_list)
497 498

    def apply_grad_to_update_var(var, grad):
499 500 501
      """Apply gradient to variable."""
      if isinstance(var, ops.Tensor):
        raise NotImplementedError("Trying to update a Tensor ", var)
502 503

      apply_kwargs = {}
504
      if isinstance(grad, ops.IndexedSlices):
505
        if var.constraint is not None:
506 507
          raise RuntimeError(
              "Cannot use a constraint function on a sparse variable.")
508 509
        if "apply_state" in self._sparse_apply_args:
          apply_kwargs["apply_state"] = apply_state
510
        return self._resource_apply_sparse_duplicate_indices(
511 512 513 514 515
            grad.values, var, grad.indices, **apply_kwargs)

      if "apply_state" in self._dense_apply_args:
        apply_kwargs["apply_state"] = apply_state
      update_op = self._resource_apply_dense(grad, var, **apply_kwargs)
516 517 518 519 520
      if var.constraint is not None:
        with ops.control_dependencies([update_op]):
          return var.assign(var.constraint(var))
      else:
        return update_op
521

522
    eagerly_outside_functions = ops.executing_eagerly_outside_functions()
523
    update_ops = []
524
    with ops.name_scope(name or self._name, skip_on_eager=True):
525
      for grad, var in grads_and_vars:
526 527 528 529 530 531 532 533
        # TODO(crccw): It's not allowed to assign PerReplica value to
        # MirroredVariable.  Remove this after we relax this restriction.
        def _assume_mirrored(grad):
          if isinstance(grad, ds_values.PerReplica):
            return ds_values.Mirrored(grad.values)
          return grad

        grad = nest.map_structure(_assume_mirrored, grad)
534 535
        # Colocate the update with variables to avoid unnecessary communication
        # delays. See b/136304694.
536 537 538 539 540
        with distribution.extended.colocate_vars_with(var):
          with ops.name_scope("update" if eagerly_outside_functions else
                              "update_" + var.op.name, skip_on_eager=True):
            update_ops.extend(distribution.extended.update(
                var, apply_grad_to_update_var, args=(grad,), group=False))
541 542 543 544 545 546 547 548 549 550 551 552

      any_symbolic = any(isinstance(i, ops.Operation) or
                         tf_utils.is_symbolic_tensor(i) for i in update_ops)
      if not context.executing_eagerly() or any_symbolic:
        # If the current context is graph mode or any of the update ops are
        # symbolic then the step update should be carried out under a graph
        # context. (eager updates execute immediately)
        with ops._get_graph_from_inputs(update_ops).as_default():  # pylint: disable=protected-access
          with ops.control_dependencies(update_ops):
            return self._iterations.assign_add(1).op

      return self._iterations.assign_add(1)
553

554
  def get_updates(self, loss, params):
555 556 557 558 559 560 561
    grads = self.get_gradients(loss, params)
    grads_and_vars = list(zip(grads, params))
    self._assert_valid_dtypes([
        v for g, v in grads_and_vars
        if g is not None and v.dtype != dtypes.resource
    ])
    return [self.apply_gradients(grads_and_vars)]
562

563
  def _set_hyper(self, name, value):
564
    """set hyper `name` to value. value can be callable, tensor, numeric."""
565 566
    if isinstance(value, trackable.Trackable):
      self._track_trackable(value, name, overwrite=True)
567 568 569 570
    if name not in self._hyper:
      self._hyper[name] = value
    else:
      prev_value = self._hyper[name]
571 572 573 574 575
      if (callable(prev_value)
          or isinstance(prev_value,
                        (ops.Tensor, int, float,
                         learning_rate_schedule.LearningRateSchedule))
          or isinstance(value, learning_rate_schedule.LearningRateSchedule)):
576 577
        self._hyper[name] = value
      else:
578
        backend.set_value(self._hyper[name], value)
579

580
  def _get_hyper(self, name, dtype=None):
Z
Zhenyu Tan 已提交
581 582
    if not self._hypers_created:
      self._create_hypers()
583
    value = self._hyper[name]
584 585
    if isinstance(value, learning_rate_schedule.LearningRateSchedule):
      return value
586 587 588 589 590 591
    if callable(value):
      value = value()
    if dtype:
      return math_ops.cast(value, dtype)
    else:
      return value
592

593 594 595 596 597 598 599 600 601 602 603 604
  def __getattribute__(self, name):
    """Overridden to support hyperparameter access."""
    try:
      return super(OptimizerV2, self).__getattribute__(name)
    except AttributeError as e:
      # Needed to avoid infinite recursion with __setattr__.
      if name == "_hyper":
        raise e
      # Backwards compatibility with Keras optimizers.
      if name == "lr":
        name = "learning_rate"
      if name in self._hyper:
Z
Zhenyu Tan 已提交
605
        return self._get_hyper(name)
606 607
      raise e

608 609
  def __setattr__(self, name, value):
    """Override setattr to support dynamic hyperparameter setting."""
610 611 612
    # Backwards compatibility with Keras optimizers.
    if name == "lr":
      name = "learning_rate"
613 614 615 616 617
    if hasattr(self, "_hyper") and name in self._hyper:
      self._set_hyper(name, value)
    else:
      super(OptimizerV2, self).__setattr__(name, value)

A
Allen Lavoie 已提交
618 619 620 621
  def get_slot_names(self):
    """A list of names for this optimizer's slots."""
    return self._slot_names

Z
Zhenyu Tan 已提交
622
  def add_slot(self, var, slot_name, initializer="zeros"):
A
Allen Lavoie 已提交
623 624 625
    """Add a new slot variable for `var`."""
    if slot_name not in self._slot_names:
      self._slot_names.append(slot_name)
626 627
    var_key = _var_key(var)
    slot_dict = self._slots.setdefault(var_key, {})
A
Allen Lavoie 已提交
628 629 630 631
    weight = slot_dict.get(slot_name, None)
    if weight is None:
      if isinstance(initializer, six.string_types) or callable(initializer):
        initializer = initializers.get(initializer)
632 633
        initial_value = functools.partial(
            initializer, shape=var.shape, dtype=var.dtype)
A
Allen Lavoie 已提交
634
      else:
635
        initial_value = initializer
Z
Zhenyu Tan 已提交
636
      strategy = distribute_ctx.get_strategy()
637 638 639 640 641 642 643 644 645
      if not strategy.extended.variable_created_in_scope(var):
        raise ValueError(
            "Trying to create optimizer slot variable under the scope for "
            "tf.distribute.Strategy ({}), which is different from the scope "
            "used for the original variable ({}). Make sure the slot "
            "variables are created under the same strategy scope. This may "
            "happen if you're restoring from a checkpoint outside the scope"
            .format(strategy, var))

646
      with strategy.extended.colocate_vars_with(var):
Z
Zhenyu Tan 已提交
647 648 649 650 651
        weight = tf_variables.Variable(
            name="%s/%s" % (var._shared_name, slot_name),  # pylint: disable=protected-access
            dtype=var.dtype,
            trainable=False,
            initial_value=initial_value)
A
Allen Lavoie 已提交
652
      backend.track_variable(weight)
653
      slot_dict[slot_name] = weight
A
Allen Lavoie 已提交
654 655 656
      self._restore_slot_variable(
          slot_name=slot_name, variable=var,
          slot_variable=weight)
657
      self._weights.append(weight)
A
Allen Lavoie 已提交
658
    return weight
659 660

  def get_slot(self, var, slot_name):
661 662 663
    var_key = _var_key(var)
    slot_dict = self._slots[var_key]
    return slot_dict[slot_name]
664

665
  def _prepare(self, var_list):
666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691
    keys = set()
    for var in var_list:
      var_devices = (getattr(var, "devices", None) or  # Distributed
                     [var.device])                     # Regular
      var_dtype = var.dtype.base_dtype
      for var_device in var_devices:
        keys.add((var_device, var_dtype))

    apply_state = {}
    for var_device, var_dtype in keys:
      apply_state[(var_device, var_dtype)] = {}
      with ops.device(var_device):
        self._prepare_local(var_device, var_dtype, apply_state)

    return apply_state

  def _prepare_local(self, var_device, var_dtype, apply_state):
    if "learning_rate" in self._hyper:
      lr_t = array_ops.identity(self._decayed_lr(var_dtype))
      apply_state[(var_device, var_dtype)]["lr_t"] = lr_t

  def _fallback_apply_state(self, var_device, var_dtype):
    """Compatibility for subclasses that don't pass apply_state through."""
    apply_state = {(var_device, var_dtype): {}}
    self._prepare_local(var_device, var_dtype, apply_state)
    return apply_state[(var_device, var_dtype)]
692 693 694

  def _create_hypers(self):
    if self._hypers_created:
695
      return
696 697
    # Iterate hyper values deterministically.
    for name, value in sorted(self._hyper.items()):
698 699
      if isinstance(
          value, (ops.Tensor, tf_variables.Variable)) or callable(value):
700
        continue
701 702 703 704 705 706
      else:
        self._hyper[name] = self.add_weight(
            name,
            shape=[],
            trainable=False,
            initializer=value,
707
            aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA)
708
    self._hypers_created = True
709 710

  @property
711
  def iterations(self):
712
    """Variable. The number of training steps this Optimizer has run."""
713 714 715 716 717 718 719 720
    if self._iterations is None:
      self._iterations = self.add_weight(
          "iter",
          shape=[],
          dtype=dtypes.int64,
          trainable=False,
          aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA)
      self._weights.append(self._iterations)
721
    return self._iterations
722

723 724
  @iterations.setter
  def iterations(self, variable):
725
    if self._iterations is not None:
726
      raise RuntimeError("Cannot set `iterations` to a new Variable after "
727 728 729 730
                         "the Optimizer weights have been created")
    self._iterations = variable
    self._weights.append(self._iterations)

731 732 733
  def _decayed_lr(self, var_dtype):
    """Get decayed learning rate as a Tensor with dtype=var_dtype."""
    lr_t = self._get_hyper("learning_rate", var_dtype)
734 735 736
    if isinstance(lr_t, learning_rate_schedule.LearningRateSchedule):
      local_step = math_ops.cast(self.iterations, var_dtype)
      lr_t = math_ops.cast(lr_t(local_step), var_dtype)
737 738 739 740 741 742
    if self._initial_decay > 0.:
      local_step = math_ops.cast(self.iterations, var_dtype)
      decay_t = self._get_hyper("decay", var_dtype)
      lr_t = lr_t / (1. + decay_t * local_step)
    return lr_t

743
  @abc.abstractmethod
744
  def get_config(self):
K
Kazuaki Ishizaki 已提交
745
    """Returns the config of the optimizer.
746 747 748 749 750 751 752 753 754

    An optimizer config is a Python dictionary (serializable)
    containing the configuration of an optimizer.
    The same optimizer can be reinstantiated later
    (without any saved state) from this configuration.

    Returns:
        Python dictionary.
    """
755
    config = {"name": self._name}
756
    if self.clipnorm is not None:
757
      config["clipnorm"] = self.clipnorm
758
    if self.clipvalue is not None:
759 760
      config["clipvalue"] = self.clipvalue
    return config
761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778

  @classmethod
  def from_config(cls, config, custom_objects=None):
    """Creates an optimizer from its config.

    This method is the reverse of `get_config`,
    capable of instantiating the same optimizer from the config
    dictionary.

    Arguments:
        config: A Python dictionary, typically the output of get_config.
        custom_objects: A Python dictionary mapping names to additional Python
          objects used to create this optimizer, such as a function used for a
          hyperparameter.

    Returns:
        An optimizer instance.
    """
779 780
    if "lr" in config:
      config["learning_rate"] = config.pop("lr")
781 782 783
    if "learning_rate" in config:
      if isinstance(config["learning_rate"], dict):
        config["learning_rate"] = learning_rate_schedule.deserialize(
784
            config["learning_rate"], custom_objects=custom_objects)
785 786 787 788
    return cls(**config)

  def _serialize_hyperparameter(self, hyperparameter_name):
    """Serialize a hyperparameter that can be a float, callable, or Tensor."""
Z
Zhenyu Tan 已提交
789
    value = self._hyper[hyperparameter_name]
790 791
    if isinstance(value, learning_rate_schedule.LearningRateSchedule):
      return learning_rate_schedule.serialize(value)
792 793
    if callable(value):
      return value()
794
    if tensor_util.is_tensor(value):
795
      return backend.get_value(value)
796
    return value
797

798 799 800 801
  def variables(self):
    """Returns variables of this Optimizer based on the order created."""
    return self._weights

802 803 804 805 806 807
  @property
  def weights(self):
    """Returns variables of this Optimizer based on the order created."""
    return self._weights

  def get_weights(self):
808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833
    """Returns the current weights of the optimizer.

    The weights of an optimizer are its state (ie, variables).
    This function returns the weight values associated with this
    optimizer as a list of Numpy arrays. The first value is always the
    iterations count of the optimizer, followed by the optimizer's state
    variables in the order they were created. The returned list can in turn
    be used to load state into similarly parameterized optimizers.

    For example, the RMSprop optimizer for this simple model returns a list of
    three values-- the iteration count, followed by the root-mean-square value
    of the kernel and bias of the single Dense layer:

    >>> opt = tf.keras.optimizers.RMSprop()
    >>> m = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
    >>> m.compile(opt, loss='mse')
    >>> data = np.arange(100).reshape(5, 20)
    >>> labels = np.zeros(5)
    >>> print('Training'); results = m.fit(data, labels)
    Training ...
    >>> len(opt.get_weights())
    3

    Returns:
        Weights values as a list of numpy arrays.
    """
834 835 836 837 838
    params = self.weights
    return backend.batch_get_value(params)

  # TODO(tanzheny): Maybe share this logic with base_layer.
  def set_weights(self, weights):
S
Shreyash Patodia 已提交
839
    """Set the weights of the optimizer.
840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866

    The weights of an optimizer are its state (ie, variables).
    This function takes the weight values associated with this
    optimizer as a list of Numpy arrays. The first value is always the
    iterations count of the optimizer, followed by the optimizer's state
    variables in the order they are created. The passed values are used to set
    the new state of the optimizer.

    For example, the RMSprop optimizer for this simple model takes a list of
    three values-- the iteration count, followed by the root-mean-square value
    of the kernel and bias of the single Dense layer:

    >>> opt = tf.keras.optimizers.RMSprop()
    >>> m = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
    >>> m.compile(opt, loss='mse')
    >>> data = np.arange(100).reshape(5, 20)
    >>> labels = np.zeros(5)
    >>> print('Training'); results = m.fit(data, labels)
    Training ...
    >>> new_weights = [np.array(10), np.ones([20, 10]), np.zeros([10])]
    >>> opt.set_weights(new_weights)
    >>> opt.iterations
    <tf.Variable 'RMSprop/iter:0' shape=() dtype=int64, numpy=10>

    Arguments:
        weights: weight values as a list of numpy arrays.
    """
867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885
    params = self.weights
    if len(params) != len(weights):
      raise ValueError(
          "You called `set_weights(weights)` on optimizer " + self._name +
          " with a  weight list of length " + str(len(weights)) +
          ", but the optimizer was expecting " + str(len(params)) +
          " weights. Provided weights: " + str(weights)[:50] + "...")
    if not params:
      return
    weight_value_tuples = []
    param_values = backend.batch_get_value(params)
    for pv, p, w in zip(param_values, params, weights):
      if pv.shape != w.shape:
        raise ValueError("Optimizer weight shape " + str(pv.shape) +
                         " not compatible with "
                         "provided weight shape " + str(w.shape))
      weight_value_tuples.append((p, w))
    backend.batch_set_value(weight_value_tuples)

886 887 888 889 890 891
  def add_weight(self,
                 name,
                 shape,
                 dtype=None,
                 initializer="zeros",
                 trainable=None,
892 893
                 synchronization=tf_variables.VariableSynchronization.AUTO,
                 aggregation=tf_variables.VariableAggregation.NONE):
894 895 896

    if dtype is None:
      dtype = dtypes.float32
897 898
    if isinstance(initializer, six.string_types) or callable(initializer):
      initializer = initializers.get(initializer)
899

900
    if synchronization == tf_variables.VariableSynchronization.ON_READ:
901 902 903 904 905 906 907 908 909 910 911 912 913 914 915
      if trainable:
        raise ValueError(
            "Synchronization value can be set to "
            "VariableSynchronization.ON_READ only for non-trainable variables. "
            "You have specified trainable=True and "
            "synchronization=VariableSynchronization.ON_READ.")
      else:
        # Set trainable to be false when variable is to be synced on read.
        trainable = False
    elif trainable is None:
      trainable = True

    variable = self._add_variable_with_custom_getter(
        name=name,
        shape=shape,
916
        getter=base_layer_utils.make_variable,
917
        overwrite=True,
918
        initializer=initializer,
919 920 921 922 923
        dtype=dtype,
        trainable=trainable,
        use_resource=True,
        synchronization=synchronization,
        aggregation=aggregation)
924
    backend.track_variable(variable)
925 926 927

    return variable

928 929 930 931 932 933 934 935
  def _init_set_name(self, name, zero_based=True):
    if not name:
      self._name = backend.unique_object_name(
          generic_utils.to_snake_case(self.__class__.__name__),
          zero_based=zero_based)
    else:
      self._name = name

936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959
  def _assert_valid_dtypes(self, tensors):
    """Asserts tensors are all valid types (see `_valid_dtypes`).

    Args:
      tensors: Tensors to check.

    Raises:
      ValueError: If any tensor is not a valid type.
    """
    valid_dtypes = self._valid_dtypes()
    for t in tensors:
      dtype = t.dtype.base_dtype
      if dtype not in valid_dtypes:
        raise ValueError("Invalid type %r for %s, expected: %s." %
                         (dtype, t.name, [v for v in valid_dtypes]))

  def _valid_dtypes(self):
    """Valid types for loss, variables and gradients.

    Subclasses should override to allow other float types.

    Returns:
      Valid types for loss, variables and gradients.
    """
G
Gaurav Jain 已提交
960 961 962 963
    return set([
        dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64,
        dtypes.complex64, dtypes.complex128
    ])
964 965 966 967 968

  def _call_if_callable(self, param):
    """Call the function if param is callable."""
    return param() if callable(param) else param

969
  def _resource_apply_dense(self, grad, handle, apply_state):
970 971 972 973 974 975
    """Add ops to apply dense gradients to the variable `handle`.

    Args:
      grad: a `Tensor` representing the gradient.
      handle: a `Tensor` of dtype `resource` which points to the variable to be
        updated.
976
      apply_state: A dict which is used across multiple apply calls.
977 978 979 980 981 982

    Returns:
      An `Operation` which updates the value of the variable.
    """
    raise NotImplementedError()

983 984
  def _resource_apply_sparse_duplicate_indices(self, grad, handle, indices,
                                               **kwargs):
985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000
    """Add ops to apply sparse gradients to `handle`, with repeated indices.

    Optimizers which override this method must deal with repeated indices. See
    the docstring of `_apply_sparse_duplicate_indices` for details. By default
    the correct behavior, to sum non-unique indices and their associated
    gradients, is enforced by first pre-processing `grad` and `indices` and
    passing them on to `_resource_apply_sparse`. Optimizers which deal correctly
    with duplicate indices may instead override this method to avoid the
    overhead of summing.

    Args:
      grad: a `Tensor` representing the gradient for the affected indices.
      handle: a `Tensor` of dtype `resource` which points to the variable to be
        updated.
      indices: a `Tensor` of integral type representing the indices for which
        the gradient is nonzero. Indices may be repeated.
1001
      **kwargs: May optionally contain `apply_state`
1002 1003 1004 1005 1006 1007

    Returns:
      An `Operation` which updates the value of the variable.
    """
    summed_grad, unique_indices = _deduplicate_indexed_slices(
        values=grad, indices=indices)
1008 1009
    return self._resource_apply_sparse(summed_grad, handle, unique_indices,
                                       **kwargs)
1010

1011
  def _resource_apply_sparse(self, grad, handle, indices, apply_state):
1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024
    """Add ops to apply sparse gradients to the variable `handle`.

    Similar to `_apply_sparse`, the `indices` argument to this method has been
    de-duplicated. Optimizers which deal correctly with non-unique indices may
    instead override `_resource_apply_sparse_duplicate_indices` to avoid this
    overhead.

    Args:
      grad: a `Tensor` representing the gradient for the affected indices.
      handle: a `Tensor` of dtype `resource` which points to the variable to be
        updated.
      indices: a `Tensor` of integral type representing the indices for which
        the gradient is nonzero. Indices are unique.
1025
      apply_state: A dict which is used across multiple apply calls.
1026 1027 1028 1029 1030 1031

    Returns:
      An `Operation` which updates the value of the variable.
    """
    raise NotImplementedError()

1032 1033 1034 1035 1036 1037 1038 1039 1040 1041
  def _resource_scatter_add(self, x, i, v):
    with ops.control_dependencies(
        [resource_variable_ops.resource_scatter_add(x.handle, i, v)]):
      return x.value()

  def _resource_scatter_update(self, x, i, v):
    with ops.control_dependencies(
        [resource_variable_ops.resource_scatter_update(x.handle, i, v)]):
      return x.value()

1042 1043 1044 1045 1046 1047 1048 1049 1050 1051
  @property
  @tracking.cached_per_instance
  def _dense_apply_args(self):
    return tf_inspect.getfullargspec(self._resource_apply_dense).args

  @property
  @tracking.cached_per_instance
  def _sparse_apply_args(self):
    return tf_inspect.getfullargspec(self._resource_apply_sparse).args

A
Allen Lavoie 已提交
1052
  # ---------------
1053
  # For implementing the trackable interface
A
Allen Lavoie 已提交
1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083
  # ---------------

  def _restore_slot_variable(self, slot_name, variable, slot_variable):
    """Restore a newly created slot variable's value."""
    variable_key = _var_key(variable)
    deferred_restorations = self._deferred_slot_restorations.get(
        slot_name, {}).pop(variable_key, [])
    # Iterate over restores, highest restore UID first to minimize the number
    # of assignments.
    deferred_restorations.sort(key=lambda position: position.restore_uid,
                               reverse=True)
    for checkpoint_position in deferred_restorations:
      checkpoint_position.restore(slot_variable)

  def _create_or_restore_slot_variable(
      self, slot_variable_position, slot_name, variable):
    """Restore a slot variable's value, possibly creating it.

    Called when a variable which has an associated slot variable is created or
    restored. When executing eagerly, we create the slot variable with a
    restoring initializer.

    No new variables are created when graph building. Instead,
    _restore_slot_variable catches these after normal creation and adds restore
    ops to the graph. This method is nonetheless important when graph building
    for the case when a slot variable has already been created but `variable`
    has just been added to a dependency graph (causing us to realize that the
    slot variable needs to be restored).

    Args:
1084 1085
      slot_variable_position: A `trackable._CheckpointPosition` object
        indicating the slot variable `Trackable` object to be restored.
A
Allen Lavoie 已提交
1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102
      slot_name: The name of this `Optimizer`'s slot to restore into.
      variable: The variable object this slot is being created for.
    """
    variable_key = _var_key(variable)
    slot_dict = self._slots.get(variable_key, {})
    slot_variable = slot_dict.get(slot_name, None)
    if (slot_variable is None and context.executing_eagerly() and
        slot_variable_position.is_simple_variable()
        # Defer slot variable creation if there is an active variable creator
        # scope. Generally we'd like to eagerly create/restore slot variables
        # when possible, but this may mean that scopes intended to catch
        # `variable` also catch its eagerly created slot variable
        # unintentionally (specifically make_template would add a dependency on
        # a slot variable if not for this case). Deferring is mostly harmless
        # (aside from double initialization), and makes variable creator scopes
        # behave the same way they do when graph building.
        and not ops.get_default_graph()._variable_creator_stack):  # pylint: disable=protected-access
1103
      initializer = trackable.CheckpointInitialValue(
A
Allen Lavoie 已提交
1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128
          checkpoint_position=slot_variable_position)
      slot_variable = self.add_slot(
          var=variable,
          initializer=initializer,
          slot_name=slot_name)
      # Slot variables are not owned by any one object (because we don't want to
      # save the slot variable if the optimizer is saved without the non-slot
      # variable, or if the non-slot variable is saved without the optimizer;
      # it's a dependency hypergraph with edges of the form (optimizer, non-slot
      # variable, variable)). So we don't _track_ slot variables anywhere, and
      # instead special-case this dependency and otherwise pretend it's a normal
      # graph.
    if slot_variable is not None:
      # If we've either made this slot variable, or if we've pulled out an
      # existing slot variable, we should restore it.
      slot_variable_position.restore(slot_variable)
    else:
      # We didn't make the slot variable. Defer restoring until it gets created
      # normally. We keep a list rather than the one with the highest restore
      # UID in case slot variables have their own dependencies, in which case
      # those could differ between restores.
      self._deferred_slot_restorations.setdefault(
          slot_name, {}).setdefault(variable_key, []).append(
              slot_variable_position)

1129 1130 1131 1132 1133

def _filter_grads(grads_and_vars):
  """Filter out iterable with grad equal to None."""
  grads_and_vars = tuple(grads_and_vars)
  if not grads_and_vars:
1134
    return grads_and_vars
1135 1136 1137 1138 1139 1140 1141 1142 1143 1144
  filtered = []
  vars_with_empty_grads = []
  for grad, var in grads_and_vars:
    if grad is None:
      vars_with_empty_grads.append(var)
    else:
      filtered.append((grad, var))
  filtered = tuple(filtered)
  if not filtered:
    raise ValueError("No gradients provided for any variable: %s." %
Z
Zhenyu Tan 已提交
1145
                     ([v.name for _, v in grads_and_vars],))
1146 1147
  if vars_with_empty_grads:
    logging.warning(
1148
        ("Gradients do not exist for variables %s when minimizing the loss."),
1149 1150 1151 1152
        ([v.name for v in vars_with_empty_grads]))
  return filtered


1153 1154
def _var_key(var):
  """Key for representing a primary variable, for looking up slots.
1155

1156 1157 1158
  In graph mode the name is derived from the var shared name.
  In eager mode the name is derived from the var unique id.
  If distribution strategy exists, get the primary variable first.
1159 1160 1161 1162 1163

  Args:
    var: the variable.

  Returns:
1164
    the unique name of the variable.
1165 1166 1167
  """

  # pylint: disable=protected-access
1168
  # Get the distributed variable if it exists.
1169
  if hasattr(var, "_distributed_container"):
1170 1171
    var = var._distributed_container()
  if var._in_graph_mode:
1172 1173 1174 1175 1176 1177 1178 1179
    return var._shared_name
  return var._unique_id


def _get_slot_key_from_var(var, slot_name):
  """Get the slot key for the variable: var_name/slot_name."""

  name = _var_key(var)
1180
  return name + "/" + slot_name
1181 1182


1183
class RestoredOptimizer(OptimizerV2):
1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194
  """A non-functional Optimizer implementation for checkpoint compatibility.

  Holds slot variables and hyperparameters when an optimizer is restored from a
  SavedModel. These variables may be referenced in functions along with ops
  created by the original optimizer, but currently we do not support using the
  optimizer object iself (e.g. through `apply_gradients`).
  """
  # TODO(allenl): Make the restored optimizer functional by tracing its apply
  # methods.

  def __init__(self):
1195
    super(RestoredOptimizer, self).__init__("RestoredOptimizer")
1196 1197 1198 1199 1200
    self._hypers_created = True

  def get_config(self):
    # TODO(allenl): Save and restore the Optimizer's config
    raise NotImplementedError(
K
Kazuaki Ishizaki 已提交
1201
        "Restoring functional Optimizers from SavedModels is not currently "
1202 1203 1204 1205 1206 1207 1208
        "supported. Please file a feature request if this limitation bothers "
        "you.")

revived_types.register_revived_type(
    "optimizer",
    lambda obj: isinstance(obj, OptimizerV2),
    versions=[revived_types.VersionedTypeRegistration(
1209
        object_factory=lambda proto: RestoredOptimizer(),
1210 1211 1212
        version=1,
        min_producer_version=1,
        min_consumer_version=1,
1213
        setter=RestoredOptimizer._set_hyper  # pylint: disable=protected-access
1214
    )])