optimizer_v2.py 49.5 KB
Newer Older
1
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Version 2 of class Optimizer."""
# pylint: disable=g-bad-name

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import abc
24
import functools
25

26
import six
27

28
from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx
29
from tensorflow.python.distribute import parameter_server_strategy
30
from tensorflow.python.distribute import reduce_util as ds_reduce_util
31
from tensorflow.python.distribute import values as ds_values
32
from tensorflow.python.eager import backprop
A
Allen Lavoie 已提交
33
from tensorflow.python.eager import context
34 35
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
36
from tensorflow.python.framework import tensor_util
37
from tensorflow.python.keras import backend
38
from tensorflow.python.keras import initializers
39
from tensorflow.python.keras.engine import base_layer_utils
40
from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule
41
from tensorflow.python.keras.utils import generic_utils
42
from tensorflow.python.keras.utils import tf_utils
43 44
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
45
from tensorflow.python.ops import control_flow_ops
46
from tensorflow.python.ops import gradients
47
from tensorflow.python.ops import math_ops
48
from tensorflow.python.ops import resource_variable_ops
49
from tensorflow.python.ops import variables as tf_variables
50
from tensorflow.python.platform import tf_logging as logging
51
from tensorflow.python.saved_model import revived_types
52
from tensorflow.python.training.tracking import base as trackable
53
from tensorflow.python.training.tracking import tracking
54
from tensorflow.python.util import nest
55
from tensorflow.python.util import tf_inspect
56
from tensorflow.python.util.tf_export import keras_export
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76


def _deduplicate_indexed_slices(values, indices):
  """Sums `values` associated with any non-unique `indices`.

  Args:
    values: A `Tensor` with rank >= 1.
    indices: A one-dimensional integer `Tensor`, indexing into the first
      dimension of `values` (as in an IndexedSlices object).

  Returns:
    A tuple of (`summed_values`, `unique_indices`) where `unique_indices` is a
    de-duplicated version of `indices` and `summed_values` contains the sum of
    `values` slices associated with each unique index.
  """
  unique_indices, new_index_positions = array_ops.unique(indices)
  summed_values = math_ops.unsorted_segment_sum(
      values, new_index_positions,
      array_ops.shape(unique_indices)[0])
  return (summed_values, unique_indices)
77 78


79
@six.add_metaclass(abc.ABCMeta)
80
@keras_export("keras.optimizers.Optimizer")
81
class OptimizerV2(trackable.Trackable):
82 83 84 85
  """Updated base class for optimizers.

  This class defines the API to add Ops to train a model.  You never use this
  class directly, but instead instantiate one of its subclasses such as
86
  `tf.keras.optimizers.SGD`, `tf.keras.optimizers.Adam`.
87 88 89 90 91

  ### Usage

  ```python
  # Create an optimizer with the desired parameters.
92 93 94 95 96 97 98
  opt = tf.keras.optimizers.SGD(learning_rate=0.1)
  # `loss` is a callable that takes no argument and returns the value
  # to minimize.
  loss = lambda: 3 * var1 * var1 + 2 * var2 * var2
  # In graph mode, returns op that minimizes the loss by updating the listed
  # variables.
  opt_op = opt.minimize(loss, var_list=[var1, var2])
99
  opt_op.run()
100 101
  # In eager mode, simply call minimize to update the list of variables.
  opt.minimize(loss, var_list=[var1, var2])
102 103
  ```

104 105 106 107 108 109 110 111 112 113 114 115
  ### Custom training loop with Keras models

  In Keras models, sometimes variables are created when the model is first
  called, instead of construction time. Examples include 1) sequential models
  without input shape pre-defined, or 2) subclassed models. Pass var_list as
  callable in these cases.

  Example:
  ```python
  opt = tf.keras.optimizers.SGD(learning_rate=0.1)
  model = tf.keras.Sequential()
  model.add(tf.keras.layers.Dense(num_hidden, activation='relu'))
116
  model.add(tf.keras.layers.Dense(num_classes, activation='sigmoid'))
117 118 119 120 121 122
  loss_fn = lambda: tf.keras.losses.mse(model(input), output)
  var_list_fn = lambda: model.trainable_weights
  for input, output in data:
    opt.minimize(loss_fn, var_list_fn)
  ```

123 124 125 126 127 128
  ### Processing gradients before applying them.

  Calling `minimize()` takes care of both computing the gradients and
  applying them to the variables.  If you want to process the gradients
  before applying them you can instead use the optimizer in three steps:

129
  1.  Compute the gradients with `tf.GradientTape`.
130 131 132 133 134 135 136
  2.  Process the gradients as you wish.
  3.  Apply the processed gradients with `apply_gradients()`.

  Example:

  ```python
  # Create an optimizer.
137
  opt = tf.keras.optimizers.SGD(learning_rate=0.1)
138 139

  # Compute the gradients for a list of variables.
140 141 142 143
  with tf.GradientTape() as tape:
    loss = <call_loss_function>
  vars = <list_of_variables>
  grads = tape.gradient(loss, vars)
144

145
  # Process the gradients, for example cap them, etc.
146 147
  # capped_grads = [MyCapper(g) for g in grads]
  processed_grads = [process_gradient(g) for g in grads]
148

149 150
  # Ask the optimizer to apply the processed gradients.
  opt.apply_gradients(zip(processed_grads, var_list))
151 152
  ```

153 154 155 156
  ### Use with `tf.distribute.Strategy`.

  This optimizer class is `tf.distribute.Strategy` aware, which means it
  automatically sums gradients across all replicas. To average gradients,
157 158 159
  you divide your loss by the global batch size, which is done
  automatically if you use `tf.keras` built-in training or evaluation loops.
  See the `reduction` argument of your loss which should be set to
160 161
  `tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE` for averaging or
  `tf.keras.losses.Reduction.SUM` for not.
162

163
  To aggregate gradients yourself, call `apply_gradients` with
164 165
  `experimental_aggregate_gradients` set to False. This is useful if you need to
  process aggregated gradients.
166

167 168 169 170 171 172 173 174
  If you are not using these and you want to average gradients, you should use
  `tf.math.reduce_sum` to add up your per-example losses and then divide by the
  global batch size. Note that when using `tf.distribute.Strategy`, the first
  component of a tensor's shape is the *replica-local* batch size, which is off
  by a factor equal to the number of replicas being used to compute a single
  step. As a result, using `tf.math.reduce_mean` will give the wrong answer,
  resulting in gradients that can be many times too big.

