提交 265de523 编写于 作者: S Scott Zhu 提交者: TensorFlower Gardener

Fix input mapping issue when model is constructed/tested with dict input tensor.

The mapping of the dict input tensors was not correct since it was still using the tensor name, rather than the key of the tensor when build the model. This cause the issue down the stream when the inputs are provided with unknown keys.

We had some backup logic, which will probably do correct things, eg just flat the dict to keep the original order, which was correct most of the case, but not very reliable. In this change, we make the behavior change:

1. When model is build with dict input tensors, the key of the tensor, instead of the name, will be used to map the tensor with input data.
2. Unknown keys in the input data will result into a warning. We didn't throw error since user might do it intentionally, eg using part of the model to test with full input data.

PiperOrigin-RevId: 317776370
Change-Id: I91983443f2b770cb0b45ddb7726f52708cb91d61
上级 e60cf089
......@@ -22,6 +22,7 @@ from __future__ import print_function
import collections
import copy
import itertools
import warnings
from six.moves import zip # pylint: disable=redefined-builtin
......@@ -131,10 +132,10 @@ class Functional(training_lib.Model):
# Models constructed with a single Tensor or list of Tensors can
# be called with a dict, where the keys of the dict are the names
# of the `Input` objects. Extra keys are ignored.
# of the `Input` objects. Extra keys are ignored with warning.
self._enable_dict_to_input_mapping = (
not nest.is_sequence(self._nested_inputs) or
(isinstance(self._nested_inputs, (list, tuple)) and
(isinstance(self._nested_inputs, (list, tuple, dict)) and
not any(nest.is_sequence(t) for t in self._nested_inputs)))
if any(not hasattr(tensor, '_keras_history') for tensor in self.outputs):
......@@ -524,10 +525,27 @@ class Functional(training_lib.Model):
ref_inputs = self._nested_inputs
if not nest.is_sequence(ref_inputs):
ref_inputs = [self._nested_inputs]
if isinstance(ref_inputs, dict):
# In the case that the graph is constructed with dict input tensors,
# We will use the original dict key to map with the keys in the input
# data. Note that the model.inputs is using nest.flatten to process the
# input tensors, which means the dict input tensors are ordered by their
# keys.
ref_input_names = sorted(ref_inputs.keys())
else:
ref_input_names = [inp._keras_history.layer.name for inp in ref_inputs]
# Raise an warning if there are more input data comparing to input tensor
if len(tensors) > len(ref_input_names):
warnings.warn(
'Input dict contained keys {} which did not match any model input. '
'They will be ignored by the model.'.format(
[n for n in tensors.keys() if n not in ref_input_names])
)
try:
# Flatten in the order `Input`s were passed during Model construction.
return [tensors[inp._keras_history.layer.name] for inp in ref_inputs]
return [tensors[n] for n in ref_input_names]
except KeyError:
# TODO(b/151582614)
return nest.flatten(tensors)
......
......@@ -18,8 +18,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import warnings
import numpy as np
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
......@@ -43,6 +46,7 @@ from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.platform import test
from tensorflow.python.training.tracking.util import Checkpoint
......@@ -1565,6 +1569,48 @@ class DefaultShapeInferenceBehaviorTest(keras_parameterized.TestCase):
self.assertEqual(config['layers'][2]['inbound_nodes'],
[[['in1', 0, 0, {}], ['in2', 0, 0, {}]]])
@combinations.generate(combinations.combine(mode=['eager']))
def test_dict_inputs_tensors(self):
# Note that this test is running with v2 eager only, since the v1
# will behave differently wrt to dict input for training.
inputs = {
'sentence2': input_layer_lib.Input(
shape=(), name='a', dtype=dtypes.string),
'sentence1': input_layer_lib.Input(
shape=(), name='b', dtype=dtypes.string),
}
strlen = layers.Lambda(string_ops.string_length_v2)
diff = layers.Subtract()(
[strlen(inputs['sentence1']), strlen(inputs['sentence2'])])
diff = math_ops.cast(diff, dtypes.float32)
model = training_lib.Model(inputs, diff)
extra_keys = {
'sentence1': constant_op.constant(['brown fox', 'lazy dog']),
'sentence2': constant_op.constant(['owl', 'cheeky cat']),
'label': constant_op.constant([0, 1]),
}
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
model(extra_keys)
self.assertIn('ignored by the model', str(w[-1].message))
model.compile('sgd', 'mse')
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
model.fit(extra_keys, y=constant_op.constant([0, 1]), steps_per_epoch=1)
self.assertIn('ignored by the model', str(w[-1].message))
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
model.evaluate(extra_keys, constant_op.constant([0, 1]))
self.assertIn('ignored by the model', str(w[-1].message))
# Make sure the model inputs are sorted with the dict keys.
self.assertEqual(model.inputs[0]._keras_history.layer.name, 'b')
self.assertEqual(model.inputs[1]._keras_history.layer.name, 'a')
class GraphUtilsTest(test.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册