optimizer_v2.py 39.0 KB
Newer Older
1
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Version 2 of class Optimizer."""
# pylint: disable=g-bad-name

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import abc
24
import functools
25

26
import six
27

28
from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx
29
from tensorflow.python.distribute import reduce_util as ds_reduce_util
30
from tensorflow.python.eager import backprop
A
Allen Lavoie 已提交
31
from tensorflow.python.eager import context
32 33
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
34
from tensorflow.python.framework import tensor_util
35
from tensorflow.python.keras import backend
36
from tensorflow.python.keras import initializers
37
from tensorflow.python.keras.engine import base_layer_utils
38
from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule
39
from tensorflow.python.keras.utils import tf_utils
40 41
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
42
from tensorflow.python.ops import gradients
43
from tensorflow.python.ops import math_ops
44
from tensorflow.python.ops import resource_variable_ops
45
from tensorflow.python.ops import variables as tf_variables
46
from tensorflow.python.platform import tf_logging as logging
47
from tensorflow.python.saved_model import revived_types
48
from tensorflow.python.training.tracking import base as trackable
49
from tensorflow.python.util import nest
50
from tensorflow.python.util.tf_export import keras_export
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70


def _deduplicate_indexed_slices(values, indices):
  """Sums `values` associated with any non-unique `indices`.

  Args:
    values: A `Tensor` with rank >= 1.
    indices: A one-dimensional integer `Tensor`, indexing into the first
      dimension of `values` (as in an IndexedSlices object).

  Returns:
    A tuple of (`summed_values`, `unique_indices`) where `unique_indices` is a
    de-duplicated version of `indices` and `summed_values` contains the sum of
    `values` slices associated with each unique index.
  """
  unique_indices, new_index_positions = array_ops.unique(indices)
  summed_values = math_ops.unsorted_segment_sum(
      values, new_index_positions,
      array_ops.shape(unique_indices)[0])
  return (summed_values, unique_indices)
71 72


73
@six.add_metaclass(abc.ABCMeta)
74
@keras_export("keras.optimizers.Optimizer")
75
class OptimizerV2(trackable.Trackable):
76 77 78 79
  """Updated base class for optimizers.

  This class defines the API to add Ops to train a model.  You never use this
  class directly, but instead instantiate one of its subclasses such as
80
  `tf.keras.optimizers.SGD`, `tf.keras.optimizers.Adam`.
81 82 83 84 85

  ### Usage

  ```python
  # Create an optimizer with the desired parameters.
86 87 88 89 90 91 92
  opt = tf.keras.optimizers.SGD(learning_rate=0.1)
  # `loss` is a callable that takes no argument and returns the value
  # to minimize.
  loss = lambda: 3 * var1 * var1 + 2 * var2 * var2
  # In graph mode, returns op that minimizes the loss by updating the listed
  # variables.
  opt_op = opt.minimize(loss, var_list=[var1, var2])
93
  opt_op.run()
94 95
  # In eager mode, simply call minimize to update the list of variables.
  opt.minimize(loss, var_list=[var1, var2])
96 97
  ```

98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
  ### Custom training loop with Keras models

  In Keras models, sometimes variables are created when the model is first
  called, instead of construction time. Examples include 1) sequential models
  without input shape pre-defined, or 2) subclassed models. Pass var_list as
  callable in these cases.

  Example:
  ```python
  opt = tf.keras.optimizers.SGD(learning_rate=0.1)
  model = tf.keras.Sequential()
  model.add(tf.keras.layers.Dense(num_hidden, activation='relu'))
  model.add(tf.keras.layers.Dense(num_classes, activation='sigmoid')
  loss_fn = lambda: tf.keras.losses.mse(model(input), output)
  var_list_fn = lambda: model.trainable_weights
  for input, output in data:
    opt.minimize(loss_fn, var_list_fn)
  ```

117 118 119 120 121 122
  ### Processing gradients before applying them.

  Calling `minimize()` takes care of both computing the gradients and
  applying them to the variables.  If you want to process the gradients
  before applying them you can instead use the optimizer in three steps:

123
  1.  Compute the gradients with `tf.GradientTape`.
124 125 126 127 128 129 130
  2.  Process the gradients as you wish.
  3.  Apply the processed gradients with `apply_gradients()`.

  Example:

  ```python
  # Create an optimizer.
131
  opt = tf.keras.optimizers.SGD(learning_rate=0.1)
132 133

  # Compute the gradients for a list of variables.
134 135 136 137 138 139
  with tf.GradientTape() as tape:
    loss = <call_loss_function>
  vars = <list_of_variables>
  grads = tape.gradient(loss, vars)
  processed_grads = [process_gradient(g) for g in grads]
  grads_and_vars = zip(processed_grads, var_list)
140 141 142 143 144 145 146 147 148

  # grads_and_vars is a list of tuples (gradient, variable).  Do whatever you
  # need to the 'gradient' part, for example cap them, etc.
  capped_grads_and_vars = [(MyCapper(gv[0]), gv[1]) for gv in grads_and_vars]

  # Ask the optimizer to apply the capped gradients.
  opt.apply_gradients(capped_grads_and_vars)
  ```