175
  ### Variable Constraint
176

177 178 179 180 181 182 183 184 185 186
  All Keras optimizers respect variable constraints. If constraint function is
  passed to any variable, the constraint will be applied to the variable after
  the gradient has been applied to the variable.
  Important: If gradient is sparse tensor, variable constraint is not supported.

  ### Thread Compatibility

  The entire optimizer is currently thread compatible, not thread-safe. The user
  needs to perform synchronization if necessary.

187 188
  ### Slots

189 190 191 192 193
  Many optimizer subclasses, such as `Adam` and `Adagrad` allocate and manage
  additional variables associated with the variables to train.  These are called
  <i>Slots</i>.  Slots have names and you can ask the optimizer for the names of
  the slots that it uses.  Once you have a slot name you can ask the optimizer
  for the variable it created to hold the slot value.
194 195 196 197 198 199 200 201 202 203 204 205

  This can be useful if you want to log debug a training algorithm, report stats
  about the slots, etc.

  ### Hyper parameters

  These are arguments passed to the optimizer subclass constructor
  (the `__init__` method), and then passed to `self._set_hyper()`.
  They can be either regular Python values (like 1.0), tensors, or
  callables. If they are callable, the callable will be called during
  `apply_gradients()` to get the value for the hyper parameter.

206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223
  Hyper parameters can be overwritten through user code:

  Example:

  ```python
  # Create an optimizer with the desired parameters.
  opt = tf.keras.optimizers.SGD(learning_rate=0.1)
  # `loss` is a callable that takes no argument and returns the value
  # to minimize.
  loss = lambda: 3 * var1 + 2 * var2
  # In eager mode, simply call minimize to update the list of variables.
  opt.minimize(loss, var_list=[var1, var2])
  # update learning rate
  opt.learning_rate = 0.05
  opt.minimize(loss, var_list=[var1, var2])
  ```

  ### Write a customized optimizer.
224 225
  If you intend to create your own optimization algorithm, simply inherit from
  this class and override the following methods:
226 227 228 229 230

    - resource_apply_dense (update variable given gradient tensor is dense)
    - resource_apply_sparse (update variable given gradient tensor is sparse)
    - create_slots (if your optimizer algorithm requires additional variables)
    - get_config (serialization of the optimizer, include all hyper parameters)
231 232
  """

233
  # Subclasses should set this to True unless they override `apply_gradients`
234 235 236 237 238
  # with a version that does not have the `experimental_aggregate_gradients`
  # argument.  Older versions of Keras did not have this argument so custom
  # optimizers may have overridden `apply_gradients` without the
  # `experimental_aggregate_gradients` argument. Keras only passes
  # `experimental_aggregate_gradients` if this attribute is True.
239
  # Note: This attribute will likely be removed in an upcoming release.
240
  _HAS_AGGREGATE_GRAD = False
241

242
  def __init__(self, name, **kwargs):
243 244 245 246 247 248 249 250
    """Create a new Optimizer.

    This must be called by the constructors of subclasses.
    Note that Optimizer instances should not bind to a single graph,
    and so shouldn't keep Tensors as member variables. Generally
    you should be able to use the _set_hyper()/state.get_hyper()
    facility instead.

251 252
    This class in stateful and thread-compatible.

253 254 255
    Args:
      name: A non-empty string.  The name to use for accumulators created
        for the optimizer.
256 257 258 259 260
      **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`,
        `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip
        gradients by value, `decay` is included for backward compatibility to
        allow time inverse decay of learning rate. `lr` is included for backward
        compatibility, recommended to use `learning_rate` instead.
261 262 263 264 265 266

    Raises:
      ValueError: If name is malformed.
      RuntimeError: If _create_slots has been overridden instead of
          _create_vars.
    """
267 268 269 270 271 272
    allowed_kwargs = {"clipnorm", "clipvalue", "lr", "decay"}
    for k in kwargs:
      if k not in allowed_kwargs:
        raise TypeError("Unexpected keyword argument "
                        "passed to optimizer: " + str(k))
      # checks that all keyword arguments are non-negative.
273
      if kwargs[k] is not None and kwargs[k] < 0:
274 275
        raise ValueError("Expected {} >= 0, received: {}".format(k, kwargs[k]))

276
    self._use_locking = True
277
    self._init_set_name(name)
278
    self._hyper = {}
279
    # dict: {variable name : {slot name : variable}}
280
    self._slots = {}
A
Allen Lavoie 已提交
281
    self._slot_names = []
282
    self._weights = []
283
    self._iterations = None
284

285
    # For implementing Trackable. Stores information about how to restore
A
Allen Lavoie 已提交
286
    # slot variables which have not yet been created
287
    # (trackable._CheckpointPosition objects).
A
Allen Lavoie 已提交
288 289 290 291 292
    #  {slot_name :
    #      {_var_key(variable_to_train): [checkpoint_position, ... ], ... },
    #   ... }
    self._deferred_slot_restorations = {}

293 294 295 296
    decay = kwargs.pop("decay", 0.0)
    if decay < 0.:
      raise ValueError("decay cannot be less than 0: {}".format(decay))
    self._initial_decay = decay
297 298 299 300 301 302 303 304 305

    # Set the gradient clipping properties
    self.clipnorm = kwargs.pop("clipnorm", None)
    self.clipvalue = kwargs.pop("clipvalue", None)
    if ((self.clipnorm is not None or self.clipvalue is not None)
        and distribute_ctx.has_strategy()):
      raise ValueError("Gradient clipping in the optimizer "
                       "(by setting clipnorm or clipvalue) is currently "
                       "unsupported when using a distribution strategy.")
306

307
    self._hypers_created = False
308

309
  def minimize(self, loss, var_list, grad_loss=None, name=None):
310
    """Minimize `loss` by updating `var_list`.
311

312
    This method simply computes gradient using `tf.GradientTape` and calls
313
    `apply_gradients()`. If you want to process the gradient before applying
314
    then call `tf.GradientTape` and `apply_gradients()` explicitly instead
315 316 317
    of using this function.

    Args:
318
      loss: A callable taking no arguments which returns the value to minimize.
319
      var_list: list or tuple of `Variable` objects to update to minimize
320 321 322 323
        `loss`, or a callable returning the list or tuple of `Variable` objects.
        Use callable when the variable list would otherwise be incomplete before
        `minimize` since the variables are created at the first time `loss` is
        called.
