提交 ee418c8e 编写于 作者: K Katherine Wu 提交者: TensorFlower Gardener

Add attribute to Keras model which generates an exportable tf.function....

Add attribute to Keras model which generates an exportable tf.function. SaveModel save now looks for this attribute when searching for a function to export.

PiperOrigin-RevId: 224861089
上级 841f5d9f
......@@ -342,6 +342,10 @@ class PolymorphicFunction(object):
"""The python function wrapped in this tf.function."""
return self._python_function
@property
def input_signature(self):
return self._input_signature
def get_initialization_function(self, *args, **kwargs):
"""Returns a `Function` object which initializes this function's variables.
......
......@@ -848,6 +848,7 @@ py_test(
deps = [
":keras",
"//tensorflow/python:client_testlib",
"//tensorflow/python/saved_model:save_test",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
......
......@@ -1539,8 +1539,7 @@ class Model(Network):
outputs = nest.flatten(outputs)
self.outputs = outputs
self.output_names = [
'output_%d' % (i + 1) for i in range(len(self.outputs))]
self.output_names = training_utils.generic_output_names(outputs)
self.built = True
def fit(self,
......@@ -2580,6 +2579,10 @@ class Model(Network):
batch_size = 32
return batch_size
@property
def _default_save_signature(self):
return training_utils.trace_model_call(self)
class DistributedCallbackModel(Model):
"""Model that is used for callbacks with DistributionStrategy."""
......
......@@ -27,9 +27,11 @@ import six
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import callbacks as cbks
......@@ -1191,3 +1193,61 @@ def get_static_batch_size(layer):
if batch_input_shape is not None:
return tensor_shape.as_dimension(batch_input_shape[0]).value
return None
def generic_output_names(outputs_list):
return ['output_%d' % (i + 1) for i in range(len(outputs_list))]
def trace_model_call(model, input_signature=None):
"""Trace the model call to create a tf.function for exporting a Keras model.
Args:
model: A Keras model.
input_signature: optional, a list of tf.TensorSpec objects specifying the
inputs to the model.
Returns:
A tf.function wrapping the model's call function with input signatures set.
Raises:
ValueError: if input signature cannot be inferred from the model.
"""
if input_signature is None:
if isinstance(model.call, def_function.PolymorphicFunction):
input_signature = model.call.input_signature
if input_signature is None:
try:
inputs = model.inputs
input_names = model.input_names
except AttributeError:
raise ValueError(
'Model {} cannot be saved because the input shapes have not been '
'set. Usually, input shapes are automatically determined from calling'
' .fit() or .predict(). To manually set the shapes, call '
'model._set_inputs(inputs).'.format(model))
input_specs = []
for input_tensor, input_name in zip(inputs, input_names):
input_specs.append(tensor_spec.TensorSpec(
shape=input_tensor.shape, dtype=input_tensor.dtype,
name=input_name))
# The input signature of the call function is a list with one element, since
# all tensor inputs must be passed in as the first argument.
input_signature = [input_specs] if len(input_specs) > 1 else input_specs
@def_function.function(input_signature=input_signature)
def _wrapped_model(*args):
"""A concrete tf.function that wraps the model's call function."""
# When given a single input, Keras models will call the model on the tensor
# rather than a list consisting of the single tensor.
inputs = args[0] if len(input_signature) == 1 else list(args)
outputs_list = nest.flatten(model(inputs=inputs))
try:
output_names = model.output_names
except AttributeError:
output_names = generic_output_names(outputs_list)
return {name: output for name, output in zip(output_names, outputs_list)}
return _wrapped_model
......@@ -18,13 +18,25 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
from tensorflow.python import keras
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
from tensorflow.python.saved_model import save as save_lib
from tensorflow.python.saved_model import save_test
class ModelInputsTest(test.TestCase):
......@@ -85,5 +97,150 @@ class ModelInputsTest(test.TestCase):
self.assertTrue(tf_utils.is_symbolic_tensor(vals['b']))
class TraceModelCallTest(keras_parameterized.TestCase):
def _assert_all_close(self, expected, actual):
if not context.executing_eagerly():
with self.cached_session() as sess:
K._initialize_variables(sess)
self.assertAllClose(expected, actual)
else:
self.assertAllClose(expected, actual)
@keras_parameterized.run_with_all_model_types
@keras_parameterized.run_all_keras_modes
def test_trace_model_outputs(self):
input_dim = 5 if testing_utils.get_model_type() == 'functional' else None
model = testing_utils.get_small_mlp(10, 3, input_dim)
inputs = array_ops.ones((8, 5))
if input_dim is None:
with self.assertRaisesRegexp(ValueError,
'input shapes have not been set'):
training_utils.trace_model_call(model)
model._set_inputs(inputs)
fn = training_utils.trace_model_call(model)
signature_outputs = fn(inputs)
expected_outputs = {model.output_names[0]: model(inputs)}
self._assert_all_close(expected_outputs, signature_outputs)
@keras_parameterized.run_with_all_model_types
@keras_parameterized.run_all_keras_modes
def test_trace_model_outputs_after_fitting(self):
input_dim = 5 if testing_utils.get_model_type() == 'functional' else None
model = testing_utils.get_small_mlp(10, 3, input_dim)
model.compile(optimizer='sgd', loss='mse')
model.fit(x=np.random.random((8, 5)),
y=np.random.random((8, 3)), epochs=2)
inputs = array_ops.ones((8, 5))
fn = training_utils.trace_model_call(model)
signature_outputs = fn(inputs)
expected_outputs = {model.output_names[0]: model(inputs)}
self._assert_all_close(expected_outputs, signature_outputs)
@keras_parameterized.run_with_all_model_types(exclude_models='sequential')
@keras_parameterized.run_all_keras_modes
def test_trace_multi_io_model_outputs(self):
input_dim = 5
num_classes = 3
num_classes_b = 4
input_a = keras.layers.Input(shape=(input_dim,), name='input_a')
input_b = keras.layers.Input(shape=(input_dim,), name='input_b')
dense = keras.layers.Dense(num_classes, name='dense')
dense2 = keras.layers.Dense(num_classes_b, name='dense2')
dropout = keras.layers.Dropout(0.5, name='dropout')
branch_a = [input_a, dense]
branch_b = [input_b, dense, dense2, dropout]
model = testing_utils.get_multi_io_model(branch_a, branch_b)
input_a_np = np.random.random((10, input_dim)).astype(np.float32)
input_b_np = np.random.random((10, input_dim)).astype(np.float32)
if testing_utils.get_model_type() == 'subclass':
with self.assertRaisesRegexp(ValueError,
'input shapes have not been set'):
training_utils.trace_model_call(model)
model.compile(optimizer='sgd', loss='mse')
model.fit(x=[np.random.random((8, input_dim)).astype(np.float32),
np.random.random((8, input_dim)).astype(np.float32)],
y=[np.random.random((8, num_classes)).astype(np.float32),
np.random.random((8, num_classes_b)).astype(np.float32)],
epochs=2)
fn = training_utils.trace_model_call(model)
signature_outputs = fn([input_a_np, input_b_np])
outputs = model([input_a_np, input_b_np])
expected_outputs = {model.output_names[0]: outputs[0],
model.output_names[1]: outputs[1]}
self._assert_all_close(expected_outputs, signature_outputs)
@keras_parameterized.run_all_keras_modes
def test_specify_input_signature(self):
model = testing_utils.get_small_sequential_mlp(10, 3, None)
inputs = array_ops.ones((8, 5))
with self.assertRaisesRegexp(ValueError, 'input shapes have not been set'):
training_utils.trace_model_call(model)
fn = training_utils.trace_model_call(
model, [tensor_spec.TensorSpec(shape=[None, 5], dtype=dtypes.float32)])
signature_outputs = fn(inputs)
expected_outputs = {model.output_names[0]: model(inputs)}
self._assert_all_close(expected_outputs, signature_outputs)
@keras_parameterized.run_all_keras_modes
def test_subclassed_model_with_input_signature(self):
class Model(keras.Model):
def __init__(self):
super(Model, self).__init__()
self.dense = keras.layers.Dense(3, name='dense')
@def_function.function(
input_signature=[[tensor_spec.TensorSpec([None, 5], dtypes.float32),
tensor_spec.TensorSpec([None], dtypes.float32)]],)
def call(self, inputs, *args):
x, y = inputs
return self.dense(x) + y
model = Model()
fn = training_utils.trace_model_call(model)
x = array_ops.ones((8, 5), dtype=dtypes.float32)
y = array_ops.ones((3,), dtype=dtypes.float32)
expected_outputs = {'output_1': model([x, y])}
signature_outputs = fn([x, y])
self._assert_all_close(expected_outputs, signature_outputs)
class ModelSaveTest(keras_parameterized.TestCase):
@keras_parameterized.run_with_all_model_types
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
def test_model_save(self):
input_dim = 5
model = testing_utils.get_small_mlp(10, 3, input_dim)
inputs = array_ops.ones((8, 5))
if testing_utils.get_model_type() == 'subclass':
model._set_inputs(inputs)
save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
save_lib.save(model, save_dir)
self.assertAllClose(
{model.output_names[0]: model.predict_on_batch(inputs)},
save_test._import_and_infer(save_dir,
{model.input_names[0]: np.ones((8, 5))}))
if __name__ == '__main__':
test.main()
......@@ -31,7 +31,6 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
......@@ -50,28 +49,7 @@ from tensorflow.python.util import compat
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
def _check_for_functional_keras_model(root):
"""Makes an export signature for `root` if it's a functional Keras Model."""
# If nothing is decorated yet but this is a functional Keras Model (duck
# typed), we'll try to make a signature ourselves.
try:
inputs = root.inputs
input_names = root.input_names
except AttributeError:
return None
input_signature = []
for input_tensor, input_name in zip(inputs, input_names):
input_signature.append(tensor_spec.TensorSpec(
shape=input_tensor.shape, dtype=input_tensor.dtype,
name=input_name))
@def_function.function(input_signature=input_signature)
def _wrapped_model(*args):
outputs_list = nest.flatten(root(inputs=list(args)))
return {name: output for name, output
in zip(root.output_names, outputs_list)}
return _wrapped_model
DEFAULT_SIGNATURE_ATTR = "_default_save_signature"
def _find_function_to_export(root):
......@@ -93,7 +71,7 @@ def _find_function_to_export(root):
exported_function = attribute_value
previous_attribute_name = attribute_name
if exported_function is None:
exported_function = _check_for_functional_keras_model(root)
exported_function = getattr(root, DEFAULT_SIGNATURE_ATTR, None)
if exported_function is None:
raise ValueError(
("Exporting an object with no tf.saved_model.save(..., signatures=...) "
......
......@@ -21,8 +21,6 @@ from __future__ import print_function
import os
import sys
import numpy
from tensorflow.python.client import session as session_lib
from tensorflow.python.eager import backprop
from tensorflow.python.eager import def_function
......@@ -32,12 +30,8 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.keras.engine import input_layer
from tensorflow.python.keras.engine import training
from tensorflow.python.keras.layers import core
from tensorflow.python.keras.layers import merge
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
......@@ -50,10 +44,9 @@ from tensorflow.python.training.checkpointable import tracking
from tensorflow.python.training.checkpointable import util
class _ModelWithOptimizer(training.Model):
class _ModelWithOptimizer(util.Checkpoint):
def __init__(self):
super(_ModelWithOptimizer, self).__init__()
self.dense = core.Dense(1)
self.optimizer = adam.AdamOptimizer(0.01)
......@@ -63,7 +56,7 @@ class _ModelWithOptimizer(training.Model):
def call(self, x, y):
with backprop.GradientTape() as tape:
loss = math_ops.reduce_mean((self.dense(x) - y) ** 2.)
trainable_variables = self.trainable_variables
trainable_variables = self.dense.trainable_variables
gradients = tape.gradient(loss, trainable_variables)
self.optimizer.apply_gradients(zip(gradients, trainable_variables))
return {"loss": loss}
......@@ -179,10 +172,10 @@ class SaveTest(test.TestCase):
x = constant_op.constant([[3., 4.]])
y = constant_op.constant([2.])
model = _ModelWithOptimizer()
first_loss = model(x, y)
first_loss = model.call(x, y)
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
save.save(model, save_dir, model.call)
second_loss = model(x, y)
second_loss = model.call(x, y)
self.assertNotEqual(first_loss, second_loss)
self.assertAllClose(
second_loss,
......@@ -197,7 +190,7 @@ class SaveTest(test.TestCase):
model = _ModelWithOptimizer()
x = constant_op.constant([[3., 4.]])
y = constant_op.constant([2.])
model(x, y)
model.call(x, y)
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
save.save(model, save_dir)
self.assertIn("loss",
......@@ -217,25 +210,40 @@ class SaveTest(test.TestCase):
model = _ModelWithOptimizer()
x = constant_op.constant([[3., 4.]])
y = constant_op.constant([2.])
model(x, y)
model.call(x, y)
model.second_function = def_function.function(lambda: 1.)
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
with self.assertRaisesRegexp(ValueError, "call.*second_function"):
save.save(model, save_dir)
def test_subclassed_no_signature(self):
def test_no_signature(self):
class Subclassed(training.Model):
class Model(util.Checkpoint):
def call(self, inputs):
return inputs * 2.
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
model = Subclassed()
model = Model()
with self.assertRaisesRegexp(
ValueError, "no @tf.function-decorated methods"):
save.save(model, save_dir)
def test_find_default_save_function(self):
class ObjWithDefaultSignature(util.Checkpoint):
@def_function.function(input_signature=[tensor_spec.TensorSpec(
shape=None, dtype=dtypes.float32)])
def _default_save_signature(self, x):
return x + x + 1
obj = ObjWithDefaultSignature()
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
save.save(obj, save_dir)
self.assertAllClose(
{"output_0": 7.}, _import_and_infer(save_dir, {"x": 3.}))
def test_docstring(self):
class Adder(util.Checkpoint):
......@@ -276,46 +284,6 @@ class SaveTest(test.TestCase):
self.assertNotIn("T", complex_node.attr)
self.assertNotIn("Tout", complex_node.attr)
def test_export_functional_keras_model(self):
x = input_layer.Input((4,), name="x")
y = core.Dense(4, name="out")(x)
model = training.Model(x, y)
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
save.save(model, save_dir)
self.assertAllClose(
{"out": model(array_ops.ones([1, 4]))},
_import_and_infer(save_dir, {"x": [[1., 1., 1., 1.]]}))
@test_util.run_v1_only("b/120545219")
def test_export_functional_keras_model_after_fit(self):
x = input_layer.Input((1,))
y = core.Dense(1, name="y")(x)
model = training.Model(x, y)
model.compile(optimizer="sgd", loss="mse")
model.fit(x=numpy.array([[1.]]),
y=numpy.array([2.]), epochs=2)
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
save.save(model, save_dir)
self.assertAllClose(
{"y": model(constant_op.constant([[1.], [2.]]))},
_import_and_infer(save_dir, {"input_1": [[1.], [2.]]}))
def test_export_multi_input_functional_keras_model(self):
x1 = input_layer.Input((2,), name="x1")
x2 = input_layer.Input((2,), name="x2")
y1 = core.Dense(4)(merge.Add()([x1, x2]))
y2 = core.Dense(4)(merge.Multiply()([x1, x2]))
model = training.Model([x1, x2], [y1, y2])
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
save.save(model, save_dir)
outputs = model([array_ops.ones([1, 2]), 2. * array_ops.ones([1, 2])])
self.assertAllClose(
{"dense": outputs[0], "dense_1": outputs[1]},
_import_and_infer(
save_dir,
{"x1": [[1., 1.]],
"x2": [[2., 2.]]}))
class AssetTests(test.TestCase):
......@@ -376,7 +344,7 @@ class MemoryTests(test.TestCase):
def test_no_reference_cycles(self):
x = constant_op.constant([[3., 4.]])
y = constant_op.constant([2.])
self._model(x, y)
self._model.call(x, y)
if sys.version_info[0] < 3:
# TODO(allenl): debug reference cycles in Python 2.x
self.skipTest("This test only works in Python 3+. Reference cycles are "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册