提交 e77b0dc6 编写于 作者: K Katherine Wu 提交者: TensorFlower Gardener

Add property that allows layers to specify that the input_spec can also be...

Add property that allows layers to specify that the input_spec can also be used as the layer call function's input_signature.

By default,all Keras exported layers will have this property return True, since Keras more rigidly defines the input_spec shape (compared to user-defined models, which may only have `ndims` set).

PiperOrigin-RevId: 339909813
Change-Id: I1b486747aa1e413e6f24f6809fb88846ad4712ab
上级 40c0e9dc
......@@ -83,6 +83,7 @@ from tensorflow.python.training.tracking import tracking
from tensorflow.python.util import compat
from tensorflow.python.util import nest
from tensorflow.python.util import object_identity
from tensorflow.python.util.tf_export import get_canonical_name_for_symbol
from tensorflow.python.util.tf_export import keras_export
from tensorflow.tools.docs import doc_controls
......@@ -349,13 +350,15 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
# Indicates whether `build` needs to be called upon layer call, to create
# the layer's weights.
self.built = False
# Provides information about which inputs are compatible with the layer.
self._input_spec = None
# SavedModel-related attributes.
# Record the build input shape for loading purposes.
# TODO(kathywu): Move this to Layer._set_save_spec once cl/290121460 is
# submitted.
self._build_input_shape = None
self._saved_model_inputs_spec = None
# Provides information about which inputs are compatible with the layer.
self._input_spec = None
# `Layer.compute_mask` will be called at the end of `Layer.__call__` if
# `Layer.compute_mask` is overridden, or if the `Layer` subclass sets
......@@ -3086,6 +3089,14 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
return (self._trackable_saved_model_saver
.list_functions_for_serialization(serialization_cache))
@property
def _use_input_spec_as_call_signature(self):
# Whether input spec can be used as the call signature when tracing the
# Layer for SavedModel. By default, this is set to `True` for layers
# exported from the Keras library, because the layers more rigidly define
# the `input_specs` property (many custom layers only set the `ndims`)
return get_canonical_name_for_symbol(type(self)) is not None
def __getstate__(self):
# Override to support `copy.deepcopy` and pickling.
# Thread-local objects cannot be copied in Python 3, so pop these.
......
......@@ -443,6 +443,16 @@ class RNN(Layer):
raise ValueError('RNNs with stateful=True not yet supported with '
'tf.distribute.Strategy.')
@property
def _use_input_spec_as_call_signature(self):
if self.unroll:
# When the RNN layer is unrolled, the time step shape cannot be unknown.
# The input spec does not define the time step (because this layer can be
# called with any time step value, as long as it is not None), so it
# cannot be used as the call function signature when saving to SavedModel.
return False
return super(RNN, self)._use_input_spec_as_call_signature
@property
def states(self):
if self._states is None:
......
......@@ -377,26 +377,26 @@ class LayerCallCollection(object):
if (isinstance(layer.call, def_function.Function) and
layer.call.input_signature is not None):
return layer.call.input_signature
elif isinstance(layer, training_lib.Model):
return saving_utils.model_input_signature(layer)
elif (layer.input_spec is not None and
layer._use_input_spec_as_call_signature): # pylint: disable=protected-access
def to_tensor_spec_or_none(x):
spec = input_spec.to_tensor_spec(x, layer._compute_dtype) # pylint: disable=protected-access
# If the shape is too general (e.g. multiple dimensions are allowed),
# return None so that separate functions can be generated for each
# inferred input signature.
# TODO(b/134962016): currently partial signatures are not supported.
if spec.shape == tensor_shape.TensorShape(None):
return None
return spec
input_signature = [nest.map_structure(
to_tensor_spec_or_none, layer.input_spec)]
return input_signature
else:
if isinstance(layer, training_lib.Model):
return saving_utils.model_input_signature(layer)
elif layer.input_spec is not None:
def to_tensor_spec_or_none(x):
spec = input_spec.to_tensor_spec(x, layer._compute_dtype) # pylint: disable=protected-access
# If the shape is too general (e.g. multiple dimensions are allowed),
# return None so that separate functions can be generated for each
# inferred input signature.
# TODO(b/134962016): currently partial signatures are not supported.
if spec.shape == tensor_shape.TensorShape(None):
return None
return spec
input_signature = [nest.map_structure(
to_tensor_spec_or_none, layer.input_spec)]
return input_signature
else:
return None
return None
def add_trace(self, *args, **kwargs):
"""Traces all functions with the same args and kwargs.
......
......@@ -83,6 +83,10 @@ class LayerWithLearningPhase(keras.engine.base_layer.Layer):
def compute_output_shape(self, input_shape):
return input_shape
@property
def _use_input_spec_as_call_signature(self):
return True
class LayerWithLoss(keras.layers.Layer):
......@@ -326,6 +330,10 @@ class TestSavedModelFormatAllModes(keras_parameterized.TestCase):
'a': keras.layers.InputSpec(max_ndim=3, axes={-1: 2}),
'b': keras.layers.InputSpec(shape=(None, 2, 3), dtype='float16')}
@property
def _use_input_spec_as_call_signature(self):
return True
layer = LayerWithNestedSpec()
saved_model_dir = self._save_model_dir()
tf_save.save(layer, saved_model_dir)
......@@ -737,8 +745,7 @@ class TestSavedModelFormatAllModes(keras_parameterized.TestCase):
predictions)
@parameterized.named_parameters([
# TODO(b/148491963): Unrolling does not work with SavedModel
# ('with_unrolling', True),
('with_unrolling', True),
('no_unrolling', False)
])
def testSaveStatefulRNN(self, unroll):
......@@ -882,6 +889,10 @@ class TestSavedModelFormat(test.TestCase):
def get_config(self):
return {}
@property
def _use_input_spec_as_call_signature(self):
return True
root = keras.models.Sequential()
root.add(keras.layers.Input(shape=(3,)))
root.attached_layer = DoNotTrace()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册