324
      grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
325
      name: Optional name for the returned operation.
326 327

    Returns:
328 329
      An `Operation` that updates the variables in `var_list`. The `iterations`
      will be automatically increased by 1.
330 331 332 333 334

    Raises:
      ValueError: If some of the variables are not `Variable` objects.

    """
335 336
    grads_and_vars = self._compute_gradients(
        loss, var_list=var_list, grad_loss=grad_loss)
337 338 339

    return self.apply_gradients(grads_and_vars, name=name)

340 341 342 343 344 345 346
  def _clip_gradients(self, grads):
    """Clip gradients according to the clipnorm and clipvalue attributes."""
    if self.clipnorm is not None:
      if distribute_ctx.has_strategy():
        raise ValueError("Gradient clipping in the optimizer "
                         "(by setting clipnorm or clipvalue) is currently "
                         "unsupported when using a distribution strategy.")
347 348
      grads = [None if g is None else clip_ops.clip_by_norm(g, self.clipnorm)
               for g in grads]
349 350 351 352 353
    if self.clipvalue is not None:
      if distribute_ctx.has_strategy():
        raise ValueError("Gradient clipping in the optimizer "
                         "(by setting clipnorm or clipvalue) is currently "
                         "unsupported when using a distribution strategy.")
354
      v = self.clipvalue
355
      grads = [
356
          None if g is None else clip_ops.clip_by_value(g, -v, v) for g in grads
357 358 359
      ]
    return grads

360
  def _compute_gradients(self, loss, var_list, grad_loss=None):
361 362 363 364 365 366 367 368 369
    """Compute gradients of `loss` for the variables in `var_list`.

    This is the first part of `minimize()`.  It returns a list
    of (gradient, variable) pairs where "gradient" is the gradient
    for "variable".  Note that "gradient" can be a `Tensor`, an
    `IndexedSlices`, or `None` if there is no gradient for the
    given variable.

    Args:
370
      loss: A callable taking no arguments which returns the value to minimize.
371 372 373 374 375
      var_list: list or tuple of `Variable` objects to update to minimize
        `loss`, or a callable returning the list or tuple of `Variable` objects.
        Use callable when the variable list would otherwise be incomplete before
        `minimize` and the variables are created at the first time when `loss`
        is called.
376 377 378 379 380 381 382 383 384 385 386
      grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.

    Returns:
      A list of (gradient, variable) pairs. Variable is always present, but
      gradient can be `None`.

    Raises:
      TypeError: If `var_list` contains anything else than `Variable` objects.
      ValueError: If some arguments are invalid, or var_list is None.
    """
    # TODO(josh11b): Test that we handle weight decay in a reasonable way.
387
    with backprop.GradientTape() as tape:
388 389
      if not callable(var_list):
        tape.watch(var_list)
390
      loss_value = loss()
391 392 393
    if callable(var_list):
      var_list = var_list()
    var_list = nest.flatten(var_list)
394
    with backend.name_scope(self._name + "/gradients"):
395
      grads = tape.gradient(loss_value, var_list, grad_loss)
396
      grads = self._clip_gradients(grads)
397 398 399 400 401 402 403 404 405

    grads_and_vars = list(zip(grads, var_list))
    self._assert_valid_dtypes([
        v for g, v in grads_and_vars
        if g is not None and v.dtype != dtypes.resource
    ])

    return grads_and_vars

406 407 408 409 410 411 412 413 414 415 416 417 418 419
  def get_gradients(self, loss, params):
    """Returns gradients of `loss` with respect to `params`.

    Arguments:
      loss: Loss tensor.
      params: List of variables.

    Returns:
      List of gradient tensors.

    Raises:
      ValueError: In case any gradient cannot be computed (e.g. if gradient
        function not implemented).
    """
420
    params = nest.flatten(params)
421 422
    with backend.get_graph().as_default(), backend.name_scope(self._name +
                                                              "/gradients"):
423
      grads = gradients.gradients(loss, params)
424 425 426 427 428 429 430
      for grad, param in zip(grads, params):
        if grad is None:
          raise ValueError("Variable {} has `None` for gradient. "
                           "Please make sure that all of your ops have a "
                           "gradient defined (i.e. are differentiable). "
                           "Common ops without gradient: "
                           "K.argmax, K.round, K.eval.".format(param))
431
      grads = self._clip_gradients(grads)
432 433
    return grads

434 435 436
  def apply_gradients(self,
                      grads_and_vars,
                      name=None,
437
                      experimental_aggregate_gradients=True):
438 439 440 441 442
    """Apply gradients to variables.

    This is the second part of `minimize()`. It returns an `Operation` that
    applies gradients.

443 444
    The method sums gradients from all replicas in the presence of
    `tf.distribute.Strategy` by default. You can aggregate gradients yourself by
445
    passing `experimental_aggregate_gradients=False`.
446 447 448 449 450 451 452

    Example:

    ```python
    grads = tape.gradient(loss, vars)
    grads = tf.distribute.get_replica_context().all_reduce('sum', grads)
    # Processing aggregated gradients.
453 454
    optimizer.apply_gradients(zip(grads, vars),
        experimental_aggregate_gradients=False)
455 456 457

    ```

458
    Args:
459
      grads_and_vars: List of (gradient, variable) pairs.
R
Ran Chen 已提交
460 461
      name: Optional name for the returned operation. Default to the name passed
        to the `Optimizer` constructor.
462
      experimental_aggregate_gradients: Whether to sum gradients from different
463 464
        replicas in the presense of `tf.distribute.Strategy`. If False, it's
        user responsibility to aggregate the gradients. Default to True.
465 466

    Returns:
Z
Zhenyu Tan 已提交
467
      An `Operation` that applies the specified gradients. The `iterations`
468
      will be automatically increased by 1.
