提交 9010c9e2 编写于 作者: E Edward Loper 提交者: TensorFlower Gardener

When converting ExtensionType inputs to match the expected dtype, only use...

When converting ExtensionType inputs to match the expected dtype, only use tf.cast if the value doesn't already have the expected dtype.  (Not all ExtensionTypes add dispatch handlers for the tf.cast method, so we should avoid calling it unless it's necessary.)

PiperOrigin-RevId: 395780095
上级 f4f65ef0
......@@ -643,9 +643,12 @@ class Functional(training_lib.Model):
tensor = tf.cast(tensor, dtype=ref_input.dtype)
elif tf_utils.is_extension_type(tensor):
# Dtype casting (If the extension type has a non-variant dtype and
# supports being cast)
# supports being cast). Only cast if necessary (since some extension
# types may not implement tf.cast).
tensor_dtype = getattr(tensor, 'dtype', None)
ref_input_dtype = getattr(ref_input, 'dtype', None)
if ref_input_dtype is not None and ref_input_dtype != tf.variant:
if (ref_input_dtype is not None and tensor_dtype is not None and
tensor_dtype != ref_input_dtype and ref_input_dtype != tf.variant):
tensor = tf.cast(tensor, dtype=ref_input_dtype)
return tensor
......
......@@ -14,11 +14,8 @@
#,============================================================================
"""Tests for layer graphs construction & handling."""
import tensorflow.compat.v2 as tf
import warnings
import numpy as np
from keras import backend
from keras import combinations
from keras import initializers
......@@ -34,7 +31,14 @@ from keras.engine import sequential
from keras.engine import training as training_lib
from keras.utils import layer_utils
from keras.utils import tf_utils
import numpy as np
import tensorflow.compat.v2 as tf
# pylint: disable=g-direct-tensorflow-import
from tensorflow.python.framework import extension_type
from tensorflow.python.training.tracking.util import Checkpoint
# pylint: enable=g-direct-tensorflow-import
class NetworkConstructionTest(keras_parameterized.TestCase):
......@@ -1168,6 +1172,36 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
# Check that second input was correctly added to first.
self.assertEqual(history.history['loss'][0], 0.0)
@combinations.generate(combinations.keras_mode_combinations())
def test_dont_cast_composite_unless_necessary(self):
if not tf.executing_eagerly():
return # Creating Keras inputs from a type_spec only supported in eager.
# TODO(edloper): Change this to tf.experimental.ExtensionTyep once
# it's been released.
class MyType(extension_type.ExtensionType):
# TODO(edloper) Remove _shape and _dtype once Keras has been switched
# to use .shape and .dtype instead.
value: tf.Tensor
_shape = property(lambda self: self.value.shape)
shape = property(lambda self: self.value.shape)
_dtype = property(lambda self: self.value.dtype)
dtype = property(lambda self: self.value.dtype)
class Spec:
_shape = property(lambda self: self.value.shape)
shape = property(lambda self: self.value.shape)
_dtype = property(lambda self: self.value.dtype)
dtype = property(lambda self: self.value.dtype)
my_spec = MyType.Spec(tf.TensorSpec([5], tf.float32))
input1 = input_layer_lib.Input(type_spec=my_spec)
model = training_lib.Model([input1], input1)
model.compile(run_eagerly=testing_utils.should_run_eagerly())
model(MyType([1., 2., 3., 4., 5.])) # Does not require cast.
with self.assertRaises((ValueError, TypeError)):
model(MyType([1, 2, 3, 4, 5]))
@combinations.generate(combinations.keras_mode_combinations())
def test_composite_call_kwarg_derived_from_keras_layer(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册