提交 a3777018 编写于 作者: P Philip Pham 提交者: TensorFlower Gardener

Retrace Keras history in `add_loss` and `add_metric`

In functional graph network models, custom losses and metrics can end up
disconnected from the network. We can retrace the graph and insert the
additional ancillary layers to support this use case.

In this way, the user can use custom losses and metrics computed by layers.

Fixes #30378.

PiperOrigin-RevId: 258843847
上级 827c1005
......@@ -1397,10 +1397,8 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
model.evaluate(inputs, targets)
@combinations.generate(
combinations.times(
all_strategy_minus_default_and_tpu_combinations() +
tpu_strategy_combinations(),
combinations.combine(cloning=[True, False])))
combinations.times(all_strategy_combinations_minus_default(),
combinations.combine(cloning=[True, False])))
def test_distribution_strategy_one_dimensional(self, distribution, cloning):
with distribution.scope():
inp = keras.layers.Input(shape=(10,))
......@@ -1464,7 +1462,7 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
1e-5)
@combinations.generate(
combinations.times(all_strategy_minus_default_and_tpu_combinations(),
combinations.times(all_strategy_combinations_minus_default(),
combinations.combine(cloning=[True, False])))
def test_distribution_strategy_with_symbolic_add_loss(self, distribution,
cloning):
......@@ -1636,6 +1634,7 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
self.assertAllClose(history.history, ds_history.history)
@combinations.generate(
# TODO(phillypham): Why does validation_steps > 1 not work on TPUs?
combinations.times(all_strategy_minus_default_and_tpu_combinations(),
combinations.combine(cloning=[True, False])))
def test_distribution_strategy_with_add_metric_outside_call(
......@@ -1680,7 +1679,8 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
@combinations.generate(
combinations.combine(
distribution=strategies_minus_default_minus_tpu, mode=['eager']))
distribution=strategies_minus_default_minus_tpu + tpu_strategies,
mode=['eager']))
def test_correctness_of_add_loss_with_merge_call(self, distribution):
batch_size = 32
......@@ -1745,5 +1745,110 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
train_step(x)
# Models to exercise inserting ancillary layers with add_loss and add_metric.
def _functional_with_add_loss_and_metric(input_shape, num_classes, l1, l2):
inputs = keras.Input(input_shape, name='images')
x = keras.layers.Conv2D(32, kernel_size=5, activation='relu')(inputs)
x = keras.layers.MaxPooling2D(pool_size=2)(x)
x = keras.layers.Conv2D(64, kernel_size=5, activation='relu')(x)
x = keras.layers.MaxPooling2D(pool_size=2)(x)
# Apply L2 regularization to embedding. Use a mix of TensorFlow ops and layers
# to exercise all code paths.
x = keras.layers.Flatten(name='embedding')(x)
l2_loss = math_ops.reduce_mean(math_ops.reduce_sum(math_ops.square(x), -1))
# Apply L1 regularization to next layer.
x = keras.layers.Dense(1024, activation='relu', name='sparse_embedding')(x)
l1_loss = keras.layers.Lambda(
lambda x: math_ops.reduce_mean(math_ops.reduce_sum(x, -1)),
name='l1_loss')(
x)
outputs = keras.layers.Dense(num_classes, name='logits')(x)
model = keras.Model(inputs=inputs, outputs=outputs)
# Weight regularization terms.
model.add_loss(keras.layers.Lambda(lambda x: x * l2)(l2_loss))
model.add_metric(l2_loss, aggregation='mean', name='l2_loss')
model.add_loss(l1_loss * l1)
model.add_metric(l1_loss, aggregation='mean', name='l1_loss')
return model
def _sequential_with_add_loss_and_metric(input_shape, num_classes, l1, l2):
model = keras.Sequential([
keras.layers.Conv2D(
32, kernel_size=5, activation='relu', input_shape=input_shape),
keras.layers.MaxPooling2D(pool_size=2),
keras.layers.Conv2D(64, kernel_size=5, activation='relu'),
keras.layers.MaxPooling2D(pool_size=2),
keras.layers.Flatten(name='embedding'),
keras.layers.Dense(1024, activation='relu', name='sparse_embedding'),
keras.layers.Dense(num_classes, name='logits'),
])
# Extract layer outputs, add regularization terms, and rescale the metric.
# Use a mix of TensorFlow ops and layers to exercise all code paths.
x = model.get_layer('sparse_embedding').get_output_at(-1)
l1_loss = l1 * math_ops.reduce_mean(math_ops.reduce_sum(x, -1))
model.add_loss(l1_loss)
model.add_metric(
keras.layers.Lambda(lambda x: math_ops.divide(x, l1))(l1_loss),
aggregation='mean',
name='l1_loss')
x = model.get_layer('embedding').get_output_at(-1)
l2_loss = keras.layers.Lambda(
lambda x: l2 * math_ops.reduce_mean(math_ops.reduce_sum(x * x, -1)),
name='l2_loss')(
x)
model.add_loss(l2_loss)
model.add_metric(l2_loss / l2, aggregation='mean', name='l2_loss')
return model
class TestDistributionStrategyWithMultipleAddLossAndMetricCalls(
test.TestCase, parameterized.TestCase):
"""Tests complex models with multiple add loss and metric calls."""
@combinations.generate(
combinations.times(
all_strategy_combinations_minus_default(),
combinations.combine(
model_fn=[
_functional_with_add_loss_and_metric,
_sequential_with_add_loss_and_metric,
],
l1=[0.01],
l2=[0.1])))
def test_fit_and_evaluate(self, distribution, model_fn, l1, l2):
# Make fake MNIST-like image data.
dataset = dataset_ops.DatasetV2.from_tensor_slices(
(np.random.uniform(size=(64, 28, 28, 1)).astype(np.float32),
np.random.randint(0, 10, size=(64,))))
dataset = dataset.shuffle(64).batch(
8 * distribution.num_replicas_in_sync, drop_remainder=True)
# Make model with distribution strategy and initialize with dataset shape.
input_shape = dataset_ops.get_structure(dataset)[0].shape[1:]
with distribution.scope():
model = model_fn(input_shape, 10, l1, l2)
model.compile(
optimizer=keras.optimizers.adam_v2.Adam(1e-4),
loss=keras.losses.SparseCategoricalCrossentropy(
from_logits=True,
reduction=loss_reduction.ReductionV2.SUM_OVER_BATCH_SIZE),
metrics=[
keras.metrics.SparseCategoricalAccuracy(),
keras.metrics.SparseCategoricalCrossentropy(from_logits=True),
])
# Non-eager training doesn't support steps_per_epoch=None.
for unused_epoch in range(2):
model.fit(dataset)
results = dict(zip(model.metrics_names, model.evaluate(dataset)))
# Sanity checks.
self.assertBetween(results['sparse_categorical_accuracy'], 0.05, 1.)
self.assertGreater(results['l2_loss'], 0.)
self.assertGreater(results['l1_loss'], 0.)
# Assert correctness of the loss calculation and updating of metrics.
self.assertNear(
results['l1_loss'] * l1 + results['l2_loss'] * l2 +
results['sparse_categorical_crossentropy'], results['loss'], 1e-6)
if __name__ == '__main__':
test.main()
......@@ -999,13 +999,7 @@ class Layer(module.Module):
else:
for symbolic_loss in symbolic_losses:
if getattr(self, '_is_graph_network', False):
new_layers = base_layer_utils.create_keras_history(symbolic_loss)
# Losses must be keyed on inputs no matter what in order to
# be supported in DistributionStrategy.
add_loss_layer = AddLoss(unconditional=False)
add_loss_layer(symbolic_loss)
new_layers.append(add_loss_layer)
self._insert_layers(new_layers)
self._graph_network_add_loss(symbolic_loss)
else:
# Possible a loss was added in a Layer's `build`.
self._losses.append(symbolic_loss)
......@@ -1092,11 +1086,7 @@ class Layer(module.Module):
'Tensor to monitor directly.')
# Insert layers into the Keras Graph Network.
new_layers = base_layer_utils.create_keras_history(value)
add_metric_layer = AddMetric(aggregation, name)
add_metric_layer(value)
new_layers.append(add_metric_layer)
self._insert_layers(new_layers)
self._graph_network_add_metric(value, aggregation, name)
@deprecation.deprecated_args(None, '`inputs` is now automatically inferred',
'inputs')
......
......@@ -1658,6 +1658,22 @@ class Network(base_layer.Layer):
def _object_identifier(self):
return '_tf_keras_network'
def _graph_network_add_loss(self, symbolic_loss):
new_layers = _diff_layers(self.inputs, [symbolic_loss], self._layers)
# Losses must be keyed on inputs no matter what in order to be supported in
# DistributionStrategy.
add_loss_layer = base_layer.AddLoss(unconditional=False)
add_loss_layer(symbolic_loss)
new_layers.append(add_loss_layer)
self._insert_layers(new_layers)
def _graph_network_add_metric(self, value, aggregation, name):
new_layers = _diff_layers(self.inputs, [value], self._layers)
add_metric_layer = base_layer.AddMetric(aggregation, name)
add_metric_layer(value)
new_layers.append(add_metric_layer)
self._insert_layers(new_layers)
def _is_hdf5_filepath(filepath):
return (filepath.endswith('.h5') or filepath.endswith('.keras') or
......@@ -1850,3 +1866,20 @@ def _map_graph_network(inputs, outputs):
str(all_names.count(name)) + ' times in the model. '
'All layer names should be unique.')
return network_nodes, nodes_by_depth, layers, layers_by_depth
def _diff_layers(inputs, outputs, layers):
"""Returns the layers in the network topology minus those in `layers`.
Args:
inputs: List of input tensors.
outputs: List of output tensors.
layers: List of layers.
Returns:
List of layers in the network topology not in `layers`.
"""
base_layer_utils.create_keras_history(outputs)
# List of all layers in the topology betweeen inputs and outputs.
all_layers = _map_graph_network(inputs, outputs)[2]
return [layer for layer in all_layers if layer not in layers]
......@@ -25,6 +25,7 @@ from tensorflow.python.keras import optimizers
from tensorflow.python.keras import saving
from tensorflow.python.keras.engine import sequential
from tensorflow.python.keras.engine import training
from tensorflow.python.keras.engine.base_layer import AddMetric
from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.engine.input_layer import Input
from tensorflow.python.keras.engine.input_layer import InputLayer
......@@ -54,6 +55,77 @@ def _clone_layer(layer):
return layer.__class__.from_config(layer.get_config())
def _insert_ancillary_layers(model, ancillary_layers, metrics_names, new_nodes):
"""Inserts ancillary layers into the model with the proper order."""
# Sort `AddMetric` layers so they agree with metrics_names.
metric_layers = [
layer for layer in ancillary_layers if isinstance(layer, AddMetric)
]
metric_layers.sort(key=lambda layer: metrics_names.index(layer.metric_name))
ancillary_layers = [
layer for layer in ancillary_layers if not isinstance(layer, AddMetric)
] + metric_layers
nodes = set(
nest.flatten([layer._inbound_nodes for layer in ancillary_layers]))
relevant_nodes = list(nodes.intersection(new_nodes))
model._insert_layers(ancillary_layers, relevant_nodes=relevant_nodes)
def _make_new_nodes(nodes_by_depth, layer_fn, layer_map, tensor_map):
"""Uses the layers in `layer_map` to make new nodes based on `nodes_by_depth`.
Args:
nodes_by_depth: Provides structure information to create new nodes.
layer_fn: Function to clone layers.
layer_map: Map from layers in `model` to new layers.
tensor_map: Map from tensors in `model` to newly compute tensors.
Returns:
A set of new nodes. `layer_map` and `tensor_map` are updated.
"""
# Iterated over every node in the reference model, in depth order.
new_nodes = set()
depth_keys = list(nodes_by_depth.keys())
depth_keys.sort(reverse=True)
for depth in depth_keys:
nodes = nodes_by_depth[depth]
for node in nodes:
# Recover the corresponding layer.
layer = node.outbound_layer
# Get or create layer.
if layer not in layer_map:
new_layer = layer_fn(layer)
layer_map[layer] = new_layer
layer = new_layer
else:
# Reuse previously cloned layer.
layer = layer_map[layer]
# Don't call InputLayer multiple times.
if isinstance(layer, InputLayer):
continue
# If all previous input tensors are available in tensor_map,
# then call node.inbound_layer on them.
if all(
tensor in tensor_map for tensor in nest.flatten(node.input_tensors)):
computed_tensors = nest.map_structure(lambda t: tensor_map[t],
node.input_tensors)
# Call layer.
kwargs = node.arguments or {}
output_tensors = layer(computed_tensors, **kwargs)
# Thread-safe way to keep track of what node was created.
first_output_tensor = nest.flatten(output_tensors)[0]
new_nodes.add(
layer._inbound_nodes[first_output_tensor._keras_history.node_index])
for x, y in zip(
nest.flatten(node.output_tensors), nest.flatten(output_tensors)):
tensor_map[x] = y
return new_nodes
def _clone_functional_model(model, input_tensors=None, layer_fn=_clone_layer):
"""Clone a functional `Model` instance.
......@@ -137,48 +209,9 @@ def _clone_functional_model(model, input_tensors=None, layer_fn=_clone_layer):
if not callable(layer_fn):
raise ValueError('Expected `layer_fn` argument to be a callable.')
new_nodes = set()
# Iterated over every node in the reference model, in depth order.
depth_keys = list(model._nodes_by_depth.keys())
depth_keys.sort(reverse=True)
for depth in depth_keys:
nodes = model._nodes_by_depth[depth]
for node in nodes:
# Recover the corresponding layer.
layer = node.outbound_layer
# Get or create layer.
if layer not in layer_map:
new_layer = layer_fn(layer)
layer_map[layer] = new_layer
layer = new_layer
else:
# Reuse previously cloned layer.
layer = layer_map[layer]
# Don't call InputLayer multiple times.
if isinstance(layer, InputLayer):
continue
# If all previous input tensors are available in tensor_map,
# then call node.inbound_layer on them.
if all(
tensor in tensor_map for tensor in nest.flatten(node.input_tensors)):
computed_tensors = nest.map_structure(lambda t: tensor_map[t],
node.input_tensors)
# Call layer.
kwargs = node.arguments or {}
output_tensors = layer(computed_tensors, **kwargs)
# Thread-safe way to keep track of what node was created.
first_output_tensor = nest.flatten(output_tensors)[0]
new_nodes.add(
layer._inbound_nodes[first_output_tensor._keras_history.node_index])
for x, y in zip(
nest.flatten(node.output_tensors), nest.flatten(output_tensors)):
tensor_map[x] = y
# Has the side effect of filling out `layer_map` and `tensor_map`.
new_nodes = _make_new_nodes(model._nodes_by_depth, layer_fn, layer_map,
tensor_map)
# Check that we did compute the model outputs,
# then instantiate a new model from inputs and outputs.
output_tensors = []
......@@ -188,20 +221,47 @@ def _clone_functional_model(model, input_tensors=None, layer_fn=_clone_layer):
input_tensors = nest.pack_sequence_as(model._nested_inputs, input_tensors)
output_tensors = nest.pack_sequence_as(model._nested_outputs, output_tensors)
metrics_names = model.metrics_names
model = Model(input_tensors, output_tensors, name=model.name)
# Layers not directly tied to outputs of the Model, such as loss layers
# created in `add_loss`.
# created in `add_loss` and `add_metric`.
ancillary_layers = [
layer for layer in layer_map.values() if layer not in model.layers
]
if ancillary_layers:
nodes = set(
nest.flatten([layer._inbound_nodes for layer in ancillary_layers]))
relevant_nodes = list(nodes.intersection(new_nodes))
model._insert_layers(ancillary_layers, relevant_nodes=relevant_nodes)
_insert_ancillary_layers(model, ancillary_layers, metrics_names, new_nodes)
return model
def _remove_ancillary_layers(model, layer_map, layers):
"""Removes and returns any ancillary layers from `layers` based on `model`.
Ancillary layers are part of the model topology but not used to compute the
model outputs, e.g., layers from `add_loss` and `add_metric`.
Args:
model: A Keras Model.
layer_map: A map to from layers in the `model` to those in `layers`.
layers: A list of all layers.
Returns:
Two lists of layers: (1) `layers` with the ancillary layers removed, and (2)
the ancillary layers.
"""
ancillary_layers = [] # Additional layers for computing losses and metrics.
if not model._is_graph_network:
return layers, ancillary_layers
# Ancillary layers are those with depth < 0.
depths = [depth for depth in model._nodes_by_depth.keys() if depth < 0]
depths.sort(reverse=True) # Order topologically from inputs to outputs.
for depth in depths:
for node in model._nodes_by_depth[depth]:
ancillary_layers.append(layer_map[node.outbound_layer])
return [l for l in layers if l not in ancillary_layers], ancillary_layers
def _clone_sequential_model(model, input_tensors=None, layer_fn=_clone_layer):
"""Clone a `Sequential` model instance.
......@@ -238,45 +298,73 @@ def _clone_sequential_model(model, input_tensors=None, layer_fn=_clone_layer):
if not callable(layer_fn):
raise ValueError('Expected `layer_fn` argument to be a callable.')
layers = [] # Layers needed to compute the model's outputs.
layer_map = {}
# Use model._layers to ensure that all layers are cloned. The model's layers
# property will exclude the initial InputLayer (if it exists) in the model,
# resulting in a different Sequential model structure.
for layer in model._layers:
if isinstance(layer, InputLayer) and input_tensors is not None:
# If input tensors are provided, the original model's InputLayer is
# overwritten with a different InputLayer.
continue
cloned_layer = (
_clone_layer(layer)
if isinstance(layer, InputLayer) else layer_fn(layer))
layers.append(cloned_layer)
layer_map[layer] = cloned_layer
layers, ancillary_layers = _remove_ancillary_layers(model, layer_map, layers)
if input_tensors is None:
layers = []
for layer in model._layers:
if isinstance(layer, InputLayer):
layers.append(_clone_layer(layer))
else:
layers.append(layer_fn(layer))
return Sequential(layers=layers, name=model.name)
cloned_model = Sequential(layers=layers, name=model.name)
elif len(generic_utils.to_list(input_tensors)) != 1:
raise ValueError('To clone a `Sequential` model, we expect '
' at most one tensor '
'as part of `input_tensors`.')
else:
# If input tensors are provided, the original model's InputLayer is
# overwritten with a different InputLayer.
layers = [
layer_fn(layer)
for layer in model._layers
if not isinstance(layer, InputLayer)
]
if len(generic_utils.to_list(input_tensors)) != 1:
raise ValueError('To clone a `Sequential` model, we expect '
' at most one tensor '
'as part of `input_tensors`.')
# Overwrite the original model's input layer.
if isinstance(input_tensors, tuple):
input_tensors = list(input_tensors)
x = generic_utils.to_list(input_tensors)[0]
if K.is_keras_tensor(x):
origin_layer = x._keras_history.layer
if isinstance(origin_layer, InputLayer):
return Sequential(layers=[origin_layer] + layers, name=model.name)
cloned_model = Sequential(
layers=[origin_layer] + layers, name=model.name)
else:
raise ValueError('Cannot clone a `Sequential` model on top '
'of a tensor that comes from a Keras layer '
'other than an `InputLayer`. '
'Use the functional API instead.')
input_tensor = Input(tensor=x, name='input_wrapper_for_' + str(x.name))
input_layer = input_tensor._keras_history.layer
return Sequential(layers=[input_layer] + layers, name=model.name)
else:
input_tensor = Input(tensor=x, name='input_wrapper_for_' + str(x.name))
input_layer = input_tensor._keras_history.layer
cloned_model = Sequential(layers=[input_layer] + layers, name=model.name)
if not ancillary_layers:
return cloned_model
tensor_map = {} # Maps tensors from `model` to those in `cloned_model`.
for depth, cloned_nodes in cloned_model._nodes_by_depth.items():
nodes = model._nodes_by_depth[depth]
# This should be safe in a Sequential model. In an arbitrary network, you
# need to sort using the outbound layer of the node as a key.
for cloned_node, node in zip(cloned_nodes, nodes):
if isinstance(cloned_node.output_tensors, list):
for j, output_tensor in enumerate(cloned_node.output_tensors):
tensor_map[node.output_tensors[j]] = output_tensor
else:
tensor_map[node.output_tensors] = cloned_node.output_tensors
# Ancillary nodes have negative depth.
new_nodes = _make_new_nodes(
{
depth: nodes
for depth, nodes in model._nodes_by_depth.items()
if depth < 0
}, layer_fn, layer_map, tensor_map)
_insert_ancillary_layers(cloned_model, ancillary_layers, model.metrics_names,
new_nodes)
return cloned_model
@keras_export('keras.models.clone_model')
......
......@@ -815,11 +815,9 @@ class TestWholeModelSaving(test.TestCase):
def _make_model():
inputs = keras.Input(shape=(4,))
x = keras.layers.Dense(8, activation='relu')(inputs)
y = keras.layers.Dense(3, activation='softmax')(x)
custom_loss = keras.layers.Lambda(lambda x: keras.backend.sum(x * x))(x)
# Connect the loss to the network.
outputs = keras.layers.Lambda(lambda x: x[0])((y, custom_loss))
outputs = keras.layers.Dense(3, activation='softmax')(x)
model = keras.Model(inputs=inputs, outputs=outputs)
custom_loss = keras.layers.Lambda(lambda x: keras.backend.sum(x * x))(x)
model.add_loss(custom_loss)
model.add_metric(custom_loss, aggregation='mean', name='custom_loss')
return model
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册