提交 10666c59 编写于 作者: T Thomas O'Malley 提交者: TensorFlower Gardener

Keras ideal fit and compile.

Kept all new abstractions private for now. In a few weeks, if we're
comfortable that these abstractions are working and stable, we should expose
many of them publicly.

Capabilites added by this CL:

(1) Easy to create a custom training step via overriding Model._train_step
(2) Easy to create custom tf.function / DistStrat logic via overriding
Model._make_train_function
(3) Advanced users can override Model.compile and Model.fit
(4) Full support for dicts, nested structures, etc with Subclassed Models.
(5) "Power user" path (tf.data inputs) only modifies data in Model._train_step,
where this behavior is easy to override and disable. This applies even to
Keras's assumption that data is passed in (x, y, sample_weight) format.

Behavior changes:

(1) "loss" passed to Callbacks is now stateful (like all other metrics in
Callbacks). This greatly simplifies the training step logic and callback logic.
(2) ProgbarLogger always uses steps. If steps is not available, the
ProgbarLogger handles inferring the steps after the first epoch.
(3) validation_batch_size added in `fit`, rather than inferring from generator.
(4) Model.inputs, Model.outputs, Model.input_names, and Model.output_names are
no longer populated for subclassed Models. Instead, "pseudo" output names are
created for subclassed Models, which are only used for metrics names and
SavedModel's signature.
(5) Cast NumPy floats to backend.floatx(), otherwise leave
unchanged (this is likely not a change, we did something like this in our old
version but the logic was scattered in many places)

