提交 c9b5ba76 编写于 作者: K Katherine Wu 提交者: TensorFlower Gardener

Add `Model.save_spec` property to get the model's call argument TensorSpecs.

This change also enables models with multiple input arguments to be saved, for example:

```
class Subclass(keras.Model):

  def call(self, a, b):
    ...
```

PiperOrigin-RevId: 381148805
上级 f26e2d64
......@@ -316,6 +316,10 @@ tf_class {
name: "save"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], "
}
member_method {
name: "save_spec"
argspec: "args=[\'self\', \'dynamic_batch\'], varargs=None, keywords=None, defaults=[\'True\'], "
}
member_method {
name: "save_weights"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
......
......@@ -326,6 +326,10 @@ tf_class {
name: "save"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], "
}
member_method {
name: "save_spec"
argspec: "args=[\'self\', \'dynamic_batch\'], varargs=None, keywords=None, defaults=[\'True\'], "
}
member_method {
name: "save_weights"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
......
......@@ -317,6 +317,10 @@ tf_class {
name: "save"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], "
}
member_method {
name: "save_spec"
argspec: "args=[\'self\', \'dynamic_batch\'], varargs=None, keywords=None, defaults=[\'True\'], "
}
member_method {
name: "save_weights"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
......
......@@ -317,6 +317,10 @@ tf_class {
name: "save"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], "
}
member_method {
name: "save_spec"
argspec: "args=[\'self\', \'dynamic_batch\'], varargs=None, keywords=None, defaults=[\'True\'], "
}
member_method {
name: "save_weights"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
......
......@@ -316,6 +316,10 @@ tf_class {
name: "save"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], "
}
member_method {
name: "save_spec"
argspec: "args=[\'self\', \'dynamic_batch\'], varargs=None, keywords=None, defaults=[\'True\'], "
}
member_method {
name: "save_weights"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
......
......@@ -326,6 +326,10 @@ tf_class {
name: "save"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], "
}
member_method {
name: "save_spec"
argspec: "args=[\'self\', \'dynamic_batch\'], varargs=None, keywords=None, defaults=[\'True\'], "
}
member_method {
name: "save_weights"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
......
......@@ -316,6 +316,10 @@ tf_class {
name: "save"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], "
}
member_method {
name: "save_spec"
argspec: "args=[\'self\', \'dynamic_batch\'], varargs=None, keywords=None, defaults=[\'True\'], "
}
member_method {
name: "save_weights"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
......
......@@ -326,6 +326,10 @@ tf_class {
name: "save"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], "
}
member_method {
name: "save_spec"
argspec: "args=[\'self\', \'dynamic_batch\'], varargs=None, keywords=None, defaults=[\'True\'], "
}
member_method {
name: "save_weights"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
......
......@@ -317,6 +317,10 @@ tf_class {
name: "save"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], "
}
member_method {
name: "save_spec"
argspec: "args=[\'self\', \'dynamic_batch\'], varargs=None, keywords=None, defaults=[\'True\'], "
}
member_method {
name: "save_weights"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
......
......@@ -317,6 +317,10 @@ tf_class {
name: "save"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], "
}
member_method {
name: "save_spec"
argspec: "args=[\'self\', \'dynamic_batch\'], varargs=None, keywords=None, defaults=[\'True\'], "
}
member_method {
name: "save_weights"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
......
......@@ -316,6 +316,10 @@ tf_class {
name: "save"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], "
}
member_method {
name: "save_spec"
argspec: "args=[\'self\', \'dynamic_batch\'], varargs=None, keywords=None, defaults=[\'True\'], "
}
member_method {
name: "save_weights"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
......
......@@ -326,6 +326,10 @@ tf_class {
name: "save"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], "
}
member_method {
name: "save_spec"
argspec: "args=[\'self\', \'dynamic_batch\'], varargs=None, keywords=None, defaults=[\'True\'], "
}
member_method {
name: "save_weights"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
......
......@@ -341,6 +341,7 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
# submitted.
self._build_input_shape = None
self._saved_model_inputs_spec = None
self._saved_model_arg_spec = None
# `Layer.compute_mask` will be called at the end of `Layer.__call__` if
# `Layer.compute_mask` is overridden, or if the `Layer` subclass sets
......@@ -963,7 +964,6 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
# - input_spec compatibility is only checked against `inputs`
# - mixed precision casting (autocast) is only applied to `inputs`,
# not to any other argument.
# - setting the SavedModel saving spec.
inputs, args, kwargs = self._split_out_first_arg(args, kwargs)
input_list = tf.nest.flatten(inputs)
......@@ -1041,7 +1041,7 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
if self._supports_masking:
self._set_mask_metadata(inputs, outputs, input_masks, not eager)
if self._saved_model_inputs_spec is None:
self._set_save_spec(inputs)
self._set_save_spec(inputs, args, kwargs)
return outputs
......@@ -3015,20 +3015,54 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
# SavedModel properties. Please see keras/saving/saved_model for details.
@tf.__internal__.tracking.no_automatic_dependency_tracking
def _set_save_spec(self, inputs):
def _set_save_spec(self, inputs, args=None, kwargs=None):
"""Defines the save spec so that serialization is able to trace layer call.
The TensorSpecs of the call function `inputs`, `args`, and `kwargs` are
saved into a tuple of `([inputs] + args, kwargs)`.
Args:
inputs: possibly nested inputs passed into the call function.
args: a list of positional arguments passed into call.
kwargs: a dictionary of keyword arguments passed into call.
"""
if self._saved_model_inputs_spec is not None:
return # Already set.
args = args or []
kwargs = kwargs or {}
inputs_spec = tf.nest.map_structure(tf_utils.get_tensor_spec, inputs)
# Filter out non-tensor arguments from args and kwargs.
args_spec = []
for arg in args:
flat_arg = tf.nest.flatten(arg)
flat_specs = [tf_utils.get_tensor_spec(x) for x in flat_arg]
if any(s is None for s in flat_specs):
break # Stop recording positional args once a non-tensor has been found
args_spec.append(tf.nest.pack_sequence_as(arg, flat_specs))
kwargs_spec = {}
for key, kwarg in kwargs.items():
if key == 'training':
continue
flat_kwarg = tf.nest.flatten(kwarg)
flat_specs = [tf_utils.get_tensor_spec(x) for x in flat_kwarg]
if any(s is None for s in flat_specs):
continue
kwargs[key] = args_spec.append(
tf.nest.pack_sequence_as(kwarg, flat_specs))
self._saved_model_inputs_spec = tf.nest.map_structure(tf_utils.get_tensor_spec,
inputs)
self._saved_model_inputs_spec = inputs_spec
self._saved_model_arg_spec = ([inputs_spec] + args_spec, kwargs_spec)
def _get_save_spec(self, dynamic_batch=True):
def _get_save_spec(self, dynamic_batch=True, inputs_only=True):
if self._saved_model_inputs_spec is None:
return None
return tf.nest.map_structure(
spec = tf.nest.map_structure(
lambda t: tf_utils.get_tensor_spec(t, dynamic_batch=dynamic_batch),
self._saved_model_inputs_spec)
self._saved_model_arg_spec)
return spec[0][0] if inputs_only else spec
@property
def _trackable_saved_model_saver(self):
......
......@@ -799,6 +799,7 @@ class Layer(base_layer.Layer):
# TODO(b/120997007): This should be done in Eager as well, but
# causes garbage collection issues because of the placeholders
# created on the default Keras graph.
self._set_save_spec(inputs, args, kwargs)
self._set_inputs(inputs, outputs)
else:
# Eager execution on data tensors.
......
......@@ -863,12 +863,12 @@ class Functional(training_lib.Model):
def _trackable_saved_model_saver(self):
return network_serialization.NetworkSavedModelSaver(self)
def _get_save_spec(self, dynamic_batch=True):
def _get_save_spec(self, dynamic_batch=True, inputs_only=True):
if getattr(self, '_has_explicit_input_shape', True):
# Functional models and Sequential models that have an explicit input
# shape should use the batch size set by the input layer.
dynamic_batch = False
return super(Functional, self)._get_save_spec(dynamic_batch)
return super(Functional, self)._get_save_spec(dynamic_batch, inputs_only)
def _make_node_key(layer_name, node_index):
......
......@@ -286,6 +286,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
# Fault-tolerance handler. Set in `ModelCheckpoint`.
self._training_state = None
self._saved_model_inputs_spec = None
self._saved_model_arg_spec = None
self._trackable_saver = saver_with_op_caching(self)
self._steps_per_execution = None
......@@ -2569,28 +2570,81 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
raise ValueError('Provide either a layer name or layer index.')
@tf.__internal__.tracking.no_automatic_dependency_tracking
def _set_save_spec(self, inputs):
def _set_save_spec(self, inputs, args=None, kwargs=None):
"""Defines the save spec so that serialization is able to trace model call.
The TensorSpecs of the call function `inputs`, `args`, and `kwargs` are
saved into a tuple of `([inputs] + args, kwargs)`. The input `TensorSpec`
names are updated to match the built `input_names`.
The specs can be retrieved with the `save_spec` property.
Args:
inputs: possibly nested inputs passed into the call function.
args: a list of positional arguments passed into call.
kwargs: a dictionary of keyword arguments passed into call.
"""
if self._saved_model_inputs_spec is not None:
return # Already set.
args = args or []
kwargs = kwargs or {}
input_names = self.input_names
if not input_names:
input_names = compile_utils.create_pseudo_input_names(inputs)
flat_inputs = tf.nest.flatten(inputs)
specs = []
inputs_spec = []
for name, tensor in zip(input_names, flat_inputs):
specs.append(
inputs_spec.append(
tf_utils.get_tensor_spec(tensor, dynamic_batch=False, name=name))
specs = tf.nest.pack_sequence_as(inputs, specs)
self._saved_model_inputs_spec = specs
inputs_spec = tf.nest.pack_sequence_as(inputs, inputs_spec)
super(Model, self)._set_save_spec(inputs_spec, args, kwargs)
# Store the input shapes
if (self.__class__.__name__ == 'Sequential' and
self._build_input_shape is None):
self._build_input_shape = tf.nest.map_structure(
lambda x: None if x is None else x.shape, specs)
lambda x: None if x is None else x.shape, inputs_spec)
def save_spec(self, dynamic_batch=True):
"""Returns the `tf.TensorSpec` of call inputs as a tuple `(args, kwargs)`.
This value is automatically defined after calling the model for the first
time. Afterwards, you can use it when exporting the model for serving:
```python
model = tf.keras.Model(...)
@tf.function
def serve(*args, **kwargs):
outputs = model(*args, **kwargs)
# Apply postprocessing steps, or add additional outputs.
...
return outputs
# arg_specs is `[tf.TensorSpec(...), ...]`. kwarg_specs, in this example, is
# an empty dict since functional models do not use keyword arguments.
arg_specs, kwarg_specs = model.save_spec()
model.save(path, signatures={
'serving_default': serve.get_concrete_function(*arg_specs, **kwarg_specs)
})
```
Args:
dynamic_batch: Whether to set the batch sizes of all the returned
`tf.TensorSpec` to `None`. (Note that when defining functional or
Sequential models with `tf.keras.Input([...], batch_size=X)`, the
batch size will always be preserved). Defaults to `True`.
Returns:
If the model inputs are defined, returns a tuple `(args, kwargs)`. All
elements in `args` and `kwargs` are `tf.TensorSpec`.
If the model inputs are not defined, returns `None`.
The model inputs are automatically set when calling the model,
`model.fit`, `model.evaluate` or `model.predict`.
"""
return self._get_save_spec(dynamic_batch, inputs_only=False)
def _assert_weights_created(self):
"""Asserts that all the weights for the model have been created.
......
......@@ -128,6 +128,7 @@ def load(path, compile=True, options=None): # pylint: disable=redefined-builtin
# When there are no Keras objects, return the results from the core loader
return tf.saved_model.load(path, options=options)
metadata = _update_to_current_version(metadata)
# Recreate layers and metrics using the info stored in the metadata.
keras_loader = KerasObjectLoader(metadata, object_graph_def)
keras_loader.load_layers(compile=compile)
......@@ -173,6 +174,21 @@ def load(path, compile=True, options=None): # pylint: disable=redefined-builtin
return model
def _update_to_current_version(metadata):
"""Applies version updates to the metadata proto for backwards compat."""
for node in metadata.nodes:
if node.version.producer == 1 and node.identifier in [
constants.MODEL_IDENTIFIER, constants.SEQUENTIAL_IDENTIFIER,
constants.NETWORK_IDENTIFIER]:
node_metadata = json_utils.decode(node.metadata)
save_spec = node_metadata.get('save_spec')
if save_spec is not None:
node_metadata['full_save_spec'] = ([save_spec], {})
node.metadata = json_utils.Encoder().encode(node_metadata)
return metadata
def _read_legacy_metadata(object_graph_def, metadata):
"""Builds a KerasMetadata proto from the SavedModel ObjectGraphDef."""
# Older SavedModels store the metadata directly in the proto instead of the
......@@ -529,9 +545,11 @@ class KerasObjectLoader(object):
# Restore model save spec for subclassed models. (layers do not store a
# SaveSpec)
if isinstance(obj, training_lib.Model):
save_spec = metadata.get('save_spec')
if save_spec is not None:
obj._set_save_spec(save_spec)
full_save_spec = metadata.get('full_save_spec')
if full_save_spec is not None:
args_spec, kwargs_spec = full_save_spec
inputs_spec = args_spec.pop(0)
obj._set_save_spec(inputs_spec, args_spec, kwargs_spec)
# pylint: enable=protected-access
build_input_shape = metadata.get('build_input_shape')
......@@ -816,10 +834,22 @@ def _finalize_saved_model_layers(layers):
if not call_fn.concrete_functions:
continue
if call_fn.input_signature is None:
inputs = infer_inputs_from_restored_call_function(call_fn)
args, kwargs = infer_inputs_from_restored_call_function(call_fn)
args = list(args)
inputs = args.pop(0)
else:
inputs = call_fn.input_signature[0]
layer._set_inputs(inputs) # pylint: disable=protected-access
args = call_fn.input_signature
args = list(args)
inputs = args.pop(0)
kwargs = None
layer._set_save_spec(inputs, args, kwargs) # pylint: disable=protected-access
# V1 models require calling _set_inputs to set the `.inputs` attr.
# Skip this step when there are multiple tensor inputs (this behavior
# is not well supported in V1 models).
if not any(isinstance(x, tf.TensorSpec)
for x in tf.nest.flatten([args, kwargs])):
layer._set_inputs(inputs)
# 3. Add losses that aren't generated by the layer.call function.
_restore_layer_unconditional_losses(layer)
......@@ -1124,9 +1154,13 @@ def infer_inputs_from_restored_call_function(fn):
one concrete function and that the inputs are in the first argument.
Returns:
TensorSpec of call function inputs.
TensorSpec of call function inputs in the form of (args, kwargs)
"""
def common_spec(x, y):
if not isinstance(x, tf.TensorSpec):
# Doesn't particularly matter what is returned in this case because the
# result will be filtered out in _set_input_shape.
return x
common_shape = get_common_shape(x.shape, y.shape)
if isinstance(x, tf.SparseTensorSpec):
return tf.SparseTensorSpec(common_shape, x.dtype)
......@@ -1134,9 +1168,9 @@ def infer_inputs_from_restored_call_function(fn):
return tf.RaggedTensorSpec(common_shape, x.dtype)
return tf.TensorSpec(common_shape, x.dtype, x.name)
spec = fn.concrete_functions[0].structured_input_signature[0][0]
spec = fn.concrete_functions[0].structured_input_signature
for concrete in fn.concrete_functions[1:]:
spec2 = concrete.structured_input_signature[0][0]
spec2 = concrete.structured_input_signature
spec = tf.nest.map_structure(common_spec, spec, spec2)
return spec
......
......@@ -32,7 +32,10 @@ class ModelSavedModelSaver(layer_serialization.LayerSavedModelSaver):
# Network stateful property is dependent on the child layers.
metadata.pop('stateful')
metadata['is_graph_network'] = self.obj._is_graph_network # pylint: disable=protected-access
metadata['save_spec'] = self.obj._get_save_spec(dynamic_batch=False) # pylint: disable=protected-access
spec = self.obj.save_spec(dynamic_batch=False)
metadata['full_save_spec'] = spec
# save_spec is saved for forward compatibility on older TF versions.
metadata['save_spec'] = None if spec is None else spec[0][0]
metadata.update(
saving_utils.model_metadata(
......
......@@ -200,7 +200,7 @@ class ReviveTestBase(keras_parameterized.TestCase):
self.evaluate(revived.weights))
input_arr = tf.constant(
np.random.random((2, 2, 3)).astype(np.float32))
if isinstance(revived._saved_model_inputs_spec,
if isinstance(revived.save_spec()[0][0],
tf.SparseTensorSpec):
input_arr = tf.sparse.from_dense(input_arr)
......
......@@ -118,7 +118,7 @@ def generate_keras_metadata(saved_nodes, node_paths):
node_id=node_id,
node_path=node_path,
version=versions_pb2.VersionDef(
producer=1, min_consumer=1, bad_consumers=[]),
producer=2, min_consumer=1, bad_consumers=[]),
identifier=node._object_identifier, # pylint: disable=protected-access
metadata=node._tracking_metadata) # pylint: disable=protected-access
......
......@@ -193,9 +193,7 @@ def wrap_layer_functions(layer, serialization_cache):
with base_layer_utils.call_context().enter(
layer, inputs=None, build_graph=True, training=None, saving=True):
for fn in fns.values():
if fn is not None and fn.input_signature is not None:
if isinstance(fn, LayerCall):
fn = fn.wrapped_call
if fn is not None and not isinstance(fn, LayerCall):
fn.get_concrete_function()
# Restore overwritten functions and losses
......@@ -208,7 +206,6 @@ def wrap_layer_functions(layer, serialization_cache):
def default_save_signature(layer):
original_losses = _reset_layer_losses(layer)
fn = saving_utils.trace_model_call(layer)
fn.get_concrete_function()
_restore_layer_losses(original_losses)
return fn
......@@ -395,38 +392,31 @@ class LayerCallCollection(object):
self._training_arg_index = utils.get_training_arg_index(
self.layer_call_method)
# If the layer call function has kwargs, then the traced function cannot
# have an input signature.
arg_spec = tf_inspect.getfullargspec(self.layer_call_method)
self._has_kwargs = bool(self._expects_training_arg or
arg_spec.defaults or
arg_spec.kwonlyargs or
arg_spec.varkw)
self._input_signature = self._generate_input_signature(layer)
self._layer_inputs = self._get_layer_inputs(layer)
self._functions = weakref.WeakValueDictionary()
# Get the input argument name from the args.
arg_spec = tf_inspect.getfullargspec(self.layer_call_method)
args = arg_spec.args
if tf_inspect.ismethod(self.layer_call_method):
args = args[1:]
self._input_arg_name = args[0] if args else 'inputs'
def _generate_input_signature(self, layer):
def _get_layer_inputs(self, layer):
"""Inspects layer object and returns the inferred input signature.
Args:
layer: Layer object.
Returns:
List of possibly nested TensorSpecs of the layer call function inputs.
The list does not contain the `training` argument.
List of possibly nested TensorSpecs of the layer call function inputs in
the form of `(args, kwargs)`
"""
if (isinstance(layer.call, tf.__internal__.function.Function) and
layer.call.input_signature is not None):
return layer.call.input_signature
return layer.call.input_signature, {}
elif isinstance(layer, training_lib.Model):
return saving_utils.model_input_signature(layer)
return saving_utils.model_call_inputs(layer)
elif (layer.input_spec is not None and
layer._use_input_spec_as_call_signature): # pylint: disable=protected-access
......@@ -437,14 +427,14 @@ class LayerCallCollection(object):
# inferred input signature.
# TODO(b/134962016): currently partial signatures are not supported.
if spec.shape == tf.TensorShape(None):
return None
return None, None
return spec
input_signature = [tf.nest.map_structure(
to_tensor_spec_or_none, layer.input_spec)]
return input_signature
return input_signature, {}
else:
return None
return None, None
def add_trace(self, *args, **kwargs):
"""Traces all functions with the same args and kwargs.
......@@ -469,18 +459,6 @@ class LayerCallCollection(object):
else:
add_trace_to_queue(fn, args, kwargs)
@property
def fn_input_signature(self):
"""Returns input signature for the wrapped layer call function."""
if self._has_kwargs:
# Input signatures may only describe tensor arguments and kwargs are not
# supported.
return None
if None in tf.nest.flatten(self._input_signature):
# TODO(b/134962016): If input signature cannot be partially defined.
return None
return self._input_signature
def training_arg_was_passed(self, args, kwargs):
if not self.layer._expects_training_arg and self._expects_training_arg: # pylint: disable=protected-access
return (utils.get_training_arg(self._training_arg_index, args, kwargs)
......@@ -554,17 +532,17 @@ class LayerCallCollection(object):
fn = LayerCall(
self,
self._maybe_wrap_with_training_arg(call_fn, match_layer_training_arg),
name,
input_signature=self.fn_input_signature)
name)
self._functions[name] = fn.wrapped_call
return fn
def trace_with_input_signature(self):
"""Trace with the layer/models inferred input signature if possible."""
if (None not in tf.nest.flatten(self._input_signature) and self._has_kwargs):
if None not in tf.nest.flatten(self._layer_inputs):
# Manually add traces for layers that have keyword arguments and have
# a fully defined input signature.
self.add_trace(*self._input_signature)
args, kwargs = self._layer_inputs
self.add_trace(*args, **kwargs)
def _filtered_inputs(inputs):
......@@ -606,7 +584,7 @@ def layer_call_wrapper(call_collection, method, name):
class LayerCall(object):
"""Function that triggers traces of other functions in the same collection."""
def __init__(self, call_collection, call_fn, name, input_signature):
def __init__(self, call_collection, call_fn, name):
"""Initializes a LayerCall object.
Args:
......@@ -615,13 +593,10 @@ class LayerCall(object):
functions should be traced with the same arguments.
call_fn: A call function.
name: Name of the call function.
input_signature: Input signature of call_fn (can be None).
"""
self.call_collection = call_collection
self.input_signature = input_signature
self.wrapped_call = tf.function(
layer_call_wrapper(call_collection, call_fn, name),
input_signature=input_signature)
layer_call_wrapper(call_collection, call_fn, name))
self.original_layer_call = call_collection.layer_call_method
def _maybe_trace(self, args, kwargs):
......
......@@ -38,6 +38,9 @@ from keras import keras_parameterized
from keras import regularizers
from keras import testing_utils
from keras.feature_column.dense_features import DenseFeatures
from keras.protobuf import saved_metadata_pb2
from keras.protobuf import versions_pb2
from keras.saving.saved_model import json_utils
from keras.saving.saved_model import load as keras_load
from keras.saving.saved_model import save_impl as keras_save
from keras.utils import control_flow_util
......@@ -879,35 +882,32 @@ class TestSavedModelFormatAllModes(keras_parameterized.TestCase):
self.evaluate(tf.compat.v1.variables_initializer(loaded.variables))
self.assertAllClose(model.predict(f), loaded.predict(f))
def testSaveLayerMultipleInputs(self):
def testSaveMultipleInputs(self):
class CustomLayer(keras.layers.Layer):
def call(self, *input_list):
self.add_loss(input_list[-2] * 2, inputs=True)
return sum(input_list[:-1]) # The test's last input is a non-tensor arg
# TODO(b/175902133): Models only support one input argument. Also, create a
# subclassed model because functional/sequential models still have funky
# behavior when calling with multiple non-nested arguments.
class CustomModel(keras.Model):
def build(self, _):
self.layer = CustomLayer()
def call(self, inputs):
inputs = inputs[:]
def call(self, *inputs):
inputs = list(inputs)
inputs.append(object()) # Test that the layer handles non-tensor inputs
return self.layer(*inputs)
model = CustomModel()
inp = [tf.constant(i, shape=[1, 1], dtype=tf.float32)
for i in range(1, 5)]
expected = model(inp)
expected = model(*inp)
expected_loss = model.get_losses_for(inp)
saved_model_dir = self._save_model_dir()
model.save(saved_model_dir, save_format='tf')
loaded = keras_load.load(saved_model_dir)
actual = loaded(inp)
actual = loaded(*inp)
actual_loss = loaded.get_losses_for(inp)
self.assertAllEqual(self.evaluate(expected),
self.evaluate(actual))
......@@ -1321,5 +1321,25 @@ class MetricTest(tf.test.TestCase, parameterized.TestCase):
self.evaluate([v.initializer for v in loaded.variables])
loaded.fit(x, y)
class TestUpdateMetadata(tf.test.TestCase):
def testAddFullSaveSpec(self):
save_spec = tf.TensorSpec([3, 5], dtype=tf.int32)
node_metadata = json_utils.Encoder().encode({'save_spec': save_spec})
metadata = saved_metadata_pb2.SavedMetadata()
metadata.nodes.add(
version=versions_pb2.VersionDef(
producer=1, min_consumer=1, bad_consumers=[]),
identifier='_tf_keras_model',
metadata=node_metadata) # pylint: disable=protected-access
new_metadata = keras_load._update_to_current_version(metadata)
node_metadata = json_utils.decode(new_metadata.nodes[0].metadata)
expected_full_spec = ([tf.TensorSpec(shape=(3, 5), dtype=tf.int32)], {})
self.assertAllEqual(expected_full_spec, node_metadata.get('full_save_spec'))
if __name__ == '__main__':
tf.test.main()
......@@ -191,7 +191,9 @@ class SerializedAttributes(object):
if key in function_dict:
if (function_dict[key] is not None and # Not all functions are required
not isinstance(function_dict[key],
(tf.__internal__.function.Function, save_impl.LayerCall))):
(tf.__internal__.function.Function,
tf.types.experimental.ConcreteFunction,
save_impl.LayerCall))):
raise ValueError(
'Function dictionary contained a non-function object: {} (for key'
' {})'.format(function_dict[key], key))
......
......@@ -14,9 +14,9 @@
# ==============================================================================
"""Utils related to keras model saving."""
# pylint: disable=g-bad-import-order, g-direct-tensorflow-import
import tensorflow.compat.v2 as tf
import collections
import copy
import os
from keras import backend as K
......@@ -28,6 +28,7 @@ from keras.utils import generic_utils
from keras.utils import version_utils
from keras.utils.io_utils import ask_to_proceed_with_overwrite
from tensorflow.python.platform import tf_logging as logging
# pylint: enable=g-bad-import-order, g-direct-tensorflow-import
def extract_model_metrics(model):
......@@ -50,7 +51,7 @@ def extract_model_metrics(model):
return None
def model_input_signature(model, keep_original_batch_size=False):
def model_call_inputs(model, keep_original_batch_size=False):
"""Inspect model to get its input signature.
The model's input signature is a list with a single (possibly-nested) object.
......@@ -68,21 +69,15 @@ def model_input_signature(model, keep_original_batch_size=False):
`None`.
Returns:
A list containing either a single TensorSpec or an object with nested
TensorSpecs. This list does not contain the `training` argument.
A tuple containing `(args, kwargs)` TensorSpecs of the model call function
inputs.
`kwargs` does not contain the `training` argument.
"""
input_specs = model._get_save_spec(dynamic_batch=not keep_original_batch_size) # pylint: disable=protected-access
input_specs = model.save_spec(dynamic_batch=not keep_original_batch_size)
if input_specs is None:
return None
return None, None
input_specs = _enforce_names_consistency(input_specs)
# Return a list with a single element as the model's input signature.
if isinstance(input_specs,
collections.abc.Sequence) and len(input_specs) == 1:
# Note that the isinstance check filters out single-element dictionaries,
# which should also be wrapped as a single-element list.
return input_specs
else:
return [input_specs]
return input_specs
def raise_model_input_error(model):
......@@ -111,22 +106,23 @@ def trace_model_call(model, input_signature=None):
if isinstance(model.call, tf.__internal__.function.Function):
input_signature = model.call.input_signature
if input_signature is None:
input_signature = model_input_signature(model)
if input_signature:
model_args = input_signature
model_kwargs = {}
else:
model_args, model_kwargs = model_call_inputs(model)
input_signature = model_args # store
if input_signature is None:
raise_model_input_error(model)
if model_args is None:
raise_model_input_error(model)
@tf.function(input_signature=input_signature)
def _wrapped_model(*args):
@tf.function
def _wrapped_model(*args, **kwargs):
"""A concrete tf.function that wraps the model's call function."""
# When given a single input, Keras models will call the model on the tensor
# rather than a list consisting of the single tensor.
inputs = args[0] if len(input_signature) == 1 else list(args)
kwargs['training'] = False
with base_layer_utils.call_context().enter(
model, inputs=inputs, build_graph=False, training=False, saving=True):
outputs = model(inputs, training=False)
model, inputs=None, build_graph=False, training=False, saving=True):
outputs = model(*args, **kwargs)
# Outputs always has to be a flat dict.
output_names = model.output_names # Functional Model.
......@@ -136,7 +132,7 @@ def trace_model_call(model, input_signature=None):
outputs = tf.nest.flatten(outputs)
return {name: output for name, output in zip(output_names, outputs)}
return _wrapped_model
return _wrapped_model.get_concrete_function(*model_args, **model_kwargs)
def model_metadata(model, include_optimizer=True, require_config=True):
......
......@@ -104,8 +104,10 @@ class TraceModelCallTest(keras_parameterized.TestCase):
model = testing_utils.get_multi_io_model(branch_a, branch_b)
input_a_np = np.random.random((10, input_dim)).astype(np.float32)
input_b_np = np.random.random((10, input_dim)).astype(np.float32)
input_a_ts = tf.constant(
np.random.random((10, input_dim)).astype(np.float32))
input_b_ts = tf.constant(
np.random.random((10, input_dim)).astype(np.float32))
if testing_utils.get_model_type() == 'subclass':
with self.assertRaisesRegex(ValueError, 'input shapes have not been set'):
......@@ -122,8 +124,15 @@ class TraceModelCallTest(keras_parameterized.TestCase):
epochs=2)
fn = saving_utils.trace_model_call(model)
signature_outputs = fn([input_a_np, input_b_np])
outputs = model([input_a_np, input_b_np])
# tf.function requires that the input structures match when calling a
# ConcreteFunction. For some reason V1 models defines the inputs as a list,
# while V2 models sets the inputs as a tuple.
if (not tf.executing_eagerly() and
testing_utils.get_model_type() != 'functional'):
signature_outputs = fn([input_a_ts, input_b_ts])
else:
signature_outputs = fn((input_a_ts, input_b_ts))
outputs = model([input_a_ts, input_b_ts])
if model.output_names:
expected_outputs = {
model.output_names[0]: outputs[0],
......@@ -140,7 +149,7 @@ class TraceModelCallTest(keras_parameterized.TestCase):
model_input = {'x': tf.constant([[1.]])}
model.predict(model_input, steps=1)
fn = saving_utils.trace_model_call(model)
self.assertAllClose({'output_1': [[1.]]}, fn({'x': [[1.]]}))
self.assertAllClose({'output_1': [[1.]]}, fn(model_input))
columns = [
tf.feature_column.numeric_column('x'),
......@@ -151,8 +160,7 @@ class TraceModelCallTest(keras_parameterized.TestCase):
'y': tf.constant([[2.]])}
model.predict(model_input, steps=1)
fn = saving_utils.trace_model_call(model)
self.assertAllClose({'output_1': [[1., 2.]]},
fn({'x': [[1.]], 'y': [[2.]]}))
self.assertAllClose({'output_1': [[1., 2.]]}, fn(model_input))
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_specify_input_signature(self):
......@@ -222,7 +230,7 @@ class TraceModelCallTest(keras_parameterized.TestCase):
train_step(x, y)
fn = saving_utils.trace_model_call(model)
self.assertEqual(fn.input_signature[0].shape.as_list(),
self.assertEqual(fn.structured_input_signature[0][0].shape.as_list(),
tf.TensorShape([None, 5]).as_list())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册