469 470 471 472 473 474 475 476

    Raises:
      TypeError: If `grads_and_vars` is malformed.
      ValueError: If none of the variables have gradients.
    """
    grads_and_vars = _filter_grads(grads_and_vars)
    var_list = [v for (_, v) in grads_and_vars]

477
    with backend.name_scope(self._name):
478 479 480 481 482
      # Create iteration if necessary.
      with ops.init_scope():
        _ = self.iterations
        self._create_hypers()
        self._create_slots(var_list)
483

484 485 486 487
      if not grads_and_vars:
        # Distribution strategy does not support reducing an empty list of
        # gradients
        return control_flow_ops.no_op()
488 489 490 491

      if distribute_ctx.in_cross_replica_context():
        raise RuntimeError(
            "`apply_gradients() cannot be called in cross-replica context. "
492 493
            "Use `tf.distribute.Strategy.experimental_run_v2` to enter replica "
            "context.")
494

495 496 497 498 499 500 501 502
      strategy = distribute_ctx.get_strategy()
      if (not experimental_aggregate_gradients and strategy and isinstance(
          strategy.extended,
          parameter_server_strategy.ParameterServerStrategyExtended)):
        raise NotImplementedError(
            "`experimental_aggregate_gradients=False is not supported for "
            "ParameterServerStrategy and CentralStorageStrategy")

503
      apply_state = self._prepare(var_list)
504
      if experimental_aggregate_gradients:
505 506 507
        reduced_grads = self._aggregate_gradients(grads_and_vars)
        var_list = [v for _, v in grads_and_vars]
        grads_and_vars = list(zip(reduced_grads, var_list))
508
      return distribute_ctx.get_replica_context().merge_call(
509
          functools.partial(self._distributed_apply, apply_state=apply_state),
510
          args=(grads_and_vars,),
511 512 513
          kwargs={
              "name": name,
          })
514

515 516 517 518 519 520 521 522 523 524
  def _aggregate_gradients(self, grads_and_vars):
    """Returns all-reduced gradients.

    Args:
      grads_and_vars: List of (gradient, variable) pairs.

    Returns:
      A list of all-reduced gradients.
    """
    grads_and_vars = list(grads_and_vars)
525
    filtered_grads_and_vars = _filter_grads(grads_and_vars)
526 527 528 529 530 531 532 533
    def all_reduce_fn(distribution, grads_and_vars):
      return distribution.extended.batch_reduce_to(
          ds_reduce_util.ReduceOp.SUM, grads_and_vars)
    # We switch to a cross-replica context since there is a bug which causes
    # IndexedSlices to be converted to dense tensors when all-reduced in a
    # replica context.
    # TODO(b/150507409): Do not switch to a cross-replica context once the bug
    # is fixed.
534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549
    if filtered_grads_and_vars:
      reduced = distribute_ctx.get_replica_context().merge_call(
          all_reduce_fn, args=(filtered_grads_and_vars,))
    else:
      reduced = []
    # Copy 'reduced' but add None gradients back in
    reduced_with_nones = []
    reduced_pos = 0
    for g, _ in grads_and_vars:
      if g is None:
        reduced_with_nones.append(None)
      else:
        reduced_with_nones.append(reduced[reduced_pos])
        reduced_pos += 1
    assert reduced_pos == len(reduced), "Failed to add all gradients"
    return reduced_with_nones
550

551
  def _distributed_apply(self, distribution, grads_and_vars, name, apply_state):
552 553 554
    """`apply_gradients` using a `DistributionStrategy`."""

    def apply_grad_to_update_var(var, grad):
555 556 557
      """Apply gradient to variable."""
      if isinstance(var, ops.Tensor):
        raise NotImplementedError("Trying to update a Tensor ", var)
558 559

      apply_kwargs = {}
560
      if isinstance(grad, ops.IndexedSlices):
561
        if var.constraint is not None:
562 563
          raise RuntimeError(
              "Cannot use a constraint function on a sparse variable.")
564 565
        if "apply_state" in self._sparse_apply_args:
          apply_kwargs["apply_state"] = apply_state
566
        return self._resource_apply_sparse_duplicate_indices(
567 568 569 570 571
            grad.values, var, grad.indices, **apply_kwargs)

      if "apply_state" in self._dense_apply_args:
        apply_kwargs["apply_state"] = apply_state
      update_op = self._resource_apply_dense(grad, var, **apply_kwargs)
572 573 574 575 576
      if var.constraint is not None:
        with ops.control_dependencies([update_op]):
          return var.assign(var.constraint(var))
      else:
        return update_op
577

578
    eagerly_outside_functions = ops.executing_eagerly_outside_functions()
579
    update_ops = []
580
    with ops.name_scope(name or self._name, skip_on_eager=True):
581
      for grad, var in grads_and_vars:
582 583 584 585 586 587 588 589
        # TODO(crccw): It's not allowed to assign PerReplica value to
        # MirroredVariable.  Remove this after we relax this restriction.
        def _assume_mirrored(grad):
          if isinstance(grad, ds_values.PerReplica):
            return ds_values.Mirrored(grad.values)
          return grad

        grad = nest.map_structure(_assume_mirrored, grad)
590 591
        # Colocate the update with variables to avoid unnecessary communication
        # delays. See b/136304694.
592 593 594 595 596
        with distribution.extended.colocate_vars_with(var):
          with ops.name_scope("update" if eagerly_outside_functions else
                              "update_" + var.op.name, skip_on_eager=True):
            update_ops.extend(distribution.extended.update(
                var, apply_grad_to_update_var, args=(grad,), group=False))
597 598 599 600 601 602 603 604 605 606 607 608

      any_symbolic = any(isinstance(i, ops.Operation) or
                         tf_utils.is_symbolic_tensor(i) for i in update_ops)
      if not context.executing_eagerly() or any_symbolic:
        # If the current context is graph mode or any of the update ops are
        # symbolic then the step update should be carried out under a graph
        # context. (eager updates execute immediately)
        with ops._get_graph_from_inputs(update_ops).as_default():  # pylint: disable=protected-access
          with ops.control_dependencies(update_ops):
            return self._iterations.assign_add(1).op

      return self._iterations.assign_add(1)
609

610
  def get_updates(self, loss, params):
611 612 613 614 615 616 617
    grads = self.get_gradients(loss, params)
    grads_and_vars = list(zip(grads, params))
    self._assert_valid_dtypes([
        v for g, v in grads_and_vars
        if g is not None and v.dtype != dtypes.resource
    ])
    return [self.apply_gradients(grads_and_vars)]
618

619
  def _set_hyper(self, name, value):
620
    """set hyper `name` to value. value can be callable, tensor, numeric."""
621 622
    if isinstance(value, trackable.Trackable):
      self._track_trackable(value, name, overwrite=True)
623 624 625 626
    if name not in self._hyper:
      self._hyper[name] = value
    else:
      prev_value = self._hyper[name]
627 628 629 630 631
      if (callable(prev_value)
          or isinstance(prev_value,
                        (ops.Tensor, int, float,
                         learning_rate_schedule.LearningRateSchedule))
          or isinstance(value, learning_rate_schedule.LearningRateSchedule)):
632 633
        self._hyper[name] = value
      else:
634
        backend.set_value(self._hyper[name], value)
635

636
  def _get_hyper(self, name, dtype=None):
Z
Zhenyu Tan 已提交
637 638
    if not self._hypers_created:
      self._create_hypers()
639
    value = self._hyper[name]
640 641
    if isinstance(value, learning_rate_schedule.LearningRateSchedule):
      return value
642 643 644 645 646 647
    if callable(value):
      value = value()
    if dtype:
      return math_ops.cast(value, dtype)
    else:
      return value
648

649 650 651 652 653 654 655 656 657 658 659 660
  def __getattribute__(self, name):
    """Overridden to support hyperparameter access."""
    try:
      return super(OptimizerV2, self).__getattribute__(name)
    except AttributeError as e:
      # Needed to avoid infinite recursion with __setattr__.
      if name == "_hyper":
        raise e
      # Backwards compatibility with Keras optimizers.
      if name == "lr":
        name = "learning_rate"
      if name in self._hyper:
Z
Zhenyu Tan 已提交
661
        return self._get_hyper(name)
662 663
      raise e

664 665
  def __setattr__(self, name, value):
    """Override setattr to support dynamic hyperparameter setting."""
666 667 668
    # Backwards compatibility with Keras optimizers.
    if name == "lr":
      name = "learning_rate"
669 670 671 672 673
    if hasattr(self, "_hyper") and name in self._hyper:
      self._set_hyper(name, value)
    else:
      super(OptimizerV2, self).__setattr__(name, value)

A
Allen Lavoie 已提交
674 675 676 677
  def get_slot_names(self):
    """A list of names for this optimizer's slots."""
    return self._slot_names