149 150 151 152
  ### Use with `tf.distribute.Strategy`.

  This optimizer class is `tf.distribute.Strategy` aware, which means it
  automatically sums gradients across all replicas. To average gradients,
153 154 155
  you divide your loss by the global batch size, which is done
  automatically if you use `tf.keras` built-in training or evaluation loops.
  See the `reduction` argument of your loss which should be set to
156 157
  `tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE` for averaging or
  `tf.keras.losses.Reduction.SUM` for not.
158 159 160 161 162 163 164 165 166

  If you are not using these and you want to average gradients, you should use
  `tf.math.reduce_sum` to add up your per-example losses and then divide by the
  global batch size. Note that when using `tf.distribute.Strategy`, the first
  component of a tensor's shape is the *replica-local* batch size, which is off
  by a factor equal to the number of replicas being used to compute a single
  step. As a result, using `tf.math.reduce_mean` will give the wrong answer,
  resulting in gradients that can be many times too big.

167
  ### Variable Constraint
168

169 170 171 172 173 174 175 176 177 178
  All Keras optimizers respect variable constraints. If constraint function is
  passed to any variable, the constraint will be applied to the variable after
  the gradient has been applied to the variable.
  Important: If gradient is sparse tensor, variable constraint is not supported.

  ### Thread Compatibility

  The entire optimizer is currently thread compatible, not thread-safe. The user
  needs to perform synchronization if necessary.

179 180
  ### Slots

181 182 183 184 185
  Many optimizer subclasses, such as `Adam` and `Adagrad` allocate and manage
  additional variables associated with the variables to train.  These are called
  <i>Slots</i>.  Slots have names and you can ask the optimizer for the names of
  the slots that it uses.  Once you have a slot name you can ask the optimizer
  for the variable it created to hold the slot value.
186 187 188 189 190 191 192 193 194 195 196 197

  This can be useful if you want to log debug a training algorithm, report stats
  about the slots, etc.

  ### Hyper parameters

  These are arguments passed to the optimizer subclass constructor
  (the `__init__` method), and then passed to `self._set_hyper()`.
  They can be either regular Python values (like 1.0), tensors, or
  callables. If they are callable, the callable will be called during
  `apply_gradients()` to get the value for the hyper parameter.

198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215
  Hyper parameters can be overwritten through user code:

  Example:

  ```python
  # Create an optimizer with the desired parameters.
  opt = tf.keras.optimizers.SGD(learning_rate=0.1)
  # `loss` is a callable that takes no argument and returns the value
  # to minimize.
  loss = lambda: 3 * var1 + 2 * var2
  # In eager mode, simply call minimize to update the list of variables.
  opt.minimize(loss, var_list=[var1, var2])
  # update learning rate
  opt.learning_rate = 0.05
  opt.minimize(loss, var_list=[var1, var2])
  ```

  ### Write a customized optimizer.
216 217
  If you intend to create your own optimization algorithm, simply inherit from
  this class and override the following methods:
218 219 220 221 222

    - resource_apply_dense (update variable given gradient tensor is dense)
    - resource_apply_sparse (update variable given gradient tensor is sparse)
    - create_slots (if your optimizer algorithm requires additional variables)
    - get_config (serialization of the optimizer, include all hyper parameters)
223 224
  """

225
  def __init__(self, name, **kwargs):
226 227 228 229 230 231 232 233
    """Create a new Optimizer.

    This must be called by the constructors of subclasses.
    Note that Optimizer instances should not bind to a single graph,
    and so shouldn't keep Tensors as member variables. Generally
    you should be able to use the _set_hyper()/state.get_hyper()
    facility instead.

234 235
    This class in stateful and thread-compatible.

236 237 238
    Args:
      name: A non-empty string.  The name to use for accumulators created
        for the optimizer.
239 240 241 242 243
      **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`,
        `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip
        gradients by value, `decay` is included for backward compatibility to
        allow time inverse decay of learning rate. `lr` is included for backward
        compatibility, recommended to use `learning_rate` instead.