PiperOrigin-RevId: 296090972
Change-Id: Ia5ac833fd39085bddb016833bd338083d0dc5fc2
上级 abaab5b3
......@@ -195,6 +195,7 @@ class DistributedDumpingCallbackTest(
self.assertAllClose(device_1_matmul_values[0], [[10.0]])
self.assertAllClose(device_1_bias_add_values[0], [[11.0]])
# TODO(b/148461691): Fix for new Keras internals.
@combinations.generate(
combinations.combine(
distribution=[
......@@ -206,7 +207,8 @@ class DistributedDumpingCallbackTest(
mode=["eager"],
tensor_debug_mode=["NO_TENSOR", "FULL_TENSOR"],
))
def testKerasModelFitOnOneOrTwoDevices(self, distribution, tensor_debug_mode):
def DISABLED_testKerasModelFitOnOneOrTwoDevices(self, distribution,
tensor_debug_mode):
writer = dumping_callback.enable_dump_debug_info(
self.dump_root, tensor_debug_mode=tensor_debug_mode)
......
......@@ -33,8 +33,12 @@ class KerasSaveLoadTest(test_base.TestSavedModelBase):
def _save_model(self, model, saved_dir):
model.save(saved_dir, save_format='tf')
def _load_and_run_model(self, distribution, saved_dir, predict_dataset,
output_name, experimental_run_tf_function):
def _load_and_run_model(self,
distribution,
saved_dir,
predict_dataset,
experimental_run_tf_function,
output_name='output_1'):
restored_keras_model = save.load_model(saved_dir)
restored_keras_model._experimental_run_tf_function = (
experimental_run_tf_function)
......
......@@ -45,7 +45,7 @@ class SimpleFunctionalModel(model_collection_base.ModelAndInput):
"""A simple functional model and its inputs."""
def get_model(self, **kwargs):
output_name = 'output_layer'
output_name = 'output_1'
x = keras.layers.Input(shape=(3,), dtype=dtypes.float32)
y = keras.layers.Dense(5, dtype=dtypes.float32, name=output_name)(x)
......@@ -74,7 +74,7 @@ class SimpleSequentialModel(model_collection_base.ModelAndInput):
"""A simple sequential model and its inputs."""
def get_model(self, **kwargs):
output_name = 'output_layer'
output_name = 'output_1'
model = keras.Sequential()
y = keras.layers.Dense(
......@@ -106,7 +106,7 @@ class _SimpleModel(keras.Model):
self._dense_layer = keras.layers.Dense(5, dtype=dtypes.float32)
def call(self, inputs):
return {'output_layer': self._dense_layer(inputs)}
return self._dense_layer(inputs)
class SimpleSubclassModel(model_collection_base.ModelAndInput):
......
......@@ -41,8 +41,12 @@ class SavedModelSaveAndLoadTest(test_base.TestSavedModelBase):
def _save_model(self, model, saved_dir):
keras_saved_model.export_saved_model(model, saved_dir, serving_only=True)
def _load_and_run_model(self, distribution, saved_dir, predict_dataset,
output_name, experimental_run_tf_function):
def _load_and_run_model(self,
distribution,
saved_dir,
predict_dataset,
experimental_run_tf_function,
output_name='output_1'):
return test_base.load_and_run_with_saved_model_api(distribution, saved_dir,
predict_dataset,
output_name)
......
......@@ -35,8 +35,12 @@ class SavedModelKerasModelTest(test_base.TestSavedModelBase):
def _save_model(self, model, saved_dir):
saved_model.save(model, saved_dir)
def _load_and_run_model(self, distribution, saved_dir, predict_dataset,
output_name, experimental_run_tf_function):
def _load_and_run_model(self,
distribution,
saved_dir,
predict_dataset,
experimental_run_tf_function,
output_name='output_1'):
return test_base.load_and_run_with_saved_model_api(distribution, saved_dir,
predict_dataset,
output_name)
......@@ -100,8 +104,12 @@ class SavedModelTFModuleTest(test_base.TestSavedModelBase):
call = model.__call__.get_concrete_function(tensor_spec.TensorSpec(None))
saved_model.save(model, saved_dir, signatures=call)
def _load_and_run_model(self, distribution, saved_dir, predict_dataset,
output_name, experimental_run_tf_function):
def _load_and_run_model(self,
distribution,
saved_dir,
predict_dataset,
experimental_run_tf_function,
output_name='output_1'):
del output_name, experimental_run_tf_function
model = saved_model.load(saved_dir)
return self._predict_with_model(distribution, model, predict_dataset)
......
......@@ -150,8 +150,12 @@ class TestSavedModelBase(test.TestCase, parameterized.TestCase):
"""
raise NotImplementedError('must be implemented in descendants')
def _load_and_run_model(self, distribution, saved_dir, predict_dataset,
output_name, experimental_run_tf_function):
def _load_and_run_model(self,
distribution,
saved_dir,
predict_dataset,
experimental_run_tf_function,
output_name='output_1'):
"""Load the model and run 1 step of predict with it.
This method must be implemented by the subclasses.
......@@ -162,10 +166,10 @@ class TestSavedModelBase(test.TestCase, parameterized.TestCase):
saved_dir: the string representing the path where the model is saved.
predict_dataset: the data used to do the predict on the model for
cross_replica context.
output_name: the string representing the name of the output layer of the
model.
experimental_run_tf_function: Whether to use the single execution path
for models.
output_name: the string representing the name of the output layer of the
model.
"""
raise NotImplementedError('must be implemented in descendants')
......@@ -211,10 +215,6 @@ class TestSavedModelBase(test.TestCase, parameterized.TestCase):
distribution=distribution,
saved_dir=saved_dir,
predict_dataset=predict_dataset,
# Note that subclassed model's output names aren't defined until after
# the model is built (in these tests, this occurs when the model is
# trained).
output_name=getattr(model, 'output_names', [None])[0],
experimental_run_tf_function=experimental_run_tf_function)
tolerance = get_tolerance(None, distribution)
......@@ -248,7 +248,6 @@ class TestSavedModelBase(test.TestCase, parameterized.TestCase):
distribution=None,
saved_dir=saved_dir,
predict_dataset=predict_dataset,
output_name=getattr(model, 'output_names', [None])[0],
experimental_run_tf_function=experimental_run_tf_function)
tolerance = get_tolerance(distribution, None)
......@@ -285,7 +284,6 @@ class TestSavedModelBase(test.TestCase, parameterized.TestCase):
distribution=distribution_for_restoring,
saved_dir=saved_dir,
predict_dataset=predict_dataset,
output_name=getattr(model, 'output_names', [None])[0],
experimental_run_tf_function=experimental_run_tf_function)
tolerance = get_tolerance(distribution_for_saving,
......
......@@ -186,7 +186,7 @@ class ForwardAccumulator(object):
>>> x = tf.constant([[2.0, 3.0], [1.0, 4.0]])
>>> dense = tf.keras.layers.Dense(1)
>>> dense.build([2])
>>> dense.build([None, 2])
>>> with tf.autodiff.ForwardAccumulator(
... primals=dense.kernel,
... tangents=tf.constant([[1.], [0.]])) as acc:
......@@ -210,7 +210,7 @@ class ForwardAccumulator(object):
>>> x = tf.constant([[2.0, 3.0], [1.0, 4.0]])
>>> dense = tf.keras.layers.Dense(1)
>>> dense.build([2])
>>> dense.build([None, 2])
>>> loss_fn = lambda: tf.reduce_sum((dense(x) - tf.constant([1., -1.])) ** 2.)
>>> kernel_fprop = []
>>> with tf.autodiff.ForwardAccumulator(
......
......@@ -1067,7 +1067,7 @@ class HessianTests(test.TestCase, parameterized.TestCase):
("MapFn", False)])
def testHessianOfVariables(self, use_pfor):
model = core.Dense(1)
model.build([2])
model.build([None, 2])
def _loss(*unused_args):
input_value = constant_op.constant([[-0.5, 1.], [0.5, -1.]])
......
......@@ -2271,7 +2271,8 @@ def _convert_inputs_to_signature(inputs, input_signature, flat_input_signature):
flatten_inputs = nest.flatten_up_to(
input_signature,
inputs[:len(input_signature)],
expand_composites=True)
expand_composites=True,
check_types=False) # lists are convert to tuples for `tf.data`.
except ValueError:
raise ValueError("Structure of Python function inputs does not match "
"input_signature:\n%s" %
......
......@@ -4347,6 +4347,10 @@ def in_train_phase(x, alt, training=None):
Either `x` or `alt` based on the `training` flag.
the `training` flag defaults to `K.learning_phase()`.
"""
from tensorflow.python.keras.engine import base_layer_utils # pylint: disable=g-import-not-at-top
if training is None:
training = base_layer_utils.call_context().training
if training is None:
training = learning_phase()
......
......@@ -49,6 +49,7 @@ from tensorflow.python.ops import summary_ops_v2
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpoint_management
from tensorflow.python.util import nest
from tensorflow.python.util.compat import collections_abc
from tensorflow.python.util.tf_export import keras_export
from tensorflow.tools.docs import doc_controls
......@@ -187,26 +188,67 @@ def make_logs(model, logs, outputs, mode, prefix=''):
class CallbackList(object):
"""Container abstracting a list of callbacks.
"""Container abstracting a list of callbacks."""
Arguments:
def __init__(self,
callbacks=None,
add_history=False,
add_progbar=False,
model=None,
**params):
"""Creates a container for `Callbacks`.
Arguments:
callbacks: List of `Callback` instances.
queue_length: Queue length for keeping
running statistics over callback execution time.
"""
add_history: Whether a `History` callback should be added, if one does not
already exist in `callback`s.
add_progbar: Whether a `ProgbarLogger` callback should be added, if one
does not already exist in `callback`s.
model: The `Model` these `Callback`s are used with.`
**params: If provided, parameters will be passed to each `Callback` via
`Callback.set_params`.
"""
self.callbacks = nest.flatten(callbacks) if callbacks else []
self._add_default_callbacks(add_history, add_progbar)
def __init__(self, callbacks=None, queue_length=10):
callbacks = callbacks or []
self.callbacks = [c for c in callbacks]
self.queue_length = queue_length
self.params = {}
self.model = None
if model:
self.set_model(model)
if params:
self.set_params(params)
self._queue_length = 10
self._reset_batch_timing()
def _add_default_callbacks(self, add_history, add_progbar):
"""Adds `Callback`s that are always present."""
self._progbar = None
self._history = None
for cb in self.callbacks:
if isinstance(cb, ProgbarLogger):
self._progbar = cb
elif isinstance(cb, History):
self._history = cb
if self._progbar is None and add_progbar:
self._progbar = ProgbarLogger(count_mode='steps')
self.callbacks.append(self._progbar)
if self._history is None and add_history:
self._history = History()
self.callbacks.append(self._history)
def _reset_batch_timing(self):
self._delta_t_batch = 0.
self._delta_ts = collections.defaultdict(
lambda: collections.deque([], maxlen=self.queue_length))
lambda: collections.deque([], maxlen=self._queue_length))
def _process_logs(self, logs):
if logs:
return {
k: v.numpy() if hasattr(v, 'numpy') else v for k, v in logs.items()
}
return {}
def append(self, callback):
self.callbacks.append(callback)
......@@ -218,6 +260,8 @@ class CallbackList(object):
def set_model(self, model):
self.model = model
if self._history:
model.history = self._history
for callback in self.callbacks:
callback.set_model(model)
......@@ -266,9 +310,11 @@ class CallbackList(object):
self.on_predict_end()
def on_batch_begin(self, batch, logs=None):
logs = self._process_logs(logs)
self._call_batch_hook(ModeKeys.TRAIN, 'begin', batch, logs=logs)
def on_batch_end(self, batch, logs=None):
logs = self._process_logs(logs)
self._call_batch_hook(ModeKeys.TRAIN, 'end', batch, logs=logs)
def on_epoch_begin(self, epoch, logs=None):
......@@ -281,7 +327,7 @@ class CallbackList(object):
logs: dict. Currently no data is passed to this argument for this method
but that may change in the future.
"""
logs = logs or {}
logs = self._process_logs(logs)
for callback in self.callbacks:
callback.on_epoch_begin(epoch, logs)
self._reset_batch_timing()
......@@ -297,7 +343,7 @@ class CallbackList(object):
validation epoch if validation is performed. Validation result keys
are prefixed with `val_`.
"""
logs = logs or {}
logs = self._process_logs(logs)
for callback in self.callbacks:
callback.on_epoch_end(epoch, logs)
......@@ -309,6 +355,7 @@ class CallbackList(object):
logs: dict. Has keys `batch` and `size` representing the current batch
number and the size of the batch.
"""
logs = self._process_logs(logs)
self._call_batch_hook(ModeKeys.TRAIN, 'begin', batch, logs=logs)
def on_train_batch_end(self, batch, logs=None):
......@@ -318,6 +365,7 @@ class CallbackList(object):
batch: integer, index of batch within the current epoch.
logs: dict. Metric results for this batch.
"""
logs = self._process_logs(logs)
self._call_batch_hook(ModeKeys.TRAIN, 'end', batch, logs=logs)
def on_test_batch_begin(self, batch, logs=None):
......@@ -328,6 +376,7 @@ class CallbackList(object):
logs: dict. Has keys `batch` and `size` representing the current batch
number and the size of the batch.
"""
logs = self._process_logs(logs)
self._call_batch_hook(ModeKeys.TEST, 'begin', batch, logs=logs)
def on_test_batch_end(self, batch, logs=None):
......@@ -347,6 +396,7 @@ class CallbackList(object):
logs: dict. Has keys `batch` and `size` representing the current batch
number and the size of the batch.
"""
logs = self._process_logs(logs)
self._call_batch_hook(ModeKeys.PREDICT, 'begin', batch, logs=logs)
def on_predict_batch_end(self, batch, logs=None):
......@@ -356,6 +406,7 @@ class CallbackList(object):
batch: integer, index of batch within the current epoch.
logs: dict. Metric results for this batch.
"""
logs = self._process_logs(logs)
self._call_batch_hook(ModeKeys.PREDICT, 'end', batch, logs=logs)
def on_train_begin(self, logs=None):
......@@ -365,6 +416,7 @@ class CallbackList(object):
logs: dict. Currently no data is passed to this argument for this method
but that may change in the future.
"""
logs = self._process_logs(logs)
for callback in self.callbacks:
callback.on_train_begin(logs)
......@@ -375,6 +427,7 @@ class CallbackList(object):
logs: dict. Currently no data is passed to this argument for this method
but that may change in the future.
"""
logs = self._process_logs(logs)
for callback in self.callbacks:
callback.on_train_end(logs)
......@@ -385,6 +438,7 @@ class CallbackList(object):
logs: dict. Currently no data is passed to this argument for this method
but that may change in the future.
"""
logs = self._process_logs(logs)
for callback in self.callbacks:
callback.on_test_begin(logs)
......@@ -395,6 +449,7 @@ class CallbackList(object):
logs: dict. Currently no data is passed to this argument for this method
but that may change in the future.
"""
logs = self._process_logs(logs)
for callback in self.callbacks:
callback.on_test_end(logs)
......@@ -405,6 +460,7 @@ class CallbackList(object):
logs: dict. Currently no data is passed to this argument for this method
but that may change in the future.
"""
logs = self._process_logs(logs)
for callback in self.callbacks:
callback.on_predict_begin(logs)
......@@ -415,6 +471,7 @@ class CallbackList(object):
logs: dict. Currently no data is passed to this argument for this method
but that may change in the future.
"""
logs = self._process_logs(logs)
for callback in self.callbacks:
callback.on_predict_end(logs)
......@@ -721,6 +778,7 @@ class ProgbarLogger(Callback):
should *not* be averaged over an epoch.
Metrics in this list will be logged as-is.
All others will be averaged over time (e.g. loss, etc).
If not provided, defaults to the `Model`'s metrics.
Raises:
ValueError: In case of invalid `count_mode`.
......@@ -734,59 +792,96 @@ class ProgbarLogger(Callback):
self.use_steps = True
else:
raise ValueError('Unknown `count_mode`: ' + str(count_mode))
self.stateful_metrics = set(stateful_metrics or [])
self.log_values = None
def on_train_begin(self, logs=None):
self.verbose = self.params['verbose']
self.epochs = self.params['epochs']
# Defaults to all Model's metrics except for loss.
self.stateful_metrics = set(stateful_metrics) if stateful_metrics else None
def on_epoch_begin(self, epoch, logs=None):
self.seen = 0
if self.use_steps:
self.target = self.params['steps']
self.progbar = None
self.target = None
self.verbose = 1
self.epochs = 1
self._called_in_fit = False
def set_params(self, params):
self.verbose = params['verbose']
self.epochs = params['epochs']
if self.use_steps and 'steps' in params:
self.target = params['steps']
elif not self.use_steps and 'samples' in params:
self.target = params['samples']
else:
self.target = self.params['samples']
self.target = None # Will be inferred at the end of the first epoch.
def on_train_begin(self, logs=None):
# When this logger is called inside `fit`, validation is silent.
self._called_in_fit = True
if self.verbose:
if self.epochs > 1:
print('Epoch %d/%d' % (epoch + 1, self.epochs))
self.progbar = Progbar(
target=self.target,
verbose=self.verbose,
stateful_metrics=self.stateful_metrics,
unit_name='step' if self.use_steps else 'sample')
def on_test_begin(self, logs=None):
if not self._called_in_fit:
self._reset_progbar()
def on_batch_begin(self, batch, logs=None):
self.log_values = []
def on_predict_begin(self, logs=None):
self._reset_progbar()
def on_batch_end(self, batch, logs=None):
logs = logs or {}
batch_size = logs.get('size', 0)
# In case of distribution strategy we can potentially run multiple steps
# at the same time, we should account for that in the `seen` calculation.
num_steps = logs.get('num_steps', 1)
if self.use_steps:
self.seen += num_steps
else:
self.seen += batch_size * num_steps
def on_epoch_begin(self, epoch, logs=None):
self._reset_progbar()
if self.verbose and self.epochs > 1:
print('Epoch %d/%d' % (epoch + 1, self.epochs))
for k in self.params['metrics']:
if k in logs:
self.log_values.append((k, logs[k]))
def on_train_batch_end(self, batch, logs=None):
self._batch_update_progbar(logs)
def on_test_batch_end(self, batch, logs=None):
if not self._called_in_fit:
self._batch_update_progbar(logs)
# Skip progbar update for the last batch;
# will be handled by on_epoch_end.
if self.verbose and (self.target is None or self.seen < self.target):
self.progbar.update(self.seen, self.log_values)
def on_predict_batch_end(self, batch, logs=None):
self._batch_update_progbar(None) # Don't pass prediction results.
def on_epoch_end(self, epoch, logs=None):
self._finalize_progbar(logs)
def on_test_end(self, logs=None):
if not self._called_in_fit:
self._finalize_progbar(logs)
def on_predict_end(self, logs=None):
self._finalize_progbar(logs)
def _reset_progbar(self):
self.seen = 0
self.progbar = None
def _batch_update_progbar(self, logs=None):
"""Updates the progbar."""
if self.stateful_metrics is None:
if self.model:
self.stateful_metrics = (set(m.name for m in self.model.metrics))
else:
self.stateful_metrics = set()
if self.progbar is None:
self.progbar = Progbar(
target=self.target,
verbose=self.verbose,
stateful_metrics=self.stateful_metrics,
unit_name='step' if self.use_steps else 'sample')
logs = copy.copy(logs) if logs else {}
batch_size = logs.pop('size', 0)
num_steps = logs.pop('num_steps', 1) # DistStrat can run >1 steps.
logs.pop('batch', None)
add_seen = num_steps if self.use_steps else num_steps * batch_size
self.seen += add_seen
self.progbar.update(self.seen, list(logs.items()), finalize=False)
def _finalize_progbar(self, logs):
if self.target is None:
self.target = self.seen
self.progbar.target = self.seen
logs = logs or {}
for k in self.params['metrics']:
if k in logs:
self.log_values.append((k, logs[k]))
if self.verbose:
self.progbar.update(self.seen, self.log_values)
self.progbar.update(self.seen, list(logs.items()), finalize=True)
@keras_export('keras.callbacks.History')
......@@ -826,7 +921,7 @@ class ModelCheckpoint(Callback):
- Definition of 'best'; which quantity to monitor and whether it should be
maximized or minimized.
- The frequency it should save at. Currently, the callback supports saving at
the end of every epoch, or after a fixed number of training samples.
the end of every epoch, or after a fixed number of training batches.
- Whether only weights are saved, or the whole model is saved.
Example:
......@@ -873,11 +968,10 @@ class ModelCheckpoint(Callback):
(`model.save(filepath)`).
save_freq: `'epoch'` or integer. When using `'epoch'`, the callback saves
the model after each epoch. When using integer, the callback saves the
model at end of a batch at which this many samples have been seen since
last saving. Note that if the saving isn't aligned to epochs, the
monitored metric may potentially be less reliable (it could reflect as
little as 1 batch, since the metrics get reset every epoch). Defaults to
`'epoch'`
model at end of this many batches. Note that if the saving isn't aligned
to epochs, the monitored metric may potentially be less reliable (it
could reflect as little as 1 batch, since the metrics get reset every
epoch). Defaults to `'epoch'`
**kwargs: Additional arguments for backwards compatibility. Possible key
is `period`.
"""
......@@ -899,7 +993,7 @@ class ModelCheckpoint(Callback):
self.save_weights_only = save_weights_only
self.save_freq = save_freq
self.epochs_since_last_save = 0
self._samples_seen_since_last_saving = 0
self._batches_seen_since_last_saving = 0
# Deprecated field `load_weights_on_restart` is for loading the checkpoint
# file from `filepath` at the start of `model.fit()`
......@@ -917,7 +1011,7 @@ class ModelCheckpoint(Callback):
if 'period' in kwargs:
self.period = kwargs['period']
logging.warning('`period` argument is deprecated. Please use `save_freq` '
'to specify the frequency in number of samples seen.')
'to specify the frequency in number of batches seen.')
else:
self.period = 1
......@@ -1000,15 +1094,15 @@ class ModelCheckpoint(Callback):
# Restore the training state so the model is ready for next (possible)
# multi worker training.
del self._training_state
del self.model._training_state
self.model._training_state = None
def on_batch_end(self, batch, logs=None):
logs = logs or {}
if isinstance(self.save_freq, int):
self._samples_seen_since_last_saving += logs.get('size', 1)
if self._samples_seen_since_last_saving >= self.save_freq:
self._batches_seen_since_last_saving += 1
if self._batches_seen_since_last_saving >= self.save_freq:
self._save_model(epoch=self._current_epoch, logs=logs)
self._samples_seen_since_last_saving = 0
self._batches_seen_since_last_saving = 0
def on_epoch_begin(self, epoch, logs=None):
self._current_epoch = epoch
......@@ -1228,16 +1322,10 @@ class EarlyStopping(Callback):
>>> model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
>>> model.compile(tf.keras.optimizers.SGD(), loss='mse')
>>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5),
... epochs=10, callbacks=[callback])
Train on 5 samples
Epoch 1/10
5/5 [==============================] - ... loss: 6533.1904
Epoch 2/10
5/5 [==============================] - ... loss: 110183360.0000
Epoch 3/10
5/5 [==============================] - ... loss: 1862575718400.0000
Epoch 4/10
5/5 [==============================] - ... loss: 31485597793124352.0000
... epochs=10, batch_size=1, callbacks=[callback],
... verbose=0)
>>> len(history.history['loss']) # Only 4 epochs are run.
4
"""
def __init__(self,
......
......@@ -35,6 +35,7 @@ import numpy as np
from tensorflow.core.framework import summary_pb2
from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
from tensorflow.python.framework import random_seed
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import testing_utils
......@@ -146,9 +147,10 @@ class CallbackCountsTest(keras_parameterized.TestCase):
@parameterized.named_parameters(('with_numpy', _get_numpy()),
('with_sequence', _get_sequence()))
def test_callback_hooks_are_called_in_fit(self, data):
if not context.executing_eagerly():
self.skipTest('Behavior changed in v2.')
x, y = data
val_x, val_y = np.ones((4, 10)), np.ones((4, 1))
is_sequence = isinstance(x, keras.utils.data_utils.Sequence)
model = self._get_model()
counter = Counter()
......@@ -156,8 +158,8 @@ class CallbackCountsTest(keras_parameterized.TestCase):
x,
y,
validation_data=(val_x, val_y),
batch_size=2 if not is_sequence else None,
steps_per_epoch=5 if is_sequence else None,
batch_size=2,
steps_per_epoch=5,
epochs=5,
callbacks=[counter])
......@@ -264,8 +266,8 @@ class KerasCallbacksTest(keras_parameterized.TestCase):
def test_progbar_logging(self):
model = self._get_model(input_shape=(3,))
x = array_ops.ones((50, 3))
y = array_ops.zeros((50, 2))
x = array_ops.ones((200, 3))
y = array_ops.zeros((200, 2))
dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).batch(10)
expected_log = r'(.*- loss:.*- my_acc:.*)+'
......@@ -279,8 +281,8 @@ class KerasCallbacksTest(keras_parameterized.TestCase):
model = self._get_model()
self.assertFalse(model.built)
x = array_ops.ones((50, 3))
y = array_ops.zeros((50, 2))
x = array_ops.ones((200, 3))
y = array_ops.zeros((200, 2))
dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).batch(10)
expected_log = r'(.*- loss:.*- my_acc:.*)+'
......@@ -304,15 +306,15 @@ class KerasCallbacksTest(keras_parameterized.TestCase):
self.assertRegexpMatches(printed.contents(), expected_log)
@keras_parameterized.run_with_all_model_types
@keras_parameterized.run_all_keras_modes
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
def test_progbar_logging_validation_split(self):
model = self._get_model(input_shape=(3,))
x = np.ones((100, 3))
y = np.zeros((100, 2))
expected_log = (
r'(?s).*1/2.*80/80.*- loss:.*- my_acc:.*- val_loss:.*- val_my_acc:'
r'.*2/2.*80/80.*- loss:.*- my_acc:.*- val_loss:.*- val_my_acc:.*')
r'(?s).*1/2.*8/8.*- loss:.*- my_acc:.*- val_loss:.*- val_my_acc:'
r'.*2/2.*8/8.*- loss:.*- my_acc:.*- val_loss:.*- val_my_acc:.*')
with self.captureWritesToStream(sys.stdout) as printed:
model.fit(x, y, batch_size=10, epochs=2, validation_split=0.2)
......@@ -587,7 +589,7 @@ class KerasCallbacksTest(keras_parameterized.TestCase):
monitor=monitor,
save_best_only=save_best_only,
mode=mode,
save_freq=30,
save_freq=15,
period=100) # The period should be ignored (this test tests this).
]
assert not os.path.exists(filepath.format(epoch=3))
......@@ -638,8 +640,8 @@ class KerasCallbacksTest(keras_parameterized.TestCase):
def get_input_datasets():
# Simple training input.
train_input = [[1]] * 16
train_label = [[0]] * 16
train_input = [[1.]] * 16
train_label = [[0.]] * 16
ds = dataset_ops.Dataset.from_tensor_slices((train_input, train_label))
return ds.batch(8, drop_remainder=True)
......@@ -1268,40 +1270,40 @@ class KerasCallbacksTest(keras_parameterized.TestCase):
values.append(x)
assert 'nan' in values[-1], 'The last epoch was not logged.'
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
def test_TerminateOnNaN(self):
with self.cached_session():
np.random.seed(1337)
(x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
train_samples=TRAIN_SAMPLES,
test_samples=TEST_SAMPLES,
input_shape=(INPUT_DIM,),
num_classes=NUM_CLASSES)
np.random.seed(1337)
(x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
train_samples=TRAIN_SAMPLES,
test_samples=TEST_SAMPLES,
input_shape=(INPUT_DIM,),
num_classes=NUM_CLASSES)
y_test = np_utils.to_categorical(y_test)
y_train = np_utils.to_categorical(y_train)
cbks = [keras.callbacks.TerminateOnNaN()]
model = keras.models.Sequential()
initializer = keras.initializers.Constant(value=1e5)
for _ in range(5):
model.add(
keras.layers.Dense(
2,
input_dim=INPUT_DIM,
activation='relu',
kernel_initializer=initializer))
model.add(keras.layers.Dense(NUM_CLASSES))
model.compile(loss='mean_squared_error', optimizer='rmsprop')
history = model.fit(
x_train,
y_train,
batch_size=BATCH_SIZE,
validation_data=(x_test, y_test),
callbacks=cbks,
epochs=20)
loss = history.history['loss']
self.assertEqual(len(loss), 1)
self.assertEqual(loss[0], np.inf)
y_test = np_utils.to_categorical(y_test)
y_train = np_utils.to_categorical(y_train)
cbks = [keras.callbacks.TerminateOnNaN()]
model = keras.models.Sequential()
initializer = keras.initializers.Constant(value=1e5)
for _ in range(5):
model.add(
keras.layers.Dense(
2,
input_dim=INPUT_DIM,
activation='relu',
kernel_initializer=initializer))
model.add(keras.layers.Dense(NUM_CLASSES))
model.compile(loss='mean_squared_error', optimizer='rmsprop')
history = model.fit(
x_train,
y_train,
batch_size=BATCH_SIZE,
validation_data=(x_test, y_test),
callbacks=cbks,
epochs=20)
loss = history.history['loss']
self.assertEqual(len(loss), 1)
self.assertTrue(np.isnan(loss[0]))
@unittest.skipIf(
os.name == 'nt',
......@@ -1406,14 +1408,17 @@ class KerasCallbacksTest(keras_parameterized.TestCase):
callbacks=cbks,
epochs=1)
def test_callback_params_samples(self):
x, y = np.ones((64, 3)), np.ones((64, 2))
model = testing_utils.get_small_sequential_mlp(
num_hidden=10, num_classes=2, input_dim=3)
def test_progbar_infers_steps(self):
x, y = np.ones((10, 1)), np.ones((10, 1))
data = dataset_ops.DatasetV2.from_tensor_slices((x, y)).batch(2)
data = data.filter(lambda x, y: True) # Unknown cardinality.
progbar = keras.callbacks.ProgbarLogger('steps')
model = keras.Sequential([keras.layers.Dense(1)])
model.compile('sgd', 'mse')
callback = keras.callbacks.Callback()
model.evaluate(x, y, callbacks=[callback])
self.assertEqual(callback.params['samples'], 64)
self.assertIsNone(progbar.target)
model.fit(data, epochs=2, callbacks=[progbar])
self.assertEqual(progbar.target, 5)
# A summary that was emitted during a test. Fields:
......
......@@ -950,10 +950,16 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
optimizer='adam',
experimental_run_tf_function=experimental_run_tf_function)
def map_fn(img, lbl, weight):
inputs = {'img': img, 'lbl': lbl, 'weight': weight}
targets = {}
return inputs, targets
if context.executing_eagerly():
def map_fn(img, lbl, weight):
inputs = {'img': img, 'lbl': lbl, 'weight': weight}
return (inputs,)
else:
def map_fn(img, lbl, weight):
inputs = {'img': img, 'lbl': lbl, 'weight': weight}
return inputs, {}
fake_imgs = np.ones([50, 64, 64, 3], dtype=np.float32)
fake_lbls = np.ones([50, 64, 64, 1], dtype=np.float32)
......@@ -1178,7 +1184,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
dataset = dataset.repeat(100)
dataset = dataset.batch(10)
with self.assertRaisesRegexp(ValueError, 'expected input to have shape'):
with self.assertRaisesRegexp(ValueError, 'incompatible with the layer'):
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0)
@combinations.generate(
......@@ -1776,7 +1782,9 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
experimental_run_tf_function=experimental_run_tf_function)
ds_history = ds_model.fit(
x, y, validation_data=(x, y), validation_steps=2, epochs=2)
self.assertLen(ds_model.metrics, 1)
# includes stateful loss metric in eager.
metrics_len = 2 if context.executing_eagerly() else 1
self.assertLen(ds_model.metrics, metrics_len)
self.assertAllClose(history.history, ds_history.history)
......@@ -1830,7 +1838,9 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
experimental_run_tf_function=experimental_run_tf_function)
ds_history = ds_model.fit(
x, y, validation_data=(x, y), validation_steps=2, epochs=2)
self.assertLen(ds_model.metrics, 1)
# includes stateful loss metric in eager.
metrics_len = 2 if context.executing_eagerly() else 1
self.assertLen(ds_model.metrics, metrics_len)
self.assertAllClose(history.history, ds_history.history)
......@@ -1870,7 +1880,9 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
experimental_run_tf_function=experimental_run_tf_function)
ds_history = ds_model.fit(
x, y, validation_data=(x, y), validation_steps=2, epochs=2)
self.assertLen(ds_model.metrics, 1)
# includes stateful loss metric in eager.
metrics_len = 2 if context.executing_eagerly() else 1
self.assertLen(ds_model.metrics, metrics_len)
self.assertAllClose(history.history, ds_history.history)
......
......@@ -257,11 +257,8 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase):
experimental_run_tf_function=experimental_run_tf_function)
dataset = keras_test_lib.get_dataset(distribution)
exception_error_message = (
'`validation_split` argument is not supported when ')
# Test with validation split
with self.assertRaisesRegexp(ValueError, exception_error_message):
with self.assertRaises(ValueError):
model.fit(
dataset,
epochs=1,
......@@ -272,9 +269,7 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase):
# Test with sample weight.
sample_weight = np.random.random((10,))
with self.assertRaisesRegexp(
ValueError, '`sample_weight` argument is not supported when.*'
'dataset'):
with self.assertRaises(ValueError):
model.fit(
dataset,
epochs=1,
......@@ -285,69 +280,14 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase):
# Test with not specifying the `steps` argument for dataset with infinite
# cardinality.
dataset = dataset.repeat()
with self.assertRaisesRegexp(
ValueError, 'When passing an infinitely '
'repeating dataset, you must specify the '
'`steps_per_epoch` argument'):
with self.assertRaises(ValueError):
model.fit(dataset, epochs=1, verbose=0)
with self.assertRaisesRegexp(
ValueError, 'When passing an infinitely '
'repeating dataset, you must specify the '
'`steps` argument'):
with self.assertRaises(ValueError):
model.evaluate(dataset, verbose=0)
with self.assertRaisesRegexp(
ValueError, 'When passing an infinitely '
'repeating dataset, you must specify the '
'`steps` argument'):
with self.assertRaises(ValueError):
model.predict(dataset, verbose=0)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
],
mode=['graph', 'eager'],
experimental_run_tf_function=[True, False]))
def test_calling_with_unsupported_predefined_callbacks(
self, distribution, experimental_run_tf_function):
with self.cached_session():
with distribution.scope():
model = keras_test_lib.get_model()
optimizer = gradient_descent.GradientDescentOptimizer(0.001)
loss = 'mse'
metrics = ['mae']
model.compile(
optimizer,
loss,
metrics=metrics,
experimental_run_tf_function=experimental_run_tf_function)
dataset = keras_test_lib.get_dataset(distribution)
def schedule(_):
return 0.001
with self.assertRaisesRegexp(
ValueError, 'You must specify a Keras Optimizer V2 when '
'using'):
model.fit(
dataset,
epochs=1,
steps_per_epoch=2,
verbose=0,
callbacks=[keras.callbacks.LearningRateScheduler(schedule)])
with self.assertRaisesRegexp(
ValueError, 'You must specify a Keras Optimizer V2 when '
'using'):
model.fit(
dataset,
epochs=1,
steps_per_epoch=2,
verbose=0,
callbacks=[keras.callbacks.ReduceLROnPlateau()])
@combinations.generate(
combinations.combine(
distribution=[
......
......@@ -29,8 +29,6 @@ py_library(
"training_generator.py",
"training_utils.py",
"training_v1.py",
"training_v2.py",
"training_v2_utils.py",
],
srcs_version = "PY2AND3",
deps = [
......@@ -428,24 +426,6 @@ tf_py_test(
],
)
tf_py_test(
name = "training_v2_utils_test",
size = "medium",
srcs = ["training_v2_utils_test.py"],
python_version = "PY3",
tags = [
"no_oss", # TODO(b/135021748) reenable
"notsan",
],
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python/distribute:strategy_combinations",
"//tensorflow/python/keras",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
)
tf_py_test(
name = "network_test",
size = "medium",
......
......@@ -22,6 +22,7 @@ import collections
import functools
import itertools
import threading
import weakref
import numpy as np
import six
......@@ -230,6 +231,8 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
# A list of metric instances corresponding to the symbolic metric tensors
# added using the `add_metric` API.
self._metrics = []
# Ensures the same metric is not added multiple times in `MirroredStrategy`.
self._metrics_lock = threading.Lock()
# Both graph and subclassed networks have a dtype policy. For graph
# networks, the policy's compute and variable dtypes are ignored, but other
......@@ -849,10 +852,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
if hasattr(self, '_set_inputs') and not self.inputs:
# Subclassed network: explicitly set metadata normally set by
# a call to self._set_inputs().
# TODO(b/120997007): This should be done in Eager as well, but
# causes garbage collection issues because of the placeholders
# created on the default Keras graph.
self._set_inputs(inputs, outputs)
self._set_inputs(cast_inputs, outputs)
else:
# Eager execution on data tensors.
with backend.name_scope(self._name_scope()):
......@@ -863,6 +863,8 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
outputs = self.call(cast_inputs, *args, **kwargs)
self._handle_activity_regularization(inputs, outputs)
self._set_mask_metadata(inputs, outputs, input_masks)
if hasattr(self, '_set_save_spec'):
self._set_save_spec(cast_inputs)
return outputs
......@@ -1146,7 +1148,8 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
collected_metrics = []
all_layers = self._gather_unique_layers()
for layer in all_layers:
collected_metrics.extend(layer._metrics)
with layer._metrics_lock:
collected_metrics.extend(layer._metrics)
return collected_metrics
@doc_controls.for_subclass_implementers
......@@ -1938,20 +1941,29 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
# on it, otherwise we create a new metric instance and
# add it to the `metrics` list.
metric_obj = getattr(value, '_metric_obj', None)
if metric_obj:
name = metric_obj.name
# Tensors that come from a Metric object already updated the Metric state.
should_update_state = not metric_obj
name = metric_obj.name if metric_obj else name
match = self._get_existing_metric(name)
if match:
# Tensors that come from a Metric object already updated the Metric state.
if not metric_obj:
match(value)
return
with self._metrics_lock:
match = self._get_existing_metric(name)
if match:
metric_obj = match
elif metric_obj:
self._metrics.append(metric_obj)
else:
from tensorflow.python.keras import metrics as metrics_mod # pylint:disable=g-import-not-at-top
if aggregation is None:
raise ValueError(
'`aggregation` must be specified when passing a `Tensor` '
'to `add_metric`.')
assert aggregation is not None
metric_obj = metrics_mod.Mean(name=name, dtype=value.dtype)
self._metrics.append(metric_obj)
if not metric_obj:
assert aggregation is not None
metric_obj, _ = base_layer_utils.create_mean_metric(value, name)
self._metrics.append(metric_obj)
if should_update_state:
metric_obj(value)
return
def _symbolic_add_metric(self, value, aggregation=None, name=None):
base_layer_utils.check_graph_consistency(value, method='add_metric')
......@@ -2259,7 +2271,8 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
layers = trackable_layer_utils.filter_empty_layer_containers(self._layers)
# Keep track of each top-level layers' `trainable` as well as the
# state of all of its sublayers.
trainable_state = {self: self.trainable}
trainable_state = weakref.WeakKeyDictionary()
trainable_state[self] = self.trainable
for layer in layers:
trainable_state.update(layer._get_trainable_state())
return trainable_state
......@@ -2565,10 +2578,12 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
# so shouldn't be copied.
state = self.__dict__.copy()
state.pop('_thread_local', None)
state.pop('_metrics_lock', None)
return state
def __setstate__(self, state):
state['_thread_local'] = threading.local()
state['_metrics_lock'] = threading.Lock()
# Bypass Trackable logic as `__dict__` already contains this info.
object.__setattr__(self, '__dict__', state)
......
......@@ -187,7 +187,7 @@ class BaseLayerTest(keras_parameterized.TestCase):
model.compile(rmsprop.RMSprop(0.001), loss='mse')
self.assertEqual(model.run_eagerly, True)
model.train_on_batch(np.random.random((2, 3)), np.random.random((2, 3)))
self.assertEqual(model.outputs, [None])
self.assertEqual(model.outputs, None)
def test_dynamic_subclassed_model_with_shape_inference(self):
......@@ -210,8 +210,10 @@ class BaseLayerTest(keras_parameterized.TestCase):
model = MyModel()
self.assertEqual(model.dynamic, True)
model.compile(rmsprop.RMSprop(0.001), loss='mse')
model.train_on_batch(np.random.random((2, 3)), np.random.random((2, 3)))
self.assertEqual(model.outputs[0].shape.as_list(), [None, 3])
x, y = np.random.random((2, 3)), np.random.random((2, 3))
model.train_on_batch(x, y)
outputs = model(x)
self.assertEqual(outputs.shape.as_list(), [2, 3])
def test_deepcopy(self):
with context.eager_mode():
......@@ -331,42 +333,6 @@ class BaseLayerTest(keras_parameterized.TestCase):
keras.backend.set_learning_phase(0)
self.assertEqual(get_learning_phase_value(), 0)
@keras_parameterized.run_all_keras_modes
def test_learning_phase_freezing_for_layers_in_predict(self):
if not (testing_utils.should_run_eagerly() or
testing_utils.should_run_tf_function()):
self.skipTest('Predict fails to override the outer learning phase in'
'the FuncGraph path.')
class LearningPhaseLayer(keras.layers.Layer):
def call(self, inputs):
return keras.backend.in_train_phase(
lambda: array_ops.ones_like(inputs),
lambda: array_ops.zeros_like(inputs))
def get_learning_phase_value():
model = keras.models.Sequential([LearningPhaseLayer(input_shape=(1,))])
model._run_eagerly = testing_utils.should_run_eagerly()
model._experimental_run_tf_function = (
testing_utils.should_run_tf_function())
return np.sum(model.predict(np.ones((1, 1))))
self.assertEqual(get_learning_phase_value(), 0)
# Test scope.
with keras.backend.learning_phase_scope(1):
self.assertEqual(get_learning_phase_value(), 0)
# The effects of the scope end after exiting it.
self.assertEqual(get_learning_phase_value(), 0)
# Test setting.
keras.backend.set_learning_phase(1)
self.assertEqual(get_learning_phase_value(), 0)
keras.backend.set_learning_phase(0)
self.assertEqual(get_learning_phase_value(), 0)
# Cannot be enabled with `run_eagerly=True`, see b/123904578
@test_util.run_all_in_graph_and_eager_modes
def test_layer_can_return_variable(self):
......
......@@ -21,9 +21,9 @@ import copy
import six
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.keras import losses as losses_mod
from tensorflow.python.keras import metrics as metrics_mod
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils import losses_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
......@@ -35,6 +35,10 @@ class LossesContainer(object):
"""A container class for losses passed to `Model.compile`."""
def __init__(self, losses, loss_weights=None, output_names=None):
# Keep user-supplied values untouched for recompiling and serialization.
self._user_losses = losses
self._user_loss_weights = loss_weights
self._losses = losses
self._loss_weights = loss_weights
self._output_names = output_names
......@@ -59,7 +63,7 @@ class LossesContainer(object):
if self._output_names is None:
# In Subclass API, output names like 'output_1' are used for
# `Metric` names.
self._output_names = create_output_names(y_pred)
self._output_names = create_pseudo_output_names(y_pred)
# Accept a dict of losses keyed by output_name when outputs are a flat
# list.
......@@ -94,7 +98,11 @@ class LossesContainer(object):
self._built = True
def __call__(self, y_true, y_pred, sample_weight=None):
def __call__(self,
y_true,
y_pred,
sample_weight=None,
regularization_losses=None):
"""Computes the overall loss.
Arguments:
......@@ -104,14 +112,19 @@ class LossesContainer(object):
per-sample loss weights. If one Tensor is passed, it is used for all
losses. If multiple Tensors are passed, the structure should match
`y_pred`.
regularization_losses: Additional losses to be added to the total loss.
Returns:
Tuple of `(total_loss, per_output_loss_list)`
"""
y_true = map_to_output_names(y_pred, self._output_names, y_true)
sample_weight = map_to_output_names(y_pred, self._output_names,
sample_weight)
if not self._built:
self._build(y_pred)
y_true = nest.flatten(y_true)
y_true = nest.flatten(y_true) if y_true is not None else []
y_pred = nest.flatten(y_pred)
# TODO(omalleyt): Remove ambiguity here.
......@@ -127,45 +140,47 @@ class LossesContainer(object):
if len(sample_weight) == 1 and len(y_pred) > 1:
sample_weight = sample_weight * len(y_pred)
loss_values = []
loss_values = [] # Used for gradient calculation.
loss_metric_values = [] # Used for loss metric calculation.
zip_args = (y_true, y_pred, sample_weight, self._losses, self._loss_weights,
self._per_output_metrics)
for y_t, y_p, sw, loss_obj, loss_weight, metric_obj in zip(*zip_args):
if loss_obj is None: # Ok to have no loss for an output.
continue
y_t = math_ops.cast(y_t, y_p.dtype)
if sw is not None:
sw = math_ops.cast(sw, y_p.dtype)
# Handle Keras mask on outputs.
mask = getattr(y_p, '_keras_mask', None)
if mask is not None:
mask = math_ops.cast(mask, y_p.dtype)
if sw is not None:
mask, _, sw = (
tf_losses_utils.squeeze_or_expand_dimensions(
mask, sample_weight=sw))
sw *= mask
else:
sw = mask
y_t, y_p, sw = match_dtype_and_rank(y_t, y_p, sw)
sw = apply_mask(y_p, sw)
loss_value = loss_obj(y_t, y_p, sample_weight=sw)
loss_metric_value = loss_value
# Correct for the `Mean` loss metrics counting each replica as a batch.
if loss_obj.reduction == losses_utils.ReductionV2.SUM:
loss_metric_value *= ds_context.get_strategy().num_replicas_in_sync
if metric_obj is not None:
metric_obj.update_state(loss_value)
metric_obj.update_state(loss_metric_value)
if loss_weight is not None:
loss_value *= loss_weight
loss_metric_value *= loss_weight
if (loss_obj.reduction == losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE or
loss_obj.reduction == losses_utils.ReductionV2.AUTO):
loss_value = losses_utils.scale_loss_for_distribution(loss_value)
loss_values.append(loss_value)
loss_metric_values.append(loss_metric_value)
if regularization_losses:
reg_loss = math_ops.add_n(regularization_losses)
loss_metric_values.append(reg_loss)
loss_values.append(losses_utils.scale_loss_for_distribution(reg_loss))
if loss_values:
total_loss_metric_value = math_ops.add_n(loss_metric_values)
self._loss_metric.update_state(total_loss_metric_value)
total_loss = math_ops.add_n(loss_values)
self._loss_metric.update_state(total_loss)
return total_loss
else:
# Ok for a model to have no compiled loss.
......@@ -188,7 +203,8 @@ class LossesContainer(object):
loss = losses_mod.get(loss)
if not isinstance(loss, losses_mod.Loss):
loss = losses_mod.LossFunctionWrapper(loss, name=loss.__name__)
loss_name = loss.__name__
loss = losses_mod.LossFunctionWrapper(loss, name=loss_name)
loss._allow_sum_over_batch_size = True # pylint: disable=protected-access
return loss
......@@ -197,6 +213,10 @@ class MetricsContainer(object):
"""A container class for metrics passed to `Model.compile`."""
def __init__(self, metrics=None, weighted_metrics=None, output_names=None):
# Keep user-supplied values untouched for recompiling and serialization.
self._user_metrics = metrics
self._user_weighted_metrics = weighted_metrics
self._metrics = metrics
self._weighted_metrics = weighted_metrics
self._output_names = output_names
......@@ -207,22 +227,19 @@ class MetricsContainer(object):
"""Metrics created by this container."""
if not self._built:
return []
metrics = [
metric_obj for metric_obj in nest.flatten(self._metrics)
if metric_obj is not None
]
weighted_metrics = [
metric_obj for metric_obj in nest.flatten(self._weighted_metrics)
if metric_obj is not None
]
return metrics + weighted_metrics
return self._metrics_in_order
def _build(self, y_pred, y_true):
"""One-time setup of metric objects."""
if self._output_names is None:
# Subclass output names like 'output_1' are used for `Metric` names.
self._output_names = create_output_names(y_pred)
self._output_names = create_pseudo_output_names(y_pred)
# If a single metric or flat list of metrics, apply to all outputs.
self._metrics = self._maybe_broadcast(self._metrics, y_pred)
self._weighted_metrics = self._maybe_broadcast(self._weighted_metrics,
y_pred)
# Accept a dict of metrics keyed by output_name when outputs are a flat
# list.
......@@ -231,10 +248,13 @@ class MetricsContainer(object):
self._weighted_metrics = map_to_output_names(y_pred, self._output_names,
self._weighted_metrics)
# If a single metric is supplied, apply to all outputs.
self._metrics = self._maybe_broadcast(self._metrics, y_pred)
self._weighted_metrics = self._maybe_broadcast(self._weighted_metrics,
y_pred)
# Standardize on tuple since `tf.data` turns lists into `Tensor`s.
# pylint: disable=protected-access
y_pred = nest._list_to_tuple(y_pred)
y_true = nest._list_to_tuple(y_true)
self._metrics = nest._list_to_tuple(self._metrics)
self._weighted_metrics = nest._list_to_tuple(self._weighted_metrics)
# pylint: enable=protected-access
# Convert to `Metric` objects, potentially disambiguating based on output
# properties.
......@@ -252,6 +272,17 @@ class MetricsContainer(object):
# Assumes metrics, weighted_metrics have been flattened up to outputs.
self._set_metric_names()
# Cache the flat order needed when returning metrics, for backwards compat.
self._metrics_in_order = []
for output_metrics, output_weighted_metrics in zip(self._metrics,
self._weighted_metrics):
for m in nest.flatten(output_metrics):
if m is not None:
self._metrics_in_order.append(m)
for wm in nest.flatten(output_weighted_metrics):
if wm is not None:
self._metrics_in_order.append(wm)
self._built = True
def _set_metric_names(self):
......@@ -277,9 +308,13 @@ class MetricsContainer(object):
if wm is None:
continue
if is_multi_output:
wm._name = output_name + '_' + wm._name
if wm._name in metric_names:
if output_name + '_' + wm._name in metric_names:
wm._name = output_name + '_weighted_' + wm._name
else:
wm._name = output_name + '_' + wm._name
elif wm._name in metric_names:
wm._name = 'weighted_' + wm._name
if wm._name in metric_names:
raise ValueError('Found two metrics with the same name: {}'.format(
wm._name))
......@@ -288,9 +323,16 @@ class MetricsContainer(object):
def update_state(self, y_true, y_pred, sample_weight=None):
"""Updates the state of per-output metrics."""
flat_y_true = nest.flatten(y_true)
y_true = map_to_output_names(y_pred, self._output_names, y_true)
sample_weight = map_to_output_names(y_pred, self._output_names,
sample_weight)
flat_y_true = nest.flatten(y_true) if y_true is not None else []
flat_y_pred = nest.flatten(y_pred)
if not flat_y_true:
return # Handle case where no targets are passed.
# TODO(omalleyt): Remove ambiguity here (see LossesContainer).
if len(flat_y_true) == 1 and len(flat_y_pred) > 1:
y_true = nest.map_structure(lambda _: flat_y_true[0], y_pred)
......@@ -311,21 +353,8 @@ class MetricsContainer(object):
zip_args = (y_true, y_pred, sample_weight, self._metrics,
self._weighted_metrics)
for y_t, y_p, sw, metric_objs, weighted_metric_objs in zip(*zip_args):
y_t = math_ops.cast(y_t, y_p.dtype)
if sw is not None:
sw = math_ops.cast(sw, y_p.dtype)
# Handle Keras mask on outputs.
mask = getattr(y_p, '_keras_mask', None)
if mask is not None:
mask = math_ops.cast(mask, y_p.dtype)
if sw is not None:
mask, _, sw = (
tf_losses_utils.squeeze_or_expand_dimensions(
mask, sample_weight=sw))
sw *= mask
else:
sw = mask
y_t, y_p, sw = match_dtype_and_rank(y_t, y_p, sw)
sw = apply_mask(y_p, sw)
for metric_obj in metric_objs:
if metric_obj is None:
......@@ -339,7 +368,7 @@ class MetricsContainer(object):
def _get_metric_objects(self, metrics, y_t, y_p):
"""Convert user-supplied metrics to `Metric` objects."""
metrics = generic_utils.to_list(metrics)
metrics = nest.flatten(metrics)
return [self._get_metric_object(m, y_t, y_p) for m in metrics]
def _get_metric_object(self, metric, y_t, y_p):
......@@ -399,31 +428,47 @@ class MetricsContainer(object):
return metric_obj
def _maybe_broadcast(self, metrics, y_pred):
"""If a single Metric is supplied, applies it to all outputs."""
"""If a flat list of Metrics is supplied, apply them to all outputs."""
def _should_broadcast(metrics):
single_valued_list = (
isinstance(metrics, list) and len(metrics) == 1 and
not nest.is_sequence(metrics[0]))
# I.e. `metrics=['accuracy']` or `metrics='accuracy'`.
# In this special case we apply the metric to each output.
return not nest.is_sequence(metrics) or single_valued_list
def _copy(metric):
if isinstance(metric, metrics_mod.Metric):
return metrics_mod.Metric.from_config(metric.get_config())
return metric
# e.g. 'mse'.
if not nest.is_sequence(metrics):
return True
# e.g. ['mse'] or ['mse', 'mae'].
return (isinstance(metrics, (list, tuple)) and
not any(nest.is_sequence(m) for m in metrics))
if _should_broadcast(metrics):
metric = metrics[0] if isinstance(metrics, list) else metrics
return nest.map_structure(lambda _: _copy(metric), y_pred)
copy_metrics = len(nest.flatten(y_pred)) > 1
def _maybe_copy(m):
if copy_metrics and isinstance(m, metrics_mod.Metric):
return m.__class__.from_config(m.get_config())
return m
metrics = nest.flatten(metrics)
return nest.map_structure(lambda _: [_maybe_copy(m) for m in metrics],
y_pred)
return metrics
def create_output_names(y_pred):
"""Creates output names for subclassed Model outputs.
def create_pseudo_output_names(outputs):
"""Create pseudo output names for a subclassed Model."""
return _create_pseudo_names(outputs, prefix='output_')
def create_pseudo_input_names(inputs):
"""Create pseudo input names for a subclassed Model."""
return _create_pseudo_names(inputs, prefix='input_')
These names are used for naming `Metric`s.
def _create_pseudo_names(tensors, prefix):
"""Creates pseudo {input | output} names for subclassed Models.
Warning: this function should only be used to define default
names for `Metics` and `SavedModel`. No other use cases should
rely on a `Model`'s input or output names.
Example with dict:
......@@ -436,10 +481,11 @@ def create_output_names(y_pred):
`['output_1', 'output_2']`
Arguments:
y_pred: `Model`'s outputs.
tensors: `Model`'s outputs or inputs.
prefix: 'output_' for outputs, 'input_' for inputs.
Returns:
Flattened list of output names.
Flattened list of pseudo names.
"""
def one_index(ele):
......@@ -448,18 +494,18 @@ def create_output_names(y_pred):
return ele + 1
return ele
flat_paths = list(nest.yield_flat_paths(y_pred))
flat_paths = list(nest.yield_flat_paths(tensors))
flat_paths = nest.map_structure(one_index, flat_paths)
output_names = []
names = []
for path in flat_paths:
if not path:
output_name = 'output_1'
name = prefix + '1' # Single output.
else:
output_name = '_'.join(str(p) for p in path)
name = '_'.join(str(p) for p in path)
if isinstance(path[0], int):
output_name = 'output_' + output_name
output_names.append(output_name)
return output_names
name = prefix + name
names.append(name)
return names
def map_to_output_names(y_pred, output_names, struct):
......@@ -473,7 +519,7 @@ def map_to_output_names(y_pred, output_names, struct):
For the Functional API, the output names are the names of the
last layer of each output. For the Subclass API, the output names
are determined by `create_output_names` (For example:
are determined by `create_pseudo_output_names` (For example:
`['output_1', 'output_2']` for a list of outputs).
This mapping preserves backwards compatibility for `compile` and
......@@ -492,17 +538,52 @@ def map_to_output_names(y_pred, output_names, struct):
outputs_are_flat_list = (
isinstance(y_pred, (list, tuple)) and
not any(nest.is_sequence(y_p) for y_p in y_pred))
if not outputs_are_flat_list:
# In this case, `y_pred` and `struct` must have the same structure.
single_output = not nest.is_sequence(y_pred)
if (single_output or outputs_are_flat_list) and isinstance(struct, dict):
output_names = output_names or create_pseudo_output_names(y_pred)
struct = copy.copy(struct)
new_struct = [struct.pop(name, None) for name in output_names]
if struct:
raise ValueError('Found unexpected keys that do not correspond '
'to any Model output: {}. Expected: {}'.format(
struct.keys(), output_names))
if len(new_struct) == 1:
return new_struct[0]
return new_struct
else:
return struct
if not isinstance(struct, dict):
return struct
struct = copy.copy(struct)
new_struct = [struct.pop(name, None) for name in output_names]
if struct:
raise ValueError('Found unexpected keys that do not correspond '
'to any Model output: {}. Expected: {}'.format(
struct.keys(), output_names))
return new_struct
def match_dtype_and_rank(y_t, y_p, sw):
"""Match dtype and rank of predictions."""
# Rank.
y_t_rank = len(y_t.shape)
y_p_rank = len(y_p.shape)
if y_t_rank == 1 and y_p_rank == 2:
y_t = array_ops.expand_dims_v2(y_t, axis=-1)
if sw is not None:
sw_rank = len(sw.shape)
if sw_rank == 1 and y_p_rank == 2:
sw = array_ops.expand_dims_v2(sw, axis=-1)
# Dtype.
y_t = math_ops.cast(y_t, y_p.dtype)
if sw is not None:
sw = math_ops.cast(sw, y_p.dtype)
return y_t, y_p, sw
def apply_mask(y_p, sw):
"""Applies any mask on predictions to sample weights."""
# Handle Keras mask on outputs.
mask = getattr(y_p, '_keras_mask', None)
if mask is not None:
mask = math_ops.cast(mask, y_p.dtype)
if sw is not None:
mask, _, sw = (
tf_losses_utils.squeeze_or_expand_dimensions(mask, sample_weight=sw))
sw *= mask
else:
sw = mask
return sw
......@@ -234,29 +234,37 @@ class MetricsContainerTest(keras_parameterized.TestCase):
def test_list_of_metrics_list_of_outputs(self):
metric_container = compile_utils.MetricsContainer(
metrics=['mse', 'mae'],
metrics=['mse', 'mae'], # Should broadcast to both outputs.
weighted_metrics=['accuracy']) # Should broadcast to both outputs.
y_t = [array_ops.ones((10, 1)), array_ops.zeros((10, 1))]
y_p = [array_ops.ones((10, 1)), 2 * array_ops.ones((10, 1))]
sw = ops.convert_to_tensor_v2([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
metric_container.update_state(y_t, y_p, sample_weight=sw)
self.assertLen(metric_container.metrics, 4)
self.assertLen(metric_container.metrics, 6)
mse_metric = metric_container.metrics[0]
self.assertEqual(mse_metric.name, 'output_1_mse')
self.assertEqual(mse_metric.result().numpy(), 0.)
mae_metric = metric_container.metrics[1]
self.assertEqual(mae_metric.name, 'output_2_mae')
self.assertEqual(mae_metric.result().numpy(), 2.)
mse_metric = metric_container.metrics[1]
self.assertEqual(mse_metric.name, 'output_1_mae')
self.assertEqual(mse_metric.result().numpy(), 0.)
acc_metric_1 = metric_container.metrics[2]
self.assertEqual(acc_metric_1.name, 'output_1_accuracy')
self.assertEqual(acc_metric_1.result().numpy(), 1.)
self.assertEqual(acc_metric_1._fn, metrics_mod.binary_accuracy)
acc_metric_2 = metric_container.metrics[3]
mae_metric = metric_container.metrics[3]
self.assertEqual(mae_metric.name, 'output_2_mse')
self.assertEqual(mae_metric.result().numpy(), 4.)
mae_metric = metric_container.metrics[4]
self.assertEqual(mae_metric.name, 'output_2_mae')
self.assertEqual(mae_metric.result().numpy(), 2.)
acc_metric_2 = metric_container.metrics[5]
self.assertEqual(acc_metric_2.name, 'output_2_accuracy')
self.assertEqual(acc_metric_2.result().numpy(), 0.)
self.assertEqual(acc_metric_2._fn, metrics_mod.binary_accuracy)
......@@ -281,16 +289,16 @@ class MetricsContainerTest(keras_parameterized.TestCase):
self.assertEqual(mse_metric.name, 'out1_mse')
self.assertEqual(mse_metric.result().numpy(), 0.)
mae_metric = metric_container.metrics[1]
weighted_mse_metric = metric_container.metrics[1]
self.assertEqual(weighted_mse_metric.name, 'out1_weighted_mse')
self.assertEqual(weighted_mse_metric.result().numpy(), 0.)
mae_metric = metric_container.metrics[2]
self.assertEqual(mae_metric.name, 'out2_mae')
self.assertEqual(mae_metric.result().numpy(), 2.)
weighted_mse_metric = metric_container.metrics[2]
self.assertEqual(weighted_mse_metric.name, 'weighted_out1_mse')
self.assertEqual(weighted_mse_metric.result().numpy(), 0.)
weighted_mae_metric = metric_container.metrics[3]
self.assertEqual(weighted_mae_metric.name, 'weighted_out2_mae')
self.assertEqual(weighted_mae_metric.name, 'out2_weighted_mae')
self.assertEqual(weighted_mae_metric.result().numpy(), 2.)
def test_metric_partial_dict_with_output_names(self):
......@@ -355,14 +363,14 @@ class MetricsContainerTest(keras_parameterized.TestCase):
self.assertEqual(a_mae_metric.name, 'a_mae')
self.assertEqual(a_mae_metric.result().numpy(), 1.)
b_1_mse_metric = metric_container.metrics[1]
self.assertEqual(b_1_mse_metric.name, 'b_1_mse')
self.assertEqual(b_1_mse_metric.result().numpy(), 4.)
weighted_a_mae_metric = metric_container.metrics[2]
weighted_a_mae_metric = metric_container.metrics[1]
self.assertEqual(weighted_a_mae_metric.name, 'a_mse')
self.assertEqual(weighted_a_mae_metric.result().numpy(), 1.)
b_1_mse_metric = metric_container.metrics[2]
self.assertEqual(b_1_mse_metric.name, 'b_1_mse')
self.assertEqual(b_1_mse_metric.result().numpy(), 4.)
def test_crossentropy(self):
metric_container = compile_utils.MetricsContainer('crossentropy')
y_t, y_p = array_ops.ones((10, 1)), array_ops.ones((10, 1))
......@@ -422,6 +430,29 @@ class MetricsContainerTest(keras_parameterized.TestCase):
self.assertEqual(weighted_mae_metric.name, 'weighted_mae')
self.assertEqual(weighted_mae_metric.result().numpy(), 0.)
def test_broadcast_metrics_to_dict(self):
metric_container = compile_utils.MetricsContainer(metrics=['mae'])
y_p = {'output': ops.convert_to_tensor([[0], [1], [2]])}
y_t = {'output': ops.convert_to_tensor([[1], [2], [3]])}
metric_container.update_state(y_t, y_p)
mae_metric = metric_container.metrics[0]
self.assertEqual(mae_metric.name, 'mae')
self.assertEqual(mae_metric.result().numpy(), 1.)
def test_broadcast_metrics_to_dict_with_output_names(self):
metric_container = compile_utils.MetricsContainer(
metrics=['mae'], output_names=['output'])
y_p = ops.convert_to_tensor([[0], [1], [2]])
y_t = {'output': ops.convert_to_tensor([[1], [2], [3]])}
metric_container.update_state(y_t, y_p)
mae_metric = metric_container.metrics[0]
self.assertEqual(mae_metric.name, 'mae')
self.assertEqual(mae_metric.result().numpy(), 1.)
if __name__ == '__main__':
ops.enable_eager_execution()
......
......@@ -124,11 +124,6 @@ class TensorLikeDataAdapterTest(DataAdapterTestBase):
self.assertFalse(self.adapter_cls.can_handle(self.generator_input))
self.assertFalse(self.adapter_cls.can_handle(self.sequence_input))
def test_iterator_expect_batch_size_numpy(self):
with self.assertRaisesRegexp(
ValueError, r'`batch_size` or `steps` is required'):
self.adapter_cls(self.numpy_input, self.numpy_target)
def test_size_numpy(self):
adapter = self.adapter_cls(
self.numpy_input, self.numpy_target, batch_size=5)
......@@ -428,12 +423,6 @@ class GenericArrayLikeDataAdapterTest(DataAdapterTestBase):
self.assertFalse(self.adapter_cls.can_handle(self.generator_input))
self.assertFalse(self.adapter_cls.can_handle(self.sequence_input))
def test_iterator_expect_batch_size_generic_arraylike(self):
with self.assertRaisesRegexp(
ValueError, r'`batch_size` or `steps` is required'):
self.adapter_cls(self.arraylike_input,
self.arraylike_target)
def test_size(self):
adapter = self.adapter_cls(
self.arraylike_input,
......@@ -885,6 +874,7 @@ class DataHandlerTest(keras_parameterized.TestCase):
def test_insufficient_data(self):
ds = dataset_ops.DatasetV2.from_tensor_slices([0, 1])
ds = ds.filter(lambda *args, **kwargs: True)
data_handler = data_adapter.DataHandler(
ds, initial_epoch=0, epochs=2, steps_per_epoch=3)
returned_data = []
......@@ -963,53 +953,6 @@ class DataHandlerTest(keras_parameterized.TestCase):
self.assertEqual(returned_data, [[([0],), ([1],),
([2],)], [([0],), ([1],), ([2],)]])
def test_class_weight(self):
data_handler = data_adapter.DataHandler(
x=[[0], [1], [2]],
y=[[2], [1], [0]],
class_weight={
0: 0.5,
1: 1.,
2: 1.5
},
epochs=2,
steps_per_epoch=3)
returned_data = []
for _, iterator in data_handler.enumerate_epochs():
epoch_data = []
for _ in data_handler.steps():
epoch_data.append(next(iterator))
returned_data.append(epoch_data)
returned_data = self.evaluate(returned_data)
self.assertEqual(returned_data, [[([0], [2], [1.5]), ([1], [1], [1.]),
([2], [0], [0.5])],
[([0], [2], [1.5]), ([1], [1], [1.]),
([2], [0], [0.5])]])
def test_class_weight_and_sample_weight(self):
data_handler = data_adapter.DataHandler(
x=[[0], [1], [2]],
y=[[2], [1], [0]],
sample_weight=[[1.], [2.], [4.]],
class_weight={
0: 0.5,
1: 1.,
2: 1.5
},
epochs=2,
steps_per_epoch=3)
returned_data = []
for _, iterator in data_handler.enumerate_epochs():
epoch_data = []
for _ in data_handler.steps():
epoch_data.append(next(iterator))
returned_data.append(epoch_data)
returned_data = self.evaluate(returned_data)
self.assertEqual(returned_data, [[([0], [2], [1.5]), ([1], [1], [2.]),
([2], [0], [2.])],
[([0], [2], [1.5]), ([1], [1], [2.]),
([2], [0], [2.])]])
def test_class_weight_user_errors(self):
with self.assertRaisesRegexp(ValueError, 'to be a dict with keys'):
data_adapter.DataHandler(
......
......@@ -40,6 +40,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras import backend
from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.engine import compile_utils
from tensorflow.python.keras.engine import input_layer as input_layer_module
from tensorflow.python.keras.engine import node as node_module
from tensorflow.python.keras.engine import training_utils
......@@ -50,6 +51,7 @@ from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpoint_management
......@@ -200,7 +202,10 @@ class Network(base_layer.Layer):
super(Network, self).__init__(name=name, **kwargs)
self.output_names = None
self.input_names = None
self._is_compiled = False
self._saved_model_inputs_spec = None
# This is True for Sequential networks and Functional networks.
self._compute_output_and_mask_jointly = False
......@@ -326,6 +331,7 @@ class Network(base_layer.Layer):
self._feed_inputs.append(layer.input)
self._compute_tensor_usage_count()
self._set_save_spec(self._nested_inputs)
def _set_output_names(self):
"""Assigns unique names to the Network's outputs.
......@@ -354,8 +360,8 @@ class Network(base_layer.Layer):
self._autocast = kwargs.get('autocast',
base_layer_utils.v2_dtype_behavior_enabled())
self._supports_ragged_inputs = None
self.outputs = []
self.inputs = []
self.outputs = None
self.inputs = None
self.built = False
self._build_input_shape = None
......@@ -573,24 +579,7 @@ class Network(base_layer.Layer):
A list of `InputSpec` instances (one per input to the model)
or a single instance if the model has only one input.
"""
# If subclassed model, can't assume anything.
if not self._is_graph_network:
return None
specs = []
for layer in self._input_layers:
if layer.input_spec is None:
specs.append(None)
else:
if not isinstance(layer.input_spec, list):
raise TypeError('Layer ' + layer.name +
' has an input_spec attribute that '
'is not a list. We expect a list. '
'Found input_spec = ' + str(layer.input_spec))
specs += layer.input_spec
if len(specs) == 1:
return specs[0]
return specs
return
@base_layer_utils.default
def build(self, input_shape):
......@@ -648,6 +637,11 @@ class Network(base_layer.Layer):
if isinstance(input_shape, list):
x = [base_layer_utils.generate_placeholders_from_shape(shape)
for shape in input_shape]
elif isinstance(input_shape, dict):
x = {
k: base_layer_utils.generate_placeholders_from_shape(shape)
for k, shape in input_shape.items()
}
else:
x = base_layer_utils.generate_placeholders_from_shape(input_shape)
......@@ -834,8 +828,7 @@ class Network(base_layer.Layer):
tensor_dict = {}
for x, y in zip(self.inputs, inputs):
x_id = str(id(x))
tensor_dict[x_id] = [y] * self._tensor_usage_count[x_id]
# Set shape and dtype based on `keras.Input`s.
if isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor):
try:
y.set_shape(y.shape.merge_with(x.shape))
......@@ -844,6 +837,11 @@ class Network(base_layer.Layer):
'Model was constructed with shape {} for input {}, but it was '
're-called on a Tensor with incompatible shape {}.'
.format(x, x.shape, y.shape))
if isinstance(x, (ops.Tensor, composite_tensor.CompositeTensor)):
y = math_ops.cast(y, dtype=x.dtype)
x_id = str(id(x))
tensor_dict[x_id] = [y] * self._tensor_usage_count[x_id]
depth_keys = list(self._nodes_by_depth.keys())
depth_keys.sort(reverse=True)
......@@ -1533,6 +1531,32 @@ class Network(base_layer.Layer):
new_layers.append(add_metric_layer)
self._insert_layers(new_layers, new_nodes)
@trackable.no_automatic_dependency_tracking
def _set_save_spec(self, inputs):
if self._saved_model_inputs_spec is not None:
return # Already set.
input_names = self.input_names
if not input_names:
input_names = compile_utils.create_pseudo_input_names(inputs)
flat_inputs = nest.flatten(inputs)
specs = []
for name, tensor in zip(input_names, flat_inputs):
specs.append(
tf_utils.get_tensor_spec(tensor, dynamic_batch=False, name=name))
specs = nest.pack_sequence_as(inputs, specs)
self._saved_model_inputs_spec = specs
def _get_save_spec(self, dynamic_batch=True):
if self._saved_model_inputs_spec is None:
return None
return nest.map_structure(
lambda t: tf_utils.get_tensor_spec(t, dynamic_batch=dynamic_batch),
self._saved_model_inputs_spec)
@property
def _trackable_saved_model_saver(self):
return network_serialization.NetworkSavedModelSaver(self)
......
......@@ -266,6 +266,10 @@ class Sequential(training.Model):
self.built = True
def call(self, inputs, training=None, mask=None): # pylint: disable=redefined-outer-name
if self._build_input_shape is None:
input_shapes = nest.map_structure(_get_shape_tuple, inputs)
self._build_input_shape = input_shapes
if self._is_graph_network:
if not self.built:
self._init_graph_network(self.inputs, self.outputs, name=self.name)
......@@ -364,7 +368,7 @@ class Sequential(training.Model):
'name': self.name,
'layers': copy.deepcopy(layer_configs)
}
if self._build_input_shape:
if self._build_input_shape is not None:
config['build_input_shape'] = self._build_input_shape
return config
......@@ -383,7 +387,8 @@ class Sequential(training.Model):
layer = layer_module.deserialize(layer_config,
custom_objects=custom_objects)
model.add(layer)
if not model.inputs and build_input_shape:
if (not model.inputs and build_input_shape and
isinstance(build_input_shape, (tuple, list))):
model.build(build_input_shape)
return model
......@@ -396,3 +401,12 @@ class Sequential(training.Model):
@property
def _trackable_saved_model_saver(self):
return model_serialization.SequentialSavedModelSaver(self)
def _get_shape_tuple(t):
if hasattr(t, 'shape'):
shape = t.shape
if shape.rank is not None:
return tuple(shape.as_list())
return None
return None
......@@ -286,9 +286,16 @@ class TestSequential(keras_parameterized.TestCase):
self.assertTrue(model.built)
config = model.get_config()
self.assertIn('build_input_shape', config)
new_model = keras.models.Sequential.from_config(config)
new_model.compile(
loss='mse',
optimizer='rmsprop',
metrics=[keras.metrics.CategoricalAccuracy()],
run_eagerly=testing_utils.should_run_eagerly(),
experimental_run_tf_function=testing_utils.should_run_tf_function())
x = np.random.random((batch_size, input_dim))
y = np.random.random((batch_size, num_classes))
new_model.train_on_batch(x, y)
self.assertEqual(len(new_model.layers), 2)
self.assertEqual(len(new_model.weights), 4)
......@@ -321,15 +328,12 @@ class TestSequential(keras_parameterized.TestCase):
self.assertFalse(model.built)
model(array_ops.zeros([1, 2]))
self.assertTrue(model.built)
self.assertEqual(len(model.outputs), 0)
model.compile(
'rmsprop',
loss='mse',
run_eagerly=testing_utils.should_run_eagerly(),
experimental_run_tf_function=testing_utils.should_run_tf_function())
self.assertEqual(len(model.outputs), 0)
model.train_on_batch(np.zeros((1, 2)), np.zeros((1, 5)))
self.assertEqual(len(model.outputs), 1)
@keras_parameterized.run_all_keras_modes
def test_sequential_nesting(self):
......@@ -399,29 +403,21 @@ class TestSequential(keras_parameterized.TestCase):
ValueError, 'should have a single output tensor'):
keras.Sequential([MultiOutputLayer()])(np.zeros((10, 10)))
@keras_parameterized.run_all_keras_modes
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
def test_layer_add_after_compile_deferred(self):
model = keras.Sequential([keras.layers.Dense(3)])
self.assertFalse(model.built)
self.assertFalse(model.inputs)
self.assertFalse(model.outputs)
model.compile('adam', loss='mse')
model.fit(np.random.random((1, 3)), np.random.random((1, 3)))
self.assertTrue(model.built)
self.assertTrue(model.inputs)
self.assertTrue(model.outputs)
model.add(keras.layers.Dense(3))
self.assertTrue(model.built)
self.assertTrue(model.inputs)
self.assertTrue(model.outputs)
self.assertFalse(model.built)
model.compile('adam', loss='mse')
model.fit(np.random.random((1, 3)), np.random.random((1, 3)))
self.assertTrue(model.built)
def test_sequential_layer_tracking(self):
"""Test that Sequential only tracks layers added in init or `.add`."""
......@@ -442,21 +438,6 @@ class TestSequential(keras_parameterized.TestCase):
model.pop()
self.assertEqual(model._layers[-1], layer)
@testing_utils.enable_v2_dtype_behavior
def test_sequential_does_not_autocast(self):
class AssertFloat64InputLayer(keras.layers.Layer):
def __init__(self):
super(AssertFloat64InputLayer, self).__init__(autocast=False)
def call(self, inputs):
assert inputs.dtype == 'float64', 'inputs are %s' % inputs.dtype
return array_ops.identity(inputs)
model = keras.Sequential([AssertFloat64InputLayer(), keras.layers.Dense(4)])
model(np.random.random((4, 4)))
class TestSequentialEagerIntegration(keras_parameterized.TestCase):
......@@ -500,27 +481,6 @@ class TestSequentialEagerIntegration(keras_parameterized.TestCase):
y = np.random.random((2, 5))
model.fit(x, y, epochs=1)
@keras_parameterized.run_all_keras_modes
def test_sequential_model_fails_with_dict_inputs(self):
num_classes = 5
model = testing_utils.get_small_sequential_mlp(
num_hidden=10, num_classes=num_classes)
model.compile(
'rmsprop',
metrics=['acc'],
weighted_metrics=['mae'],
loss='categorical_crossentropy',
run_eagerly=testing_utils.should_run_eagerly(),
experimental_run_tf_function=testing_utils.should_run_tf_function())
x = {'dense_input': np.random.random((10, 1))}
y = np.random.randint(num_classes, size=(10, 1))
with self.assertRaisesRegexp(
ValueError, 'Passing a dictionary input to a Sequential Model which '
'doesn\'t have FeatureLayer as the first layer is an error'):
model.fit(x, y, batch_size=5, epochs=1)
if __name__ == '__main__':
test.main()
......@@ -226,13 +226,9 @@ def model_iteration(model,
epochs=epochs,
steps_per_epoch=steps_per_epoch,
samples=num_samples_or_steps,
verbose=0, # Handle ProgBarLogger separately in this loop.
count_mode=count_mode,
verbose=verbose,
mode=mode)
# TODO(omalleyt): Handle ProgBar as part of Callbacks once hooks are ready.
progbar = training_utils.get_progbar(
model, count_mode, mode != ModeKeys.PREDICT)
progbar.params = callbacks.params
progbar.params['verbose'] = verbose
# Find beforehand arrays that need sparse-to-dense conversion.
if issparse is not None and not use_steps:
......@@ -259,7 +255,6 @@ def model_iteration(model,
callbacks.model.stop_training = False
callbacks._call_begin_hook(mode)
progbar.on_train_begin()
initial_epoch = model._maybe_load_initial_epoch_from_ckpt(initial_epoch, mode)
......@@ -275,7 +270,6 @@ def model_iteration(model,
model.reset_metrics()
if mode == ModeKeys.TRAIN:
callbacks.on_epoch_begin(epoch, epoch_logs)
progbar.on_epoch_begin(epoch, epoch_logs)
if use_steps:
# Step-wise loop.
......@@ -290,7 +284,6 @@ def model_iteration(model,
while step < target_steps:
batch_logs = {'batch': step, 'size': 1}
callbacks._call_batch_hook(mode, 'begin', step, batch_logs)
progbar.on_batch_begin(step, batch_logs)
# Get outputs.
try:
......@@ -320,9 +313,6 @@ def model_iteration(model,
elif step > 0:
steps_per_epoch = step
aggregator.steps = steps_per_epoch
if mode == ModeKeys.TRAIN:
progbar.params['steps'] = steps_per_epoch
progbar.progbar.target = steps_per_epoch
else:
# We ran out of batches while the user passed an iterator (legacy).
callbacks.model.stop_training = True
......@@ -350,7 +340,6 @@ def model_iteration(model,
# Callbacks batch end.
batch_logs = cbks.make_logs(model, batch_logs, batch_outs, mode)
callbacks._call_batch_hook(mode, 'end', step, batch_logs)
progbar.on_batch_end(step, batch_logs)
step += 1
if callbacks.model.stop_training:
......@@ -392,7 +381,6 @@ def model_iteration(model,
# Callbacks batch_begin.
batch_logs = {'batch': batch_index, 'size': len(batch_ids)}
callbacks._call_batch_hook(mode, 'begin', batch_index, batch_logs)
progbar.on_batch_begin(batch_index, batch_logs)
# Get outputs.
batch_outs = f(ins_batch)
......@@ -407,7 +395,6 @@ def model_iteration(model,
# Callbacks batch end.
batch_logs = cbks.make_logs(model, batch_logs, batch_outs, mode)
callbacks._call_batch_hook(mode, 'end', batch_index, batch_logs)
progbar.on_batch_end(batch_index, batch_logs)
if callbacks.model.stop_training:
break
......@@ -452,7 +439,6 @@ def model_iteration(model,
if mode == ModeKeys.TRAIN:
# Epochs only apply to `fit`.
callbacks.on_epoch_end(epoch, epoch_logs)
progbar.on_epoch_end(epoch, epoch_logs)
# Reinitialize dataset iterator for the next epoch.
if reset_dataset_after_each_epoch and epoch < epochs - 1:
......
......@@ -107,8 +107,7 @@ class TestTrainingWithDataset(keras_parameterized.TestCase):
validation_data=dataset, validation_steps=2)
# Test with validation split
with self.assertRaisesRegexp(
ValueError, '`validation_split` argument is not supported when '):
with self.assertRaises(ValueError):
model.fit(dataset,
epochs=1, steps_per_epoch=2, verbose=0,
validation_split=0.5, validation_steps=2)
......@@ -124,19 +123,6 @@ class TestTrainingWithDataset(keras_parameterized.TestCase):
verbose=0,
sample_weight=sample_weight)
# Test invalid usage
with self.assertRaisesRegexp(
ValueError, 'The `batch_size` argument must not be specified'):
model.fit(dataset, batch_size=10, epochs=1, steps_per_epoch=2,
verbose=0)
with self.assertRaisesRegexp(
ValueError, 'The `batch_size` argument must not be specified'):
model.predict(dataset, batch_size=10, steps=2, verbose=0)
with self.assertRaisesRegexp(
ValueError, 'The `batch_size` argument must not be specified'):
model.evaluate(dataset, batch_size=10, steps=2, verbose=0)
with self.assertRaisesRegexp(
ValueError, '(you should not specify a target)|'
'(`y` argument is not supported when using dataset as input.)'):
......@@ -144,14 +130,11 @@ class TestTrainingWithDataset(keras_parameterized.TestCase):
epochs=1, steps_per_epoch=2, verbose=0)
# With an infinite dataset, `steps_per_epoch`/`steps` argument is required.
with self.assertRaisesRegexp(
ValueError, 'the `steps_per_epoch` argument'):
with self.assertRaises(ValueError):
model.fit(dataset, epochs=1, verbose=0)
with self.assertRaisesRegexp(ValueError,
'the `steps` argument'):
with self.assertRaises(ValueError):
model.evaluate(dataset, verbose=0)
with self.assertRaisesRegexp(ValueError,
'the `steps` argument'):
with self.assertRaises(ValueError):
model.predict(dataset, verbose=0)
@keras_parameterized.run_with_all_model_types(exclude_models='sequential')
......@@ -185,14 +168,6 @@ class TestTrainingWithDataset(keras_parameterized.TestCase):
model.fit(dataset_tuple, epochs=1, steps_per_epoch=2, verbose=1)
model.evaluate(dataset_tuple, steps=2, verbose=1)
predict_dataset_tuple = dataset_ops.Dataset.from_tensor_slices(
(input_a_np, input_b_np))
# TODO(b/123360757): Remove below assertion once predict() supports
# muti-input datasets.
with self.assertRaisesRegexp(ValueError,
'Error when checking model input'):
model.predict(predict_dataset_tuple, steps=1)
# Test with dict
input_dict = {'input_1': input_a_np, 'input_2': input_b_np}
if testing_utils.get_model_type() == 'subclass':
......@@ -457,15 +432,7 @@ class TestTrainingWithDataset(keras_parameterized.TestCase):
self.assertIn('10/10', lines[-1])
self.assertLen(history.history['loss'], 2)
# The first epoch will invoke batch begin 11 times, since it doesn't know
# the cardinality. The second epoch should just invoke 10 times.
if (testing_utils.should_run_eagerly()
or testing_utils.should_run_tf_function()):
expected_batch_begin_count = 21
else:
expected_batch_begin_count = 20
self.assertEqual(batch_counter.batch_begin_count,
expected_batch_begin_count)
self.assertEqual(batch_counter.batch_begin_count, 21)
self.assertEqual(batch_counter.batch_end_count, 20)
model.evaluate(dataset)
out = model.predict(dataset)
......
......@@ -194,12 +194,10 @@ class TrainingTest(keras_parameterized.TestCase):
model.fit(dataset, epochs=1, verbose=0)
# Step argument is required for infinite datasets.
with self.assertRaisesRegexp(ValueError,
'specify the `validation_steps` argument.'):
with self.assertRaises(ValueError):
model.fit(dataset, steps_per_epoch=2, epochs=1, verbose=0,
validation_data=validation_dataset)
with self.assertRaisesRegexp(ValueError,
'specify the `validation_steps` argument.'):
with self.assertRaises(ValueError):
model.fit(dataset, steps_per_epoch=2, epochs=1, verbose=0,
validation_data=validation_dataset)
......@@ -355,7 +353,8 @@ class CorrectnessTest(keras_parameterized.TestCase):
x = np.ones((20, 4)).astype(np.float32)
y = np.random.randint(0, 3, size=(20,)).astype(np.int64)
dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).batch(2)
evaluation_results = dict(zip(model.metrics_names, model.evaluate(dataset)))
results = model.evaluate(dataset)
evaluation_results = dict(zip(model.metrics_names, results))
# Rate of dropout depends on the learning phase.
self.assertEqual(evaluation_results['regularization_loss'],
expected_validation_loss)
......
......@@ -174,12 +174,9 @@ def model_iteration(model,
steps_per_epoch=steps_per_epoch,
batch_size=batch_size,
samples=num_samples_or_steps,
verbose=0, # Handle ProgBar as part of Callbacks once hooks are ready.
count_mode=count_mode,
verbose=verbose,
mode=mode)
# TODO(omalleyt): Handle ProgBar as part of Callbacks once hooks are ready.
progbar = training_utils.get_progbar(model, count_mode)
progbar.params = callbacks.params
progbar.params['verbose'] = verbose
if mode == ModeKeys.PREDICT:
aggregator = training_utils.OutputsAggregator(True, steps=steps_per_epoch)
......@@ -194,7 +191,6 @@ def model_iteration(model,
callbacks.model.stop_training = False
callbacks._call_begin_hook(mode)
progbar.on_train_begin()
initial_epoch = model._maybe_load_initial_epoch_from_ckpt(initial_epoch, mode)
......@@ -207,7 +203,6 @@ def model_iteration(model,
epoch_logs = {}
if mode == ModeKeys.TRAIN:
callbacks.on_epoch_begin(epoch, epoch_logs)
progbar.on_epoch_begin(epoch, epoch_logs)
if steps_per_epoch is None:
# Loop over dataset until `OutOfRangeError` is raised.
......@@ -237,9 +232,6 @@ def model_iteration(model,
elif step > 0:
steps_per_epoch = step
aggregator.steps = steps_per_epoch
if mode == ModeKeys.TRAIN:
progbar.params['steps'] = steps_per_epoch
progbar.progbar.target = steps_per_epoch
else:
# We ran out of batches while the user passed an iterator (legacy).
callbacks.model.stop_training = True
......@@ -259,7 +251,6 @@ def model_iteration(model,
# Callbacks batch begin.
batch_logs = {'batch': step, 'size': batch_size}
callbacks._call_batch_hook(mode, 'begin', step, batch_logs)
progbar.on_batch_begin(step, batch_logs)
is_deferred = not model._is_compiled
batch_outs = batch_function(*batch_data)
......@@ -283,16 +274,12 @@ def model_iteration(model,
verbose=verbose,
mode=mode)
progbar.params = callbacks.params
progbar.params['verbose'] = verbose
# Aggregate results.
aggregator.aggregate(batch_outs)
# Callbacks batch end.
batch_logs = cbks.make_logs(model, batch_logs, batch_outs, mode)
callbacks._call_batch_hook(mode, 'end', step, batch_logs)
progbar.on_batch_end(step, batch_logs)
step += 1
if callbacks.model.stop_training:
......@@ -330,7 +317,6 @@ def model_iteration(model,
if mode == ModeKeys.TRAIN:
# Epochs only apply to `fit`.
callbacks.on_epoch_end(epoch, epoch_logs)
progbar.on_epoch_end(epoch, epoch_logs)
# Recreate dataset iterator for the next epoch.
if reset_dataset_after_each_epoch and epoch < epochs - 1:
......
......@@ -245,15 +245,14 @@ class TestGeneratorMethods(keras_parameterized.TestCase):
run_eagerly=testing_utils.should_run_eagerly(),
experimental_run_tf_function=testing_utils.should_run_tf_function())
err_msg = 'Output of generator should be a tuple of 1 or 2 or 3 elements'
with self.assertRaisesRegex(ValueError, err_msg):
with self.assertRaises(ValueError):
model.fit_generator(invalid_generator(),
steps_per_epoch=5,
epochs=1,
verbose=1,
max_queue_size=10,
use_multiprocessing=False)
with self.assertRaisesRegex(ValueError, err_msg):
with self.assertRaises(ValueError):
model.fit_generator(custom_generator(),
steps_per_epoch=5,
epochs=1,
......@@ -262,12 +261,12 @@ class TestGeneratorMethods(keras_parameterized.TestCase):
use_multiprocessing=False,
validation_data=invalid_generator(),
validation_steps=10)
with self.assertRaisesRegex(ValueError, err_msg):
with self.assertRaises(ValueError):
model.predict_generator(invalid_generator(),
steps=5,
max_queue_size=10,
use_multiprocessing=False)
with self.assertRaisesRegex(ValueError, err_msg):
with self.assertRaises(ValueError):
model.evaluate_generator(invalid_generator(),
steps=5,
max_queue_size=10,
......@@ -330,38 +329,11 @@ class TestGeneratorMethods(keras_parameterized.TestCase):
model.evaluate(custom_generator_changing_batch_size(), steps=5)
model.predict(custom_generator_changing_batch_size(), steps=5)
@keras_parameterized.run_with_all_model_types
@keras_parameterized.run_all_keras_modes
def test_invalid_batch_size_argument(self):
def ones_generator():
while True:
yield np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
model = testing_utils.get_small_mlp(
num_hidden=10, num_classes=1, input_dim=10)
model.compile(
'adam',
'binary_crossentropy',
run_eagerly=testing_utils.should_run_eagerly(),
experimental_run_tf_function=testing_utils.should_run_tf_function())
with self.assertRaisesRegexp(
ValueError, 'The `batch_size` argument must not be specified'):
model.fit(ones_generator(), batch_size=2, epochs=2)
with self.assertRaisesRegexp(
ValueError, 'The `batch_size` argument must not be specified'):
model.evaluate(ones_generator(), batch_size=2)
with self.assertRaisesRegexp(
ValueError, 'The `batch_size` argument must not be specified'):
model.predict(ones_generator(), batch_size=2)
@keras_parameterized.run_with_all_model_types
@keras_parameterized.run_all_keras_modes
@data_utils.dont_use_multiprocessing_pool
def test_generator_dynamic_shapes(self):
x = [
'I think juice is great',
'unknown is the best language since slicedbread',
......
......@@ -49,8 +49,6 @@ from tensorflow.python.keras.engine import training_distributed
from tensorflow.python.keras.engine import training_eager
from tensorflow.python.keras.engine import training_generator
from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras.engine import training_v2
from tensorflow.python.keras.engine import training_v2_utils
from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
from tensorflow.python.keras.saving.saved_model import model_serialization
......@@ -162,6 +160,8 @@ class Model(training_lib.Model):
self._experimental_run_tf_function = (
ops.executing_eagerly_outside_functions())
self._v1_compile_was_called = False
@trackable.no_automatic_dependency_tracking
def _set_strategy(self, strategy):
self._compile_time_distribution_strategy = strategy
......@@ -301,6 +301,7 @@ class Model(training_lib.Model):
self._run_eagerly = kwargs.pop('run_eagerly', None)
self._experimental_run_tf_function = kwargs.pop(
'experimental_run_tf_function', True)
self._v1_compile_was_called = True
# Prepare Session arguments (legacy).
kwargs.pop('cloning', None) # Legacy DistStrat argument, never used.
......@@ -561,14 +562,6 @@ class Model(training_lib.Model):
'original `Dataset` object instead of passing in '
'`iter(dataset)`.')
# Experiment training loop with default DS path.
if context.executing_eagerly() and self._experimental_run_tf_function:
if self._in_multi_worker_mode():
return training_distributed.DistributionMultiWorkerTrainingLoop(
training_v2.Loop())
else:
return training_v2.Loop()
# Case 1: distribution strategy.
if self._distribution_strategy:
if self._in_multi_worker_mode():
......@@ -1031,18 +1024,6 @@ class Model(training_lib.Model):
"""
self._assert_compile_was_called()
self._check_call_args('train_on_batch')
if self._experimental_run_tf_function:
outputs = training_v2_utils.train_on_batch(
self, x, y=y, sample_weight=sample_weight,
class_weight=class_weight, reset_metrics=reset_metrics,
standalone=True)
outputs = (outputs['total_loss'] + outputs['output_losses'] +
outputs['metrics'])
outputs = [
training_v2_utils._non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access
if len(outputs) == 1:
outputs = outputs[0]
return outputs
# If at this point we are in the replica context, then it is okay to execute
# the Eager code path. The expected way to get here is to call `fit` that
......@@ -1069,8 +1050,7 @@ class Model(training_lib.Model):
output_loss_metrics=self._output_loss_metrics)
outputs = (output_dict['total_loss'] + output_dict['output_losses']
+ output_dict['metrics'])
outputs = [
training_v2_utils._non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access
outputs = [_non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access
else:
x = training_utils.ModelInputs(x).as_list()
ins = x + list(y or []) + list(sample_weights or [])
......@@ -1129,17 +1109,6 @@ class Model(training_lib.Model):
"""
self._assert_compile_was_called()
self._check_call_args('test_on_batch')
if self._experimental_run_tf_function:
outputs = training_v2_utils.test_on_batch(
self, x, y=y, sample_weight=sample_weight,
reset_metrics=reset_metrics, standalone=True)
outputs = (outputs['total_loss'] + outputs['output_losses'] +
outputs['metrics'])
outputs = [
training_v2_utils._non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access
if len(outputs) == 1:
outputs = outputs[0]
return outputs
if (self._distribution_strategy and
distribution_strategy_context.in_cross_replica_context()):
......@@ -1160,8 +1129,7 @@ class Model(training_lib.Model):
output_loss_metrics=self._output_loss_metrics)
outputs = (output_dict['total_loss'] + output_dict['output_losses']
+ output_dict['metrics'])
outputs = [
training_v2_utils._non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access
outputs = [_non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access
else:
x = training_utils.ModelInputs(x).as_list()
inputs = x + list(y or []) + list(sample_weights or [])
......@@ -1196,8 +1164,6 @@ class Model(training_lib.Model):
expectations of the model.
"""
self._check_call_args('predict_on_batch')
if self._experimental_run_tf_function:
return training_v2_utils.predict_on_batch(self, x, standalone=True)
if (self._distribution_strategy and
distribution_strategy_context.in_cross_replica_context()):
......@@ -2601,6 +2567,7 @@ class Model(training_lib.Model):
ValueError: If dict inputs are passed to a Sequential Model where the
first layer isn't FeatureLayer.
"""
self._set_save_spec(inputs)
inputs = self._set_input_attrs(inputs)
if outputs is None:
......@@ -2760,7 +2727,7 @@ class Model(training_lib.Model):
training setting, return the epoch the training is supposed to continue
at. Otherwise, return the `initial_epoch` the user passes in.
"""
if hasattr(self, '_training_state'):
if self._training_state is not None:
return self._training_state.maybe_load_initial_epoch_from_ckpt(
initial_epoch, mode)
return initial_epoch
......@@ -2781,7 +2748,7 @@ class Model(training_lib.Model):
# then the optimizer is set. This is different from whether the
# model is compiled
# (i.e. whether the model is built and its inputs/outputs are set).
if not self.optimizer:
if not self._compile_was_called:
raise RuntimeError('You must compile your model before '
'training/testing. '
'Use `model.compile(optimizer, loss)`.')
......@@ -2821,6 +2788,21 @@ class Model(training_lib.Model):
def _trackable_saved_model_saver(self):
return model_serialization.ModelSavedModelSaver(self)
def _get_compile_args(self):
self._assert_compile_was_called()
kwargs = {
'loss': self.loss,
'metrics': self._compile_metrics,
'loss_weights': self.loss_weights,
'sample_weight_mode': self.sample_weight_mode,
'weighted_metrics': self._compile_weighted_metrics,
}
return kwargs
@property
def _compile_was_called(self):
return self._v1_compile_was_called
class DistributedCallbackModel(Model):
"""Model that is used for callbacks with tf.distribute.Strategy."""
......@@ -3189,3 +3171,8 @@ def _get_metrics_from_layers(layers):
else:
metrics.extend(layer.metrics)
return metrics
def _non_none_constant_value(v):
constant_value = tensor_util.constant_value(v)
return constant_value if constant_value is not None else v
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册