From e42a66b26372845ea41aea6c12bc3e4f09f1efc5 Mon Sep 17 00:00:00 2001 From: Dan Moldovan Date: Fri, 22 Feb 2019 14:05:33 -0800 Subject: [PATCH] Fix rewrap to always return the updated entity. Use it as such at the autograph integration point. In turn, temporarily disable autograph for dataset functions until we reduce the conversion overhead. Fixes #25159. With thanks to @ktaebum for identifying this fix. PiperOrigin-RevId: 235255035 --- tensorflow/python/data/ops/dataset_ops.py | 3 +++ .../data/ops/multi_device_iterator_ops.py | 25 ++++++++++++++----- tensorflow/python/eager/def_function.py | 2 +- tensorflow/python/framework/func_graph.py | 3 ++- tensorflow/python/util/tf_decorator.py | 14 ++++++++++- 5 files changed, 38 insertions(+), 9 deletions(-) diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 91a07a5534d..3f50738f17f 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -2366,6 +2366,8 @@ class StructuredFunctionWrapper(object): else: defun_kwargs.update({"func_name": func_name}) + # TODO(b/124254153): Enable autograph once the overhead is low enough. + # TODO(mdan): Make sure autograph recurses into _wrapper_helper when on. @eager_function.defun_with_attributes( input_signature=[ tensor_spec.TensorSpec(input_shape, input_type) # pylint: disable=g-complex-comprehension @@ -2373,6 +2375,7 @@ class StructuredFunctionWrapper(object): self._input_structure._flat_shapes, self._input_structure._flat_types) ], + autograph=False, attributes=defun_kwargs) def wrapper_fn(*args): # pylint: disable=missing-docstring ret = _wrapper_helper(*args) diff --git a/tensorflow/python/data/ops/multi_device_iterator_ops.py b/tensorflow/python/data/ops/multi_device_iterator_ops.py index 36e1e311d39..f253f8bf42a 100644 --- a/tensorflow/python/data/ops/multi_device_iterator_ops.py +++ b/tensorflow/python/data/ops/multi_device_iterator_ops.py @@ -42,13 +42,15 @@ class _PerDeviceGenerator(dataset_ops.DatasetV2): gen_dataset_ops.multi_device_iterator_to_string_handle( multi_device_iterator_resource)) - @function.defun() + # TODO(b/124254153): Enable autograph once the overhead is low enough. + @function.defun(autograph=False) # Pure graph code. def _init_func(): return multi_device_iterator_string_handle init_func_concrete = _init_func._get_concrete_function_internal() # pylint: disable=protected-access - @function.defun() + # TODO(b/124254153): Enable autograph once the overhead is low enough. + @function.defun(autograph=False) # Pure graph code. def _remote_init_func(): return functional_ops.remote_call( target=source_device, @@ -59,7 +61,10 @@ class _PerDeviceGenerator(dataset_ops.DatasetV2): self._init_func = _remote_init_func._get_concrete_function_internal() # pylint: disable=protected-access self._init_captured_args = self._init_func.captured_inputs - @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) + # TODO(b/124254153): Enable autograph once the overhead is low enough. + @function.defun( + input_signature=[tensor_spec.TensorSpec([], dtypes.string)], + autograph=False) # Pure graph code. def _next_func(string_handle): # pylint: disable=protected-access multi_device_iterator = ( @@ -76,9 +81,11 @@ class _PerDeviceGenerator(dataset_ops.DatasetV2): next_func_concrete = _next_func._get_concrete_function_internal() # pylint: disable=protected-access + # TODO(b/124254153): Enable autograph once the overhead is low enough. @function.defun_with_attributes( input_signature=[tensor_spec.TensorSpec([], dtypes.string)], - attributes={"experimental_ints_on_device": True}) + attributes={"experimental_ints_on_device": True}, + autograph=False) # Pure graph code. def _remote_next_func(string_handle): return functional_ops.remote_call( target=source_device, @@ -94,13 +101,19 @@ class _PerDeviceGenerator(dataset_ops.DatasetV2): if arg == incarnation_id: self._incarnation_id_index = i - @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) + # TODO(b/124254153): Enable autograph once the overhead is low enough. + @function.defun( + input_signature=[tensor_spec.TensorSpec([], dtypes.string)], + autograph=False) # Pure graph code. def _finalize_func(unused_string_handle): return array_ops.constant(0, dtypes.int64) finalize_func_concrete = _finalize_func._get_concrete_function_internal() # pylint: disable=protected-access - @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) + # TODO(b/124254153): Enable autograph once the overhead is low enough. + @function.defun( + input_signature=[tensor_spec.TensorSpec([], dtypes.string)], + autograph=False) # Pure graph code. def _remote_finalize_func(string_handle): return functional_ops.remote_call( target=source_device, diff --git a/tensorflow/python/eager/def_function.py b/tensorflow/python/eager/def_function.py index 1d54973487c..330b2e0a763 100644 --- a/tensorflow/python/eager/def_function.py +++ b/tensorflow/python/eager/def_function.py @@ -514,7 +514,7 @@ class Function(object): """Make and call a `ConcreteFunction` which initializes variables.""" # Note: using defun here avoids an infinite recursion. - @function_lib.defun + @function_lib.defun(autograph=False) # Pure graph code. def initialize_variables(): for v, init in initializer_map.items(): with ops.init_scope(): diff --git a/tensorflow/python/framework/func_graph.py b/tensorflow/python/framework/func_graph.py index 9097a8dd1f0..aee8d4c6584 100644 --- a/tensorflow/python/framework/func_graph.py +++ b/tensorflow/python/framework/func_graph.py @@ -633,7 +633,8 @@ def func_graph_from_py_func(name, # Wrapping around a decorator allows checks like tf_inspect.getargspec # to be accurate. converted_func = tf_decorator.make_decorator(original_func, wrapper) - tf_decorator.rewrap(python_func, original_func, converted_func) + python_func = tf_decorator.rewrap(python_func, original_func, + converted_func) func_outputs = python_func(*func_args, **func_kwargs) diff --git a/tensorflow/python/util/tf_decorator.py b/tensorflow/python/util/tf_decorator.py index 6c71f70c4e7..033d35b2fe7 100644 --- a/tensorflow/python/util/tf_decorator.py +++ b/tensorflow/python/util/tf_decorator.py @@ -138,6 +138,10 @@ def rewrap(decorator_func, previous_target, new_target): decorator_func: Callable returned by `wrap`. previous_target: Callable that needs to be replaced. new_target: Callable to replace previous_target with. + + Returns: + The updated decorator. If decorator_func is not a tf_decorator, new_target + is returned. """ # Because the process mutates the decorator, we only need to alter the # innermost function that wraps previous_target. @@ -145,14 +149,20 @@ def rewrap(decorator_func, previous_target, new_target): innermost_decorator = None target = None while hasattr(cur, '_tf_decorator'): + assert cur is not None innermost_decorator = cur target = getattr(cur, '_tf_decorator') if target.decorated_target is previous_target: break cur = target.decorated_target + # If decorator_func is not a decorator, new_target replaces it directly. if innermost_decorator is None: - return + # Consistency check. The caller should always pass the result of + # tf_decorator.unwrap as previous_target. If decorator_func is not a + # decorator, that will have returned decorator_func itself. + assert decorator_func is previous_target + return new_target target.decorated_target = new_target @@ -168,6 +178,8 @@ def rewrap(decorator_func, previous_target, new_target): else: innermost_decorator.__wrapped__ = new_target + return decorator_func + def unwrap(maybe_tf_decorator): """Unwraps an object into a list of TFDecorators and a final target. -- GitLab