244 245 246 247 248 249

    Raises:
      ValueError: If name is malformed.
      RuntimeError: If _create_slots has been overridden instead of
          _create_vars.
    """
250 251 252 253 254 255 256 257 258
    allowed_kwargs = {"clipnorm", "clipvalue", "lr", "decay"}
    for k in kwargs:
      if k not in allowed_kwargs:
        raise TypeError("Unexpected keyword argument "
                        "passed to optimizer: " + str(k))
      # checks that all keyword arguments are non-negative.
      if kwargs[k] < 0:
        raise ValueError("Expected {} >= 0, received: {}".format(k, kwargs[k]))

259
    self._use_locking = True
260
    self._name = name
261
    self._hyper = {}
262
    # dict: {variable name : {slot name : variable}}
263
    self._slots = {}
A
Allen Lavoie 已提交
264
    self._slot_names = []
265
    self._weights = []
266
    self._iterations = None
267

268
    # For implementing Trackable. Stores information about how to restore
A
Allen Lavoie 已提交
269
    # slot variables which have not yet been created
270
    # (trackable._CheckpointPosition objects).
A
Allen Lavoie 已提交
271 272 273 274 275
    #  {slot_name :
    #      {_var_key(variable_to_train): [checkpoint_position, ... ], ... },
    #   ... }
    self._deferred_slot_restorations = {}

276 277 278 279
    decay = kwargs.pop("decay", 0.0)
    if decay < 0.:
      raise ValueError("decay cannot be less than 0: {}".format(decay))
    self._initial_decay = decay
280 281 282 283
    if "clipnorm" in kwargs:
      self.clipnorm = kwargs.pop("clipnorm")
    if "clipvalue" in kwargs:
      self.clipvalue = kwargs.pop("clipvalue")
284

285
    self._hypers_created = False
286

287
  def minimize(self, loss, var_list, grad_loss=None, name=None):
288
    """Minimize `loss` by updating `var_list`.
289

290
    This method simply computes gradient using `tf.GradientTape` and calls
291
    `apply_gradients()`. If you want to process the gradient before applying
292
    then call `tf.GradientTape` and `apply_gradients()` explicitly instead
293 294 295
    of using this function.

    Args:
296
      loss: A callable taking no arguments which returns the value to minimize.
297
      var_list: list or tuple of `Variable` objects to update to minimize
298 299 300 301
        `loss`, or a callable returning the list or tuple of `Variable` objects.
        Use callable when the variable list would otherwise be incomplete before
        `minimize` since the variables are created at the first time `loss` is
        called.
302
      grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
303
      name: Optional name for the returned operation.
304 305 306 307 308 309 310 311 312

    Returns:
      An Operation that updates the variables in `var_list`.  If `global_step`
      was not `None`, that operation also increments `global_step`.

    Raises:
      ValueError: If some of the variables are not `Variable` objects.

    """
313 314
    grads_and_vars = self._compute_gradients(
        loss, var_list=var_list, grad_loss=grad_loss)
315 316 317

    return self.apply_gradients(grads_and_vars, name=name)

318
  def _compute_gradients(self, loss, var_list, grad_loss=None):
319 320 321 322 323 324 325 326 327
    """Compute gradients of `loss` for the variables in `var_list`.

    This is the first part of `minimize()`.  It returns a list
    of (gradient, variable) pairs where "gradient" is the gradient
    for "variable".  Note that "gradient" can be a `Tensor`, an
    `IndexedSlices`, or `None` if there is no gradient for the
    given variable.

    Args:
328
      loss: A callable taking no arguments which returns the value to minimize.
329 330 331 332 333
      var_list: list or tuple of `Variable` objects to update to minimize
        `loss`, or a callable returning the list or tuple of `Variable` objects.
        Use callable when the variable list would otherwise be incomplete before
        `minimize` and the variables are created at the first time when `loss`
        is called.
334 335 336 337 338 339 340 341 342 343 344
      grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.

    Returns:
      A list of (gradient, variable) pairs. Variable is always present, but
      gradient can be `None`.

    Raises:
      TypeError: If `var_list` contains anything else than `Variable` objects.
      ValueError: If some arguments are invalid, or var_list is None.
    """
    # TODO(josh11b): Test that we handle weight decay in a reasonable way.
345
    with backprop.GradientTape() as tape:
346 347
      if not callable(var_list):
        tape.watch(var_list)
348
      loss_value = loss()
349 350 351
    if callable(var_list):
      var_list = var_list()
    var_list = nest.flatten(var_list)
352 353 354 355 356 357 358 359 360
    grads = tape.gradient(loss_value, var_list, grad_loss)

    if hasattr(self, "clipnorm"):
      grads = [clip_ops.clip_by_norm(g, self.clipnorm) for g in grads]
    if hasattr(self, "clipvalue"):
      grads = [
          clip_ops.clip_by_value(g, -self.clipvalue, self.clipvalue)
          for g in grads
      ]
361 362 363 364 365 366 367 368 369

    grads_and_vars = list(zip(grads, var_list))
    self._assert_valid_dtypes([
        v for g, v in grads_and_vars
        if g is not None and v.dtype != dtypes.resource
    ])

    return grads_and_vars

370 371 372 373 374 375 376 377 378 379 380 381 382 383
  def get_gradients(self, loss, params):
    """Returns gradients of `loss` with respect to `params`.

    Arguments:
      loss: Loss tensor.
      params: List of variables.

    Returns:
      List of gradient tensors.

    Raises:
      ValueError: In case any gradient cannot be computed (e.g. if gradient
        function not implemented).
    """
384
    params = nest.flatten(params)
385 386
    with backend.get_graph().as_default():
      grads = gradients.gradients(loss, params)
387 388 389 390 391 392 393
    for grad, param in zip(grads, params):
      if grad is None:
        raise ValueError("Variable {} has `None` for gradient. "
                         "Please make sure that all of your ops have a "
                         "gradient defined (i.e. are differentiable). "
                         "Common ops without gradient: "
                         "K.argmax, K.round, K.eval.".format(param))
394 395 396 397 398 399 400 401 402
    if hasattr(self, "clipnorm"):
      grads = [clip_ops.clip_by_norm(g, self.clipnorm) for g in grads]
    if hasattr(self, "clipvalue"):
      grads = [
          clip_ops.clip_by_value(g, -self.clipvalue, self.clipvalue)
          for g in grads
      ]
    return grads

403 404 405 406 407 408 409
  def apply_gradients(self, grads_and_vars, name=None):
    """Apply gradients to variables.

    This is the second part of `minimize()`. It returns an `Operation` that
    applies gradients.

    Args:
410
      grads_and_vars: List of (gradient, variable) pairs.
411 412 413 414 415 416 417 418 419 420 421 422 423 424
      name: Optional name for the returned operation.  Default to the name
        passed to the `Optimizer` constructor.

    Returns:
      An `Operation` that applies the specified gradients. If `global_step`
      was not None, that operation also increments `global_step`.

    Raises:
      TypeError: If `grads_and_vars` is malformed.
      ValueError: If none of the variables have gradients.
    """
    grads_and_vars = _filter_grads(grads_and_vars)
    var_list = [v for (_, v) in grads_and_vars]

425
    # Create iteration if necessary.
426
    with ops.init_scope():
427 428
      _ = self.iterations
      self._create_hypers()
429 430
      self._create_slots(var_list)

431 432
    self._prepare(var_list)

433 434 435 436 437 438 439 440 441 442 443
    return distribute_ctx.get_replica_context().merge_call(
        self._distributed_apply, args=(grads_and_vars,), kwargs={"name": name})

  def _distributed_apply(self, distribution, grads_and_vars, name):
    """`apply_gradients` using a `DistributionStrategy`."""
    reduced_grads = distribution.extended.batch_reduce_to(
        ds_reduce_util.ReduceOp.SUM, grads_and_vars)
    var_list = [v for _, v in grads_and_vars]
    grads_and_vars = zip(reduced_grads, var_list)

    def apply_grad_to_update_var(var, grad):
444 445 446 447
      """Apply gradient to variable."""
      if isinstance(var, ops.Tensor):
        raise NotImplementedError("Trying to update a Tensor ", var)
      if isinstance(grad, ops.IndexedSlices):
448
        if var.constraint is not None:
449 450 451 452 453
          raise RuntimeError(
              "Cannot use a constraint function on a sparse variable.")
        return self._resource_apply_sparse_duplicate_indices(
            grad.values, var, grad.indices)
      update_op = self._resource_apply_dense(grad, var)
454 455 456 457 458
      if var.constraint is not None:
        with ops.control_dependencies([update_op]):
          return var.assign(var.constraint(var))
      else:
        return update_op
459

460
    update_ops = []
461
    with backend.name_scope(name or self._name):
462
      for grad, var in grads_and_vars:
463 464
        scope_name = ("" if ops.executing_eagerly_outside_functions() else
                      "_" + var.op.name)
465
        with backend.name_scope("update" + scope_name):
466 467 468
          update_ops.extend(
              distribution.extended.update(
                  var, apply_grad_to_update_var, args=(grad,), group=False))
469 470 471 472 473 474 475 476 477 478 479 480

      any_symbolic = any(isinstance(i, ops.Operation) or
                         tf_utils.is_symbolic_tensor(i) for i in update_ops)
      if not context.executing_eagerly() or any_symbolic:
        # If the current context is graph mode or any of the update ops are
        # symbolic then the step update should be carried out under a graph
        # context. (eager updates execute immediately)
        with ops._get_graph_from_inputs(update_ops).as_default():  # pylint: disable=protected-access
          with ops.control_dependencies(update_ops):
            return self._iterations.assign_add(1).op

      return self._iterations.assign_add(1)
481

482
  def get_updates(self, loss, params):
483 484 485 486 487 488 489
    grads = self.get_gradients(loss, params)
    grads_and_vars = list(zip(grads, params))
    self._assert_valid_dtypes([
        v for g, v in grads_and_vars
        if g is not None and v.dtype != dtypes.resource
    ])
    return [self.apply_gradients(grads_and_vars)]
490

491
  def _set_hyper(self, name, value):
492
    """set hyper `name` to value. value can be callable, tensor, numeric."""
493 494
    if isinstance(value, trackable.Trackable):
      self._track_trackable(value, name, overwrite=True)
495 496 497 498
    if name not in self._hyper:
      self._hyper[name] = value
    else:
      prev_value = self._hyper[name]
499 500 501 502 503
      if (callable(prev_value)
          or isinstance(prev_value,
                        (ops.Tensor, int, float,
                         learning_rate_schedule.LearningRateSchedule))
          or isinstance(value, learning_rate_schedule.LearningRateSchedule)):
504 505
        self._hyper[name] = value
      else:
506
        backend.set_value(self._hyper[name], value)
507

508
  def _get_hyper(self, name, dtype=None):
Z
Zhenyu Tan 已提交
509 510
    if not self._hypers_created:
      self._create_hypers()
511
    value = self._hyper[name]
512 513
    if isinstance(value, learning_rate_schedule.LearningRateSchedule):
      return value
514 515 516 517 518 519
    if callable(value):
      value = value()
    if dtype:
      return math_ops.cast(value, dtype)
    else:
      return value
520

521 522 523 524 525 526 527 528 529 530 531 532
  def __getattribute__(self, name):
    """Overridden to support hyperparameter access."""
    try:
      return super(OptimizerV2, self).__getattribute__(name)
    except AttributeError as e:
      # Needed to avoid infinite recursion with __setattr__.
      if name == "_hyper":
        raise e
      # Backwards compatibility with Keras optimizers.
      if name == "lr":
        name = "learning_rate"
      if name in self._hyper:
Z
Zhenyu Tan 已提交
533
        return self._get_hyper(name)
534 535
      raise e

536 537
  def __setattr__(self, name, value):
    """Override setattr to support dynamic hyperparameter setting."""
538 539 540
    # Backwards compatibility with Keras optimizers.
    if name == "lr":
      name = "learning_rate"
541 542 543 544 545
    if hasattr(self, "_hyper") and name in self._hyper:
      self._set_hyper(name, value)
    else:
      super(OptimizerV2, self).__setattr__(name, value)

A
Allen Lavoie 已提交
546 547 548 549
  def get_slot_names(self):
    """A list of names for this optimizer's slots."""
    return self._slot_names