Z
Zhenyu Tan 已提交
678
  def add_slot(self, var, slot_name, initializer="zeros"):
A
Allen Lavoie 已提交
679 680 681
    """Add a new slot variable for `var`."""
    if slot_name not in self._slot_names:
      self._slot_names.append(slot_name)
682 683
    var_key = _var_key(var)
    slot_dict = self._slots.setdefault(var_key, {})
A
Allen Lavoie 已提交
684 685 686 687
    weight = slot_dict.get(slot_name, None)
    if weight is None:
      if isinstance(initializer, six.string_types) or callable(initializer):
        initializer = initializers.get(initializer)
688 689
        initial_value = functools.partial(
            initializer, shape=var.shape, dtype=var.dtype)
A
Allen Lavoie 已提交
690
      else:
691
        initial_value = initializer
Z
Zhenyu Tan 已提交
692
      strategy = distribute_ctx.get_strategy()
693 694 695 696 697 698 699 700 701
      if not strategy.extended.variable_created_in_scope(var):
        raise ValueError(
            "Trying to create optimizer slot variable under the scope for "
            "tf.distribute.Strategy ({}), which is different from the scope "
            "used for the original variable ({}). Make sure the slot "
            "variables are created under the same strategy scope. This may "
            "happen if you're restoring from a checkpoint outside the scope"
            .format(strategy, var))

702
      with strategy.extended.colocate_vars_with(var):
Z
Zhenyu Tan 已提交
703 704 705 706 707
        weight = tf_variables.Variable(
            name="%s/%s" % (var._shared_name, slot_name),  # pylint: disable=protected-access
            dtype=var.dtype,
            trainable=False,
            initial_value=initial_value)
A
Allen Lavoie 已提交
708
      backend.track_variable(weight)
709
      slot_dict[slot_name] = weight
A
Allen Lavoie 已提交
710 711 712
      self._restore_slot_variable(
          slot_name=slot_name, variable=var,
          slot_variable=weight)
713
      self._weights.append(weight)
A
Allen Lavoie 已提交
714
    return weight
715 716

  def get_slot(self, var, slot_name):
717 718 719
    var_key = _var_key(var)
    slot_dict = self._slots[var_key]
    return slot_dict[slot_name]
720

721
  def _prepare(self, var_list):
722 723
    keys = set()
    for var in var_list:
724 725 726 727
      if isinstance(var, ds_values.DistributedValues):
        var_devices = var._devices   # pylint: disable=protected-access
      else:
        var_devices = [var.device]
728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749
      var_dtype = var.dtype.base_dtype
      for var_device in var_devices:
        keys.add((var_device, var_dtype))

    apply_state = {}
    for var_device, var_dtype in keys:
      apply_state[(var_device, var_dtype)] = {}
      with ops.device(var_device):
        self._prepare_local(var_device, var_dtype, apply_state)

    return apply_state

  def _prepare_local(self, var_device, var_dtype, apply_state):
    if "learning_rate" in self._hyper:
      lr_t = array_ops.identity(self._decayed_lr(var_dtype))
      apply_state[(var_device, var_dtype)]["lr_t"] = lr_t

  def _fallback_apply_state(self, var_device, var_dtype):
    """Compatibility for subclasses that don't pass apply_state through."""
    apply_state = {(var_device, var_dtype): {}}
    self._prepare_local(var_device, var_dtype, apply_state)
    return apply_state[(var_device, var_dtype)]
750 751 752

  def _create_hypers(self):
    if self._hypers_created:
753
      return
754 755
    # Iterate hyper values deterministically.
    for name, value in sorted(self._hyper.items()):
756 757
      if isinstance(
          value, (ops.Tensor, tf_variables.Variable)) or callable(value):
758
        continue
759 760 761 762 763 764
      else:
        self._hyper[name] = self.add_weight(
            name,
            shape=[],
            trainable=False,
            initializer=value,
765
            aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA)
766
    self._hypers_created = True
767 768

  @property
769
  def iterations(self):
770
    """Variable. The number of training steps this Optimizer has run."""
771 772 773 774 775 776 777 778
    if self._iterations is None:
      self._iterations = self.add_weight(
          "iter",
          shape=[],
          dtype=dtypes.int64,
          trainable=False,
          aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA)
      self._weights.append(self._iterations)
779
    return self._iterations
780

781 782
  @iterations.setter
  def iterations(self, variable):
783
    if self._iterations is not None:
784
      raise RuntimeError("Cannot set `iterations` to a new Variable after "
785 786 787 788
                         "the Optimizer weights have been created")
    self._iterations = variable
    self._weights.append(self._iterations)

789 790 791
  def _decayed_lr(self, var_dtype):
    """Get decayed learning rate as a Tensor with dtype=var_dtype."""
    lr_t = self._get_hyper("learning_rate", var_dtype)
792 793 794
    if isinstance(lr_t, learning_rate_schedule.LearningRateSchedule):
      local_step = math_ops.cast(self.iterations, var_dtype)
      lr_t = math_ops.cast(lr_t(local_step), var_dtype)
795 796 797 798 799 800
    if self._initial_decay > 0.:
      local_step = math_ops.cast(self.iterations, var_dtype)
      decay_t = self._get_hyper("decay", var_dtype)
      lr_t = lr_t / (1. + decay_t * local_step)
    return lr_t

801
  @abc.abstractmethod
802
  def get_config(self):
K
Kazuaki Ishizaki 已提交
803
    """Returns the config of the optimizer.
804 805 806 807 808 809 810 811 812

    An optimizer config is a Python dictionary (serializable)
    containing the configuration of an optimizer.
    The same optimizer can be reinstantiated later
    (without any saved state) from this configuration.

    Returns:
        Python dictionary.
    """
813
    config = {"name": self._name}
814
    if self.clipnorm is not None:
815
      config["clipnorm"] = self.clipnorm
816
    if self.clipvalue is not None:
817 818
      config["clipvalue"] = self.clipvalue
    return config
819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836

  @classmethod
  def from_config(cls, config, custom_objects=None):
    """Creates an optimizer from its config.

    This method is the reverse of `get_config`,
    capable of instantiating the same optimizer from the config
    dictionary.

    Arguments:
        config: A Python dictionary, typically the output of get_config.
        custom_objects: A Python dictionary mapping names to additional Python
          objects used to create this optimizer, such as a function used for a
          hyperparameter.

    Returns:
        An optimizer instance.
    """
837 838
    if "lr" in config:
      config["learning_rate"] = config.pop("lr")
839 840 841
    if "learning_rate" in config:
      if isinstance(config["learning_rate"], dict):
        config["learning_rate"] = learning_rate_schedule.deserialize(
842
            config["learning_rate"], custom_objects=custom_objects)
843 844 845 846
    return cls(**config)

  def _serialize_hyperparameter(self, hyperparameter_name):
    """Serialize a hyperparameter that can be a float, callable, or Tensor."""
Z
Zhenyu Tan 已提交
847
    value = self._hyper[hyperparameter_name]
848 849
    if isinstance(value, learning_rate_schedule.LearningRateSchedule):
      return learning_rate_schedule.serialize(value)
850 851
    if callable(value):
      return value()
852
    if tensor_util.is_tensor(value):
853
      return backend.get_value(value)
854
    return value
855

856 857 858 859
  def variables(self):
    """Returns variables of this Optimizer based on the order created."""
    return self._weights

860 861 862 863 864 865
  @property
  def weights(self):
    """Returns variables of this Optimizer based on the order created."""
    return self._weights

  def get_weights(self):
866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891
    """Returns the current weights of the optimizer.

    The weights of an optimizer are its state (ie, variables).
    This function returns the weight values associated with this
    optimizer as a list of Numpy arrays. The first value is always the
    iterations count of the optimizer, followed by the optimizer's state
    variables in the order they were created. The returned list can in turn
    be used to load state into similarly parameterized optimizers.

    For example, the RMSprop optimizer for this simple model returns a list of
    three values-- the iteration count, followed by the root-mean-square value
    of the kernel and bias of the single Dense layer:

    >>> opt = tf.keras.optimizers.RMSprop()
    >>> m = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
    >>> m.compile(opt, loss='mse')
    >>> data = np.arange(100).reshape(5, 20)
    >>> labels = np.zeros(5)
    >>> print('Training'); results = m.fit(data, labels)
    Training ...
    >>> len(opt.get_weights())
    3

    Returns:
        Weights values as a list of numpy arrays.
    """
892 893 894 895 896
    params = self.weights
    return backend.batch_get_value(params)

  # TODO(tanzheny): Maybe share this logic with base_layer.
  def set_weights(self, weights):
S
Shreyash Patodia 已提交
897
    """Set the weights of the optimizer.
898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924

    The weights of an optimizer are its state (ie, variables).
    This function takes the weight values associated with this
    optimizer as a list of Numpy arrays. The first value is always the
    iterations count of the optimizer, followed by the optimizer's state
    variables in the order they are created. The passed values are used to set
    the new state of the optimizer.

    For example, the RMSprop optimizer for this simple model takes a list of
    three values-- the iteration count, followed by the root-mean-square value
    of the kernel and bias of the single Dense layer:

    >>> opt = tf.keras.optimizers.RMSprop()
    >>> m = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
    >>> m.compile(opt, loss='mse')
    >>> data = np.arange(100).reshape(5, 20)
    >>> labels = np.zeros(5)
    >>> print('Training'); results = m.fit(data, labels)
    Training ...
    >>> new_weights = [np.array(10), np.ones([20, 10]), np.zeros([10])]
    >>> opt.set_weights(new_weights)
    >>> opt.iterations
    <tf.Variable 'RMSprop/iter:0' shape=() dtype=int64, numpy=10>

    Arguments:
        weights: weight values as a list of numpy arrays.
    """
925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943
    params = self.weights
    if len(params) != len(weights):
      raise ValueError(
          "You called `set_weights(weights)` on optimizer " + self._name +
          " with a  weight list of length " + str(len(weights)) +
          ", but the optimizer was expecting " + str(len(params)) +
          " weights. Provided weights: " + str(weights)[:50] + "...")
    if not params:
      return
    weight_value_tuples = []
    param_values = backend.batch_get_value(params)
    for pv, p, w in zip(param_values, params, weights):
      if pv.shape != w.shape:
        raise ValueError("Optimizer weight shape " + str(pv.shape) +
                         " not compatible with "
                         "provided weight shape " + str(w.shape))
      weight_value_tuples.append((p, w))
    backend.batch_set_value(weight_value_tuples)

944 945 946 947 948 949
  def add_weight(self,
                 name,
                 shape,
                 dtype=None,
                 initializer="zeros",
                 trainable=None,
950 951
                 synchronization=tf_variables.VariableSynchronization.AUTO,
                 aggregation=tf_variables.VariableAggregation.NONE):
