未验证 提交 2646d230 编写于 作者: G Goldie Gadde 提交者: GitHub

Merge pull request #32669 from tensorflow/ggadde-cp-19

[r2.0-CherryPick]:[tf.data] Avoid double conversion to a tensor during input normalizat…
...@@ -31,6 +31,7 @@ from tensorflow.python.framework import tensor_spec ...@@ -31,6 +31,7 @@ from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import type_spec from tensorflow.python.framework import type_spec
from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import deprecation from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
...@@ -83,24 +84,31 @@ def normalize_element(element): ...@@ -83,24 +84,31 @@ def normalize_element(element):
components = nest.flatten(element) components = nest.flatten(element)
normalized_components = [] normalized_components = []
with ops.name_scope("normalize_element"): with ops.name_scope("normalize_element"):
# Imported here to avoid circular dependency # Imported here to avoid circular dependency.
from tensorflow.python.data.ops import dataset_ops # pylint: disable=g-import-not-at-top from tensorflow.python.data.ops import dataset_ops # pylint: disable=g-import-not-at-top
for i, t in enumerate(components): for i, t in enumerate(components):
spec = type_spec_from_value(t) try:
if isinstance(spec, sparse_tensor.SparseTensorSpec): spec = type_spec_from_value(t, use_fallback=False)
normalized_components.append(sparse_tensor.SparseTensor.from_value(t)) except TypeError:
elif isinstance(spec, ragged_tensor.RaggedTensorSpec): # TypeError indicates it was not possible to compute a `TypeSpec` for
normalized_components.append( # the value. As a fallback try converting the value to a tensor.
ragged_tensor.convert_to_tensor_or_ragged_tensor(
t, name="component_%d" % i))
elif isinstance(
spec, (tensor_array_ops.TensorArraySpec, dataset_ops.DatasetSpec)):
normalized_components.append(t)
elif isinstance(t, composite_tensor.CompositeTensor):
normalized_components.append(t)
else:
normalized_components.append( normalized_components.append(
ops.convert_to_tensor(t, name="component_%d" % i)) ops.convert_to_tensor(t, name="component_%d" % i))
else:
if isinstance(spec, sparse_tensor.SparseTensorSpec):
normalized_components.append(sparse_tensor.SparseTensor.from_value(t))
elif isinstance(spec, ragged_tensor.RaggedTensorSpec):
normalized_components.append(
ragged_tensor.convert_to_tensor_or_ragged_tensor(
t, name="component_%d" % i))
elif isinstance(
spec, (tensor_array_ops.TensorArraySpec, dataset_ops.DatasetSpec)):
normalized_components.append(t)
elif isinstance(t, composite_tensor.CompositeTensor):
normalized_components.append(t)
else:
normalized_components.append(
ops.convert_to_tensor(t, name="component_%d" % i))
return nest.pack_sequence_as(element, normalized_components) return nest.pack_sequence_as(element, normalized_components)
...@@ -392,11 +400,13 @@ def are_compatible(spec1, spec2): ...@@ -392,11 +400,13 @@ def are_compatible(spec1, spec2):
return True return True
def type_spec_from_value(element): def type_spec_from_value(element, use_fallback=True):
"""Creates a type specification for the given value. """Creates a type specification for the given value.
Args: Args:
element: The element to create the type specification for. element: The element to create the type specification for.
use_fallback: Whether to fall back to converting the element to a tensor
in order to compute its `TypeSpec`.
Returns: Returns:
A nested structure of `TypeSpec`s that represents the type specification A nested structure of `TypeSpec`s that represents the type specification
...@@ -432,14 +442,16 @@ def type_spec_from_value(element): ...@@ -432,14 +442,16 @@ def type_spec_from_value(element):
# `element` is not a namedtuple # `element` is not a namedtuple
return tuple([type_spec_from_value(v) for v in element]) return tuple([type_spec_from_value(v) for v in element])
# Fallback: try converting value to a tensor. if use_fallback:
try: # As a fallback try converting the element to a tensor.
tensor = ops.convert_to_tensor(element) try:
spec = type_spec_from_value(tensor) tensor = ops.convert_to_tensor(element)
if spec is not None: spec = type_spec_from_value(tensor)
return spec if spec is not None:
except (ValueError, TypeError): return spec
pass except (ValueError, TypeError) as e:
logging.vlog(
3, "Failed to convert %r to tensor: %s" % (type(element).__name__, e))
raise TypeError("Could not build a TypeSpec for %r with type %s" % raise TypeError("Could not build a TypeSpec for %r with type %s" %
(element, type(element).__name__)) (element, type(element).__name__))
...@@ -26,6 +26,7 @@ from tensorflow.python import pywrap_tensorflow ...@@ -26,6 +26,7 @@ from tensorflow.python import pywrap_tensorflow
from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat from tensorflow.python.util import compat
from tensorflow.python.util import nest from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_decorator
...@@ -483,8 +484,9 @@ def type_spec_from_value(value): ...@@ -483,8 +484,9 @@ def type_spec_from_value(value):
spec = _type_spec_from_value(tensor) spec = _type_spec_from_value(tensor)
if spec is not None: if spec is not None:
return spec return spec
except (ValueError, TypeError): except (ValueError, TypeError) as e:
pass logging.vlog(
3, "Failed to convert %r to tensor: %s" % (type(value).__name__, e))
raise TypeError("Could not build a TypeSpec for %r with type %s" % raise TypeError("Could not build a TypeSpec for %r with type %s" %
(value, type(value).__name__)) (value, type(value).__name__))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册