Z
Zhenyu Tan 已提交
550
  def add_slot(self, var, slot_name, initializer="zeros"):
A
Allen Lavoie 已提交
551 552 553
    """Add a new slot variable for `var`."""
    if slot_name not in self._slot_names:
      self._slot_names.append(slot_name)
554 555
    var_key = _var_key(var)
    slot_dict = self._slots.setdefault(var_key, {})
A
Allen Lavoie 已提交
556 557 558 559
    weight = slot_dict.get(slot_name, None)
    if weight is None:
      if isinstance(initializer, six.string_types) or callable(initializer):
        initializer = initializers.get(initializer)
560 561
        initial_value = functools.partial(
            initializer, shape=var.shape, dtype=var.dtype)
A
Allen Lavoie 已提交
562
      else:
563
        initial_value = initializer
Z
Zhenyu Tan 已提交
564
      strategy = distribute_ctx.get_strategy()
565
      with strategy.extended.colocate_vars_with(var):
Z
Zhenyu Tan 已提交
566 567 568 569 570
        weight = tf_variables.Variable(
            name="%s/%s" % (var._shared_name, slot_name),  # pylint: disable=protected-access
            dtype=var.dtype,
            trainable=False,
            initial_value=initial_value)
A
Allen Lavoie 已提交
571
      backend.track_variable(weight)
572
      slot_dict[slot_name] = weight
A
Allen Lavoie 已提交
573 574 575
      self._restore_slot_variable(
          slot_name=slot_name, variable=var,
          slot_variable=weight)
576
      self._weights.append(weight)
A
Allen Lavoie 已提交
577
    return weight
578 579

  def get_slot(self, var, slot_name):
580 581 582
    var_key = _var_key(var)
    slot_dict = self._slots[var_key]
    return slot_dict[slot_name]
583

584 585 586 587 588
  def _prepare(self, var_list):
    pass

  def _create_hypers(self):
    if self._hypers_created:
589
      return
590 591
    # Iterate hyper values deterministically.
    for name, value in sorted(self._hyper.items()):
592
      if isinstance(value, ops.Tensor) or callable(value):
593
        continue
594 595 596 597 598 599
      else:
        self._hyper[name] = self.add_weight(
            name,
            shape=[],
            trainable=False,
            initializer=value,
600
            aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA)
601
    self._hypers_created = True
602 603

  @property
604
  def iterations(self):
605
    """Variable. The number of training steps this Optimizer has run."""
606 607 608 609 610 611 612 613
    if self._iterations is None:
      self._iterations = self.add_weight(
          "iter",
          shape=[],
          dtype=dtypes.int64,
          trainable=False,
          aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA)
      self._weights.append(self._iterations)
614
    return self._iterations
615

616 617
  @iterations.setter
  def iterations(self, variable):
618
    if self._iterations is not None:
619
      raise RuntimeError("Cannot set `iterations` to a new Variable after "
620 621 622 623
                         "the Optimizer weights have been created")
    self._iterations = variable
    self._weights.append(self._iterations)

624 625 626
  def _decayed_lr(self, var_dtype):
    """Get decayed learning rate as a Tensor with dtype=var_dtype."""
    lr_t = self._get_hyper("learning_rate", var_dtype)