952 953 954

    if dtype is None:
      dtype = dtypes.float32
955 956
    if isinstance(initializer, six.string_types) or callable(initializer):
      initializer = initializers.get(initializer)
957

958
    if synchronization == tf_variables.VariableSynchronization.ON_READ:
959 960 961 962 963 964 965 966 967 968 969 970 971 972 973
      if trainable:
        raise ValueError(
            "Synchronization value can be set to "
            "VariableSynchronization.ON_READ only for non-trainable variables. "
            "You have specified trainable=True and "
            "synchronization=VariableSynchronization.ON_READ.")
      else:
        # Set trainable to be false when variable is to be synced on read.
        trainable = False
    elif trainable is None:
      trainable = True

    variable = self._add_variable_with_custom_getter(
        name=name,
        shape=shape,
974
        getter=base_layer_utils.make_variable,
975
        overwrite=True,
976
        initializer=initializer,
977 978 979 980 981
        dtype=dtype,
        trainable=trainable,
        use_resource=True,
        synchronization=synchronization,
        aggregation=aggregation)
982
    backend.track_variable(variable)
983 984 985

    return variable

986 987 988 989 990 991 992 993
  def _init_set_name(self, name, zero_based=True):
    if not name:
      self._name = backend.unique_object_name(
          generic_utils.to_snake_case(self.__class__.__name__),
          zero_based=zero_based)
    else:
      self._name = name

994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017
  def _assert_valid_dtypes(self, tensors):
    """Asserts tensors are all valid types (see `_valid_dtypes`).

    Args:
      tensors: Tensors to check.

    Raises:
      ValueError: If any tensor is not a valid type.
    """
    valid_dtypes = self._valid_dtypes()
    for t in tensors:
      dtype = t.dtype.base_dtype
      if dtype not in valid_dtypes:
        raise ValueError("Invalid type %r for %s, expected: %s." %
                         (dtype, t.name, [v for v in valid_dtypes]))

  def _valid_dtypes(self):
    """Valid types for loss, variables and gradients.

    Subclasses should override to allow other float types.

    Returns:
      Valid types for loss, variables and gradients.
    """
G
Gaurav Jain 已提交
1018 1019 1020 1021
    return set([
        dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64,
        dtypes.complex64, dtypes.complex128
    ])
1022 1023 1024 1025 1026

  def _call_if_callable(self, param):
    """Call the function if param is callable."""
    return param() if callable(param) else param

1027
  def _resource_apply_dense(self, grad, handle, apply_state):
1028 1029 1030 1031 1032 1033
    """Add ops to apply dense gradients to the variable `handle`.

    Args:
      grad: a `Tensor` representing the gradient.
      handle: a `Tensor` of dtype `resource` which points to the variable to be
        updated.
1034
      apply_state: A dict which is used across multiple apply calls.
1035 1036 1037 1038 1039 1040

    Returns:
      An `Operation` which updates the value of the variable.
    """
    raise NotImplementedError()

1041 1042
  def _resource_apply_sparse_duplicate_indices(self, grad, handle, indices,
                                               **kwargs):
1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058
    """Add ops to apply sparse gradients to `handle`, with repeated indices.

    Optimizers which override this method must deal with repeated indices. See
    the docstring of `_apply_sparse_duplicate_indices` for details. By default
    the correct behavior, to sum non-unique indices and their associated
    gradients, is enforced by first pre-processing `grad` and `indices` and
    passing them on to `_resource_apply_sparse`. Optimizers which deal correctly
    with duplicate indices may instead override this method to avoid the
    overhead of summing.

    Args:
      grad: a `Tensor` representing the gradient for the affected indices.
      handle: a `Tensor` of dtype `resource` which points to the variable to be
        updated.
      indices: a `Tensor` of integral type representing the indices for which
        the gradient is nonzero. Indices may be repeated.
1059
      **kwargs: May optionally contain `apply_state`
1060 1061 1062 1063 1064 1065

    Returns:
      An `Operation` which updates the value of the variable.
    """
    summed_grad, unique_indices = _deduplicate_indexed_slices(
        values=grad, indices=indices)
1066 1067
    return self._resource_apply_sparse(summed_grad, handle, unique_indices,
                                       **kwargs)
1068

1069
  def _resource_apply_sparse(self, grad, handle, indices, apply_state):
1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082
    """Add ops to apply sparse gradients to the variable `handle`.

    Similar to `_apply_sparse`, the `indices` argument to this method has been
    de-duplicated. Optimizers which deal correctly with non-unique indices may
    instead override `_resource_apply_sparse_duplicate_indices` to avoid this
    overhead.

    Args:
      grad: a `Tensor` representing the gradient for the affected indices.
      handle: a `Tensor` of dtype `resource` which points to the variable to be
        updated.
      indices: a `Tensor` of integral type representing the indices for which
        the gradient is nonzero. Indices are unique.
1083
      apply_state: A dict which is used across multiple apply calls.
1084 1085 1086 1087 1088 1089

    Returns:
      An `Operation` which updates the value of the variable.
    """
    raise NotImplementedError()

1090 1091 1092 1093 1094 1095 1096 1097 1098 1099
  def _resource_scatter_add(self, x, i, v):
    with ops.control_dependencies(
        [resource_variable_ops.resource_scatter_add(x.handle, i, v)]):
      return x.value()

  def _resource_scatter_update(self, x, i, v):
    with ops.control_dependencies(
        [resource_variable_ops.resource_scatter_update(x.handle, i, v)]):
      return x.value()

1100 1101 1102 1103 1104 1105 1106 1107 1108 1109
  @property
  @tracking.cached_per_instance
  def _dense_apply_args(self):
    return tf_inspect.getfullargspec(self._resource_apply_dense).args

  @property
  @tracking.cached_per_instance
  def _sparse_apply_args(self):
    return tf_inspect.getfullargspec(self._resource_apply_sparse).args

A
Allen Lavoie 已提交
1110
  # ---------------
1111
  # For implementing the trackable interface
