diff --git a/tensorflow/python/data/util/structure.py b/tensorflow/python/data/util/structure.py index 645d95bdd68bc3d1aa5e944ea94d8ec01037684c..ce82428620060bc3ff8518c16ad90166d7f2b377 100644 --- a/tensorflow/python/data/util/structure.py +++ b/tensorflow/python/data/util/structure.py @@ -31,6 +31,7 @@ from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import type_spec from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops.ragged import ragged_tensor +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import deprecation from tensorflow.python.util.tf_export import tf_export @@ -83,24 +84,31 @@ def normalize_element(element): components = nest.flatten(element) normalized_components = [] with ops.name_scope("normalize_element"): - # Imported here to avoid circular dependency + # Imported here to avoid circular dependency. from tensorflow.python.data.ops import dataset_ops # pylint: disable=g-import-not-at-top for i, t in enumerate(components): - spec = type_spec_from_value(t) - if isinstance(spec, sparse_tensor.SparseTensorSpec): - normalized_components.append(sparse_tensor.SparseTensor.from_value(t)) - elif isinstance(spec, ragged_tensor.RaggedTensorSpec): - normalized_components.append( - ragged_tensor.convert_to_tensor_or_ragged_tensor( - t, name="component_%d" % i)) - elif isinstance( - spec, (tensor_array_ops.TensorArraySpec, dataset_ops.DatasetSpec)): - normalized_components.append(t) - elif isinstance(t, composite_tensor.CompositeTensor): - normalized_components.append(t) - else: + try: + spec = type_spec_from_value(t, use_fallback=False) + except TypeError: + # TypeError indicates it was not possible to compute a `TypeSpec` for + # the value. As a fallback try converting the value to a tensor. normalized_components.append( ops.convert_to_tensor(t, name="component_%d" % i)) + else: + if isinstance(spec, sparse_tensor.SparseTensorSpec): + normalized_components.append(sparse_tensor.SparseTensor.from_value(t)) + elif isinstance(spec, ragged_tensor.RaggedTensorSpec): + normalized_components.append( + ragged_tensor.convert_to_tensor_or_ragged_tensor( + t, name="component_%d" % i)) + elif isinstance( + spec, (tensor_array_ops.TensorArraySpec, dataset_ops.DatasetSpec)): + normalized_components.append(t) + elif isinstance(t, composite_tensor.CompositeTensor): + normalized_components.append(t) + else: + normalized_components.append( + ops.convert_to_tensor(t, name="component_%d" % i)) return nest.pack_sequence_as(element, normalized_components) @@ -392,11 +400,13 @@ def are_compatible(spec1, spec2): return True -def type_spec_from_value(element): +def type_spec_from_value(element, use_fallback=True): """Creates a type specification for the given value. Args: element: The element to create the type specification for. + use_fallback: Whether to fall back to converting the element to a tensor + in order to compute its `TypeSpec`. Returns: A nested structure of `TypeSpec`s that represents the type specification @@ -432,14 +442,16 @@ def type_spec_from_value(element): # `element` is not a namedtuple return tuple([type_spec_from_value(v) for v in element]) - # Fallback: try converting value to a tensor. - try: - tensor = ops.convert_to_tensor(element) - spec = type_spec_from_value(tensor) - if spec is not None: - return spec - except (ValueError, TypeError): - pass + if use_fallback: + # As a fallback try converting the element to a tensor. + try: + tensor = ops.convert_to_tensor(element) + spec = type_spec_from_value(tensor) + if spec is not None: + return spec + except (ValueError, TypeError) as e: + logging.vlog( + 3, "Failed to convert %r to tensor: %s" % (type(element).__name__, e)) raise TypeError("Could not build a TypeSpec for %r with type %s" % (element, type(element).__name__)) diff --git a/tensorflow/python/framework/type_spec.py b/tensorflow/python/framework/type_spec.py index ffc93b06c671b081994a1b375a43c19d5596d88f..e462fd4ca47fd4b8c7847c86d3ee786b1d6f0df6 100644 --- a/tensorflow/python/framework/type_spec.py +++ b/tensorflow/python/framework/type_spec.py @@ -26,6 +26,7 @@ from tensorflow.python import pywrap_tensorflow from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import compat from tensorflow.python.util import nest from tensorflow.python.util import tf_decorator @@ -483,8 +484,9 @@ def type_spec_from_value(value): spec = _type_spec_from_value(tensor) if spec is not None: return spec - except (ValueError, TypeError): - pass + except (ValueError, TypeError) as e: + logging.vlog( + 3, "Failed to convert %r to tensor: %s" % (type(value).__name__, e)) raise TypeError("Could not build a TypeSpec for %r with type %s" % (value, type(value).__name__))