627 628 629
    if isinstance(lr_t, learning_rate_schedule.LearningRateSchedule):
      local_step = math_ops.cast(self.iterations, var_dtype)
      lr_t = math_ops.cast(lr_t(local_step), var_dtype)
630 631 632 633 634 635
    if self._initial_decay > 0.:
      local_step = math_ops.cast(self.iterations, var_dtype)
      decay_t = self._get_hyper("decay", var_dtype)
      lr_t = lr_t / (1. + decay_t * local_step)
    return lr_t

636
  @abc.abstractmethod
637 638 639 640 641 642 643 644 645 646 647
  def get_config(self):
    """Returns the config of the optimimizer.

    An optimizer config is a Python dictionary (serializable)
    containing the configuration of an optimizer.
    The same optimizer can be reinstantiated later
    (without any saved state) from this configuration.

    Returns:
        Python dictionary.
    """
648 649 650 651 652 653
    config = {"name": self._name}
    if hasattr(self, "clipnorm"):
      config["clipnorm"] = self.clipnorm
    if hasattr(self, "clipvalue"):
      config["clipvalue"] = self.clipvalue
    return config
654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671

  @classmethod
  def from_config(cls, config, custom_objects=None):
    """Creates an optimizer from its config.

    This method is the reverse of `get_config`,
    capable of instantiating the same optimizer from the config
    dictionary.

    Arguments:
        config: A Python dictionary, typically the output of get_config.
        custom_objects: A Python dictionary mapping names to additional Python
          objects used to create this optimizer, such as a function used for a
          hyperparameter.

    Returns:
        An optimizer instance.
    """
672 673
    if "lr" in config:
      config["learning_rate"] = config.pop("lr")
674 675 676
    if "learning_rate" in config:
      if isinstance(config["learning_rate"], dict):
        config["learning_rate"] = learning_rate_schedule.deserialize(
677
            config["learning_rate"], custom_objects=custom_objects)
678 679 680 681
    return cls(**config)

  def _serialize_hyperparameter(self, hyperparameter_name):
    """Serialize a hyperparameter that can be a float, callable, or Tensor."""
Z
Zhenyu Tan 已提交
682
    value = self._hyper[hyperparameter_name]
683 684
    if isinstance(value, learning_rate_schedule.LearningRateSchedule):
      return learning_rate_schedule.serialize(value)
685 686
    if callable(value):
      return value()
687
    if tensor_util.is_tensor(value):
688
      return backend.get_value(value)
689
    return value
690

691 692 693 694
  def variables(self):
    """Returns variables of this Optimizer based on the order created."""
    return self._weights

695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724
  @property
  def weights(self):
    """Returns variables of this Optimizer based on the order created."""
    return self._weights

  def get_weights(self):
    params = self.weights
    return backend.batch_get_value(params)

  # TODO(tanzheny): Maybe share this logic with base_layer.
  def set_weights(self, weights):
    params = self.weights
    if len(params) != len(weights):
      raise ValueError(
          "You called `set_weights(weights)` on optimizer " + self._name +
          " with a  weight list of length " + str(len(weights)) +
          ", but the optimizer was expecting " + str(len(params)) +
          " weights. Provided weights: " + str(weights)[:50] + "...")
    if not params:
      return
    weight_value_tuples = []
    param_values = backend.batch_get_value(params)
    for pv, p, w in zip(param_values, params, weights):
      if pv.shape != w.shape:
        raise ValueError("Optimizer weight shape " + str(pv.shape) +
                         " not compatible with "
                         "provided weight shape " + str(w.shape))
      weight_value_tuples.append((p, w))
    backend.batch_set_value(weight_value_tuples)

725 726 727 728 729 730
  def add_weight(self,
                 name,
                 shape,
                 dtype=None,
                 initializer="zeros",
                 trainable=None,
731 732
                 synchronization=tf_variables.VariableSynchronization.AUTO,
                 aggregation=tf_variables.VariableAggregation.NONE):
733 734 735

    if dtype is None:
      dtype = dtypes.float32
736 737
    if isinstance(initializer, six.string_types) or callable(initializer):
      initializer = initializers.get(initializer)
738

739
    if synchronization == tf_variables.VariableSynchronization.ON_READ:
740 741 742 743 744 745 746 747 748 749 750 751 752 753 754
      if trainable:
        raise ValueError(
            "Synchronization value can be set to "
            "VariableSynchronization.ON_READ only for non-trainable variables. "
            "You have specified trainable=True and "
            "synchronization=VariableSynchronization.ON_READ.")
      else:
        # Set trainable to be false when variable is to be synced on read.
        trainable = False
    elif trainable is None:
      trainable = True

    variable = self._add_variable_with_custom_getter(
        name=name,
        shape=shape,
755
        getter=base_layer_utils.make_variable,
756
        overwrite=True,
757
        initializer=initializer,
758 759 760 761 762
        dtype=dtype,
        trainable=trainable,
        use_resource=True,
        synchronization=synchronization,
        aggregation=aggregation)
763
    backend.track_variable(variable)
764 765 766

    return variable