A
Allen Lavoie 已提交
1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141
  # ---------------

  def _restore_slot_variable(self, slot_name, variable, slot_variable):
    """Restore a newly created slot variable's value."""
    variable_key = _var_key(variable)
    deferred_restorations = self._deferred_slot_restorations.get(
        slot_name, {}).pop(variable_key, [])
    # Iterate over restores, highest restore UID first to minimize the number
    # of assignments.
    deferred_restorations.sort(key=lambda position: position.restore_uid,
                               reverse=True)
    for checkpoint_position in deferred_restorations:
      checkpoint_position.restore(slot_variable)

  def _create_or_restore_slot_variable(
      self, slot_variable_position, slot_name, variable):
    """Restore a slot variable's value, possibly creating it.

    Called when a variable which has an associated slot variable is created or
    restored. When executing eagerly, we create the slot variable with a
    restoring initializer.

    No new variables are created when graph building. Instead,
    _restore_slot_variable catches these after normal creation and adds restore
    ops to the graph. This method is nonetheless important when graph building
    for the case when a slot variable has already been created but `variable`
    has just been added to a dependency graph (causing us to realize that the
    slot variable needs to be restored).

    Args:
1142 1143
      slot_variable_position: A `trackable._CheckpointPosition` object
        indicating the slot variable `Trackable` object to be restored.
A
Allen Lavoie 已提交
1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160
      slot_name: The name of this `Optimizer`'s slot to restore into.
      variable: The variable object this slot is being created for.
    """
    variable_key = _var_key(variable)
    slot_dict = self._slots.get(variable_key, {})
    slot_variable = slot_dict.get(slot_name, None)
    if (slot_variable is None and context.executing_eagerly() and
        slot_variable_position.is_simple_variable()
        # Defer slot variable creation if there is an active variable creator
        # scope. Generally we'd like to eagerly create/restore slot variables
        # when possible, but this may mean that scopes intended to catch
        # `variable` also catch its eagerly created slot variable
        # unintentionally (specifically make_template would add a dependency on
        # a slot variable if not for this case). Deferring is mostly harmless
        # (aside from double initialization), and makes variable creator scopes
        # behave the same way they do when graph building.
        and not ops.get_default_graph()._variable_creator_stack):  # pylint: disable=protected-access
1161
      initializer = trackable.CheckpointInitialValue(
A
Allen Lavoie 已提交
1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186
          checkpoint_position=slot_variable_position)
      slot_variable = self.add_slot(
          var=variable,
          initializer=initializer,
          slot_name=slot_name)
      # Slot variables are not owned by any one object (because we don't want to
      # save the slot variable if the optimizer is saved without the non-slot
      # variable, or if the non-slot variable is saved without the optimizer;
      # it's a dependency hypergraph with edges of the form (optimizer, non-slot
      # variable, variable)). So we don't _track_ slot variables anywhere, and
      # instead special-case this dependency and otherwise pretend it's a normal
      # graph.
    if slot_variable is not None:
      # If we've either made this slot variable, or if we've pulled out an
      # existing slot variable, we should restore it.
      slot_variable_position.restore(slot_variable)
    else:
      # We didn't make the slot variable. Defer restoring until it gets created
      # normally. We keep a list rather than the one with the highest restore
      # UID in case slot variables have their own dependencies, in which case
      # those could differ between restores.
      self._deferred_slot_restorations.setdefault(
          slot_name, {}).setdefault(variable_key, []).append(
              slot_variable_position)

1187 1188 1189 1190 1191

def _filter_grads(grads_and_vars):
  """Filter out iterable with grad equal to None."""
  grads_and_vars = tuple(grads_and_vars)
  if not grads_and_vars:
1192
    return grads_and_vars
1193 1194 1195 1196 1197 1198 1199 1200 1201 1202
  filtered = []
  vars_with_empty_grads = []
  for grad, var in grads_and_vars:
    if grad is None:
      vars_with_empty_grads.append(var)
    else:
      filtered.append((grad, var))
  filtered = tuple(filtered)
  if not filtered:
    raise ValueError("No gradients provided for any variable: %s." %
Z
Zhenyu Tan 已提交
1203
                     ([v.name for _, v in grads_and_vars],))
1204 1205
  if vars_with_empty_grads:
    logging.warning(
1206
        ("Gradients do not exist for variables %s when minimizing the loss."),
1207 1208 1209 1210
        ([v.name for v in vars_with_empty_grads]))
  return filtered


1211 1212
def _var_key(var):
  """Key for representing a primary variable, for looking up slots.
1213

1214 1215 1216
  In graph mode the name is derived from the var shared name.
  In eager mode the name is derived from the var unique id.
  If distribution strategy exists, get the primary variable first.
1217 1218 1219 1220 1221

  Args:
    var: the variable.

  Returns:
1222
    the unique name of the variable.
1223 1224 1225
  """

  # pylint: disable=protected-access
1226
  # Get the distributed variable if it exists.
1227
  if hasattr(var, "_distributed_container"):
1228 1229
    var = var._distributed_container()
  if var._in_graph_mode:
1230 1231 1232 1233 1234 1235 1236 1237
    return var._shared_name
  return var._unique_id


def _get_slot_key_from_var(var, slot_name):
  """Get the slot key for the variable: var_name/slot_name."""

  name = _var_key(var)
1238
  return name + "/" + slot_name
1239 1240


1241
class RestoredOptimizer(OptimizerV2):
1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252
  """A non-functional Optimizer implementation for checkpoint compatibility.

  Holds slot variables and hyperparameters when an optimizer is restored from a
  SavedModel. These variables may be referenced in functions along with ops
  created by the original optimizer, but currently we do not support using the
  optimizer object iself (e.g. through `apply_gradients`).
  """
  # TODO(allenl): Make the restored optimizer functional by tracing its apply
  # methods.

  def __init__(self):
1253
    super(RestoredOptimizer, self).__init__("RestoredOptimizer")
1254 1255 1256 1257 1258
    self._hypers_created = True

  def get_config(self):
    # TODO(allenl): Save and restore the Optimizer's config
    raise NotImplementedError(
K
Kazuaki Ishizaki 已提交
1259
        "Restoring functional Optimizers from SavedModels is not currently "
1260 1261 1262 1263 1264 1265 1266
        "supported. Please file a feature request if this limitation bothers "
        "you.")

revived_types.register_revived_type(
    "optimizer",
    lambda obj: isinstance(obj, OptimizerV2),
    versions=[revived_types.VersionedTypeRegistration(
1267
        object_factory=lambda proto: RestoredOptimizer(),
1268 1269 1270
        version=1,
        min_producer_version=1,
        min_consumer_version=1,
1271
        setter=RestoredOptimizer._set_hyper  # pylint: disable=protected-access
1272
    )])