767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855
  def _assert_valid_dtypes(self, tensors):
    """Asserts tensors are all valid types (see `_valid_dtypes`).

    Args:
      tensors: Tensors to check.

    Raises:
      ValueError: If any tensor is not a valid type.
    """
    valid_dtypes = self._valid_dtypes()
    for t in tensors:
      dtype = t.dtype.base_dtype
      if dtype not in valid_dtypes:
        raise ValueError("Invalid type %r for %s, expected: %s." %
                         (dtype, t.name, [v for v in valid_dtypes]))

  def _valid_dtypes(self):
    """Valid types for loss, variables and gradients.

    Subclasses should override to allow other float types.

    Returns:
      Valid types for loss, variables and gradients.
    """
    return set(
        [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64])

  def _call_if_callable(self, param):
    """Call the function if param is callable."""
    return param() if callable(param) else param

  def _resource_apply_dense(self, grad, handle):
    """Add ops to apply dense gradients to the variable `handle`.

    Args:
      grad: a `Tensor` representing the gradient.
      handle: a `Tensor` of dtype `resource` which points to the variable to be
        updated.

    Returns:
      An `Operation` which updates the value of the variable.
    """
    raise NotImplementedError()

  def _resource_apply_sparse_duplicate_indices(self, grad, handle, indices):
    """Add ops to apply sparse gradients to `handle`, with repeated indices.

    Optimizers which override this method must deal with repeated indices. See
    the docstring of `_apply_sparse_duplicate_indices` for details. By default
    the correct behavior, to sum non-unique indices and their associated
    gradients, is enforced by first pre-processing `grad` and `indices` and
    passing them on to `_resource_apply_sparse`. Optimizers which deal correctly
    with duplicate indices may instead override this method to avoid the
    overhead of summing.

    Args:
      grad: a `Tensor` representing the gradient for the affected indices.
      handle: a `Tensor` of dtype `resource` which points to the variable to be
        updated.
      indices: a `Tensor` of integral type representing the indices for which
        the gradient is nonzero. Indices may be repeated.

    Returns:
      An `Operation` which updates the value of the variable.
    """
    summed_grad, unique_indices = _deduplicate_indexed_slices(
        values=grad, indices=indices)
    return self._resource_apply_sparse(summed_grad, handle, unique_indices)

  def _resource_apply_sparse(self, grad, handle, indices):
    """Add ops to apply sparse gradients to the variable `handle`.

    Similar to `_apply_sparse`, the `indices` argument to this method has been
    de-duplicated. Optimizers which deal correctly with non-unique indices may
    instead override `_resource_apply_sparse_duplicate_indices` to avoid this
    overhead.

    Args:
      grad: a `Tensor` representing the gradient for the affected indices.
      handle: a `Tensor` of dtype `resource` which points to the variable to be
        updated.
      indices: a `Tensor` of integral type representing the indices for which
        the gradient is nonzero. Indices are unique.

    Returns:
      An `Operation` which updates the value of the variable.
    """
    raise NotImplementedError()

856 857 858 859 860 861 862 863 864 865
  def _resource_scatter_add(self, x, i, v):
    with ops.control_dependencies(
        [resource_variable_ops.resource_scatter_add(x.handle, i, v)]):
      return x.value()

  def _resource_scatter_update(self, x, i, v):
    with ops.control_dependencies(
        [resource_variable_ops.resource_scatter_update(x.handle, i, v)]):
      return x.value()

A
Allen Lavoie 已提交
866
  # ---------------
867
  # For implementing the trackable interface
A
Allen Lavoie 已提交
868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897
  # ---------------

  def _restore_slot_variable(self, slot_name, variable, slot_variable):
    """Restore a newly created slot variable's value."""
    variable_key = _var_key(variable)
    deferred_restorations = self._deferred_slot_restorations.get(
        slot_name, {}).pop(variable_key, [])
    # Iterate over restores, highest restore UID first to minimize the number
    # of assignments.
    deferred_restorations.sort(key=lambda position: position.restore_uid,
                               reverse=True)
    for checkpoint_position in deferred_restorations:
      checkpoint_position.restore(slot_variable)

  def _create_or_restore_slot_variable(
      self, slot_variable_position, slot_name, variable):
    """Restore a slot variable's value, possibly creating it.

    Called when a variable which has an associated slot variable is created or
    restored. When executing eagerly, we create the slot variable with a
    restoring initializer.

    No new variables are created when graph building. Instead,
    _restore_slot_variable catches these after normal creation and adds restore
    ops to the graph. This method is nonetheless important when graph building
    for the case when a slot variable has already been created but `variable`
    has just been added to a dependency graph (causing us to realize that the
    slot variable needs to be restored).

    Args:
898 899
      slot_variable_position: A `trackable._CheckpointPosition` object
        indicating the slot variable `Trackable` object to be restored.
A
Allen Lavoie 已提交
900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916
      slot_name: The name of this `Optimizer`'s slot to restore into.
      variable: The variable object this slot is being created for.
    """
    variable_key = _var_key(variable)
    slot_dict = self._slots.get(variable_key, {})
    slot_variable = slot_dict.get(slot_name, None)
    if (slot_variable is None and context.executing_eagerly() and
        slot_variable_position.is_simple_variable()
        # Defer slot variable creation if there is an active variable creator
        # scope. Generally we'd like to eagerly create/restore slot variables
        # when possible, but this may mean that scopes intended to catch
        # `variable` also catch its eagerly created slot variable
        # unintentionally (specifically make_template would add a dependency on
        # a slot variable if not for this case). Deferring is mostly harmless
        # (aside from double initialization), and makes variable creator scopes
        # behave the same way they do when graph building.
        and not ops.get_default_graph()._variable_creator_stack):  # pylint: disable=protected-access
917
      initializer = trackable.CheckpointInitialValue(
A
Allen Lavoie 已提交
918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942
          checkpoint_position=slot_variable_position)
      slot_variable = self.add_slot(
          var=variable,
          initializer=initializer,
          slot_name=slot_name)
      # Slot variables are not owned by any one object (because we don't want to
      # save the slot variable if the optimizer is saved without the non-slot
      # variable, or if the non-slot variable is saved without the optimizer;
      # it's a dependency hypergraph with edges of the form (optimizer, non-slot
      # variable, variable)). So we don't _track_ slot variables anywhere, and
      # instead special-case this dependency and otherwise pretend it's a normal
      # graph.
    if slot_variable is not None:
      # If we've either made this slot variable, or if we've pulled out an
      # existing slot variable, we should restore it.
      slot_variable_position.restore(slot_variable)
    else:
      # We didn't make the slot variable. Defer restoring until it gets created
      # normally. We keep a list rather than the one with the highest restore
      # UID in case slot variables have their own dependencies, in which case
      # those could differ between restores.
      self._deferred_slot_restorations.setdefault(
          slot_name, {}).setdefault(variable_key, []).append(
              slot_variable_position)

943 944 945 946 947

def _filter_grads(grads_and_vars):
  """Filter out iterable with grad equal to None."""
  grads_and_vars = tuple(grads_and_vars)
  if not grads_and_vars:
948
    return grads_and_vars
949 950 951 952 953 954 955 956 957 958
  filtered = []
  vars_with_empty_grads = []
  for grad, var in grads_and_vars:
    if grad is None:
      vars_with_empty_grads.append(var)
    else:
      filtered.append((grad, var))
  filtered = tuple(filtered)
  if not filtered:
    raise ValueError("No gradients provided for any variable: %s." %
Z
Zhenyu Tan 已提交
959
                     ([v.name for _, v in grads_and_vars],))
960 961 962 963 964 965 966
  if vars_with_empty_grads:
    logging.warning(
        ("Gradients does not exist for variables %s when minimizing the loss."),
        ([v.name for v in vars_with_empty_grads]))
  return filtered


967 968
def _var_key(var):
  """Key for representing a primary variable, for looking up slots.
969

970 971 972
  In graph mode the name is derived from the var shared name.
  In eager mode the name is derived from the var unique id.
  If distribution strategy exists, get the primary variable first.
973 974 975 976 977

  Args:
    var: the variable.

  Returns:
978
    the unique name of the variable.
979 980 981
  """

  # pylint: disable=protected-access
982 983 984 985
  # Get the distributed variable if it exists.
  if getattr(var, "_distributed_container", None) is not None:
    var = var._distributed_container()
  if var._in_graph_mode:
986 987 988 989 990 991 992 993
    return var._shared_name
  return var._unique_id


def _get_slot_key_from_var(var, slot_name):
  """Get the slot key for the variable: var_name/slot_name."""

  name = _var_key(var)
994
  return name + "/" + slot_name
995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028


class _RestoredOptimizer(OptimizerV2):
  """A non-functional Optimizer implementation for checkpoint compatibility.

  Holds slot variables and hyperparameters when an optimizer is restored from a
  SavedModel. These variables may be referenced in functions along with ops
  created by the original optimizer, but currently we do not support using the
  optimizer object iself (e.g. through `apply_gradients`).
  """
  # TODO(allenl): Make the restored optimizer functional by tracing its apply
  # methods.

  def __init__(self):
    super(_RestoredOptimizer, self).__init__("_RestoredOptimizer")
    self._hypers_created = True

  def get_config(self):
    # TODO(allenl): Save and restore the Optimizer's config
    raise NotImplementedError(
        "Restoring functional Optimzers from SavedModels is not currently "
        "supported. Please file a feature request if this limitation bothers "
        "you.")

revived_types.register_revived_type(
    "optimizer",
    lambda obj: isinstance(obj, OptimizerV2),
    versions=[revived_types.VersionedTypeRegistration(
        object_factory=lambda proto: _RestoredOptimizer(),
        version=1,
        min_producer_version=1,
        min_consumer_version=1,
        setter=_RestoredOptimizer._set_hyper  # pylint: disable=protected-access
    )])