diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 3f50738f17f2499ebefa9480f63053ea7010d120..91a07a5534d03dfee519a1da1cf70767e21fda20 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -2366,8 +2366,6 @@ class StructuredFunctionWrapper(object): else: defun_kwargs.update({"func_name": func_name}) - # TODO(b/124254153): Enable autograph once the overhead is low enough. - # TODO(mdan): Make sure autograph recurses into _wrapper_helper when on. @eager_function.defun_with_attributes( input_signature=[ tensor_spec.TensorSpec(input_shape, input_type) # pylint: disable=g-complex-comprehension @@ -2375,7 +2373,6 @@ class StructuredFunctionWrapper(object): self._input_structure._flat_shapes, self._input_structure._flat_types) ], - autograph=False, attributes=defun_kwargs) def wrapper_fn(*args): # pylint: disable=missing-docstring ret = _wrapper_helper(*args) diff --git a/tensorflow/python/data/ops/multi_device_iterator_ops.py b/tensorflow/python/data/ops/multi_device_iterator_ops.py index f253f8bf42a225d79db1acb5c0b03f48b40a9703..36e1e311d3913b1be7b2967e224865938c40de6e 100644 --- a/tensorflow/python/data/ops/multi_device_iterator_ops.py +++ b/tensorflow/python/data/ops/multi_device_iterator_ops.py @@ -42,15 +42,13 @@ class _PerDeviceGenerator(dataset_ops.DatasetV2): gen_dataset_ops.multi_device_iterator_to_string_handle( multi_device_iterator_resource)) - # TODO(b/124254153): Enable autograph once the overhead is low enough. - @function.defun(autograph=False) # Pure graph code. + @function.defun() def _init_func(): return multi_device_iterator_string_handle init_func_concrete = _init_func._get_concrete_function_internal() # pylint: disable=protected-access - # TODO(b/124254153): Enable autograph once the overhead is low enough. - @function.defun(autograph=False) # Pure graph code. + @function.defun() def _remote_init_func(): return functional_ops.remote_call( target=source_device, @@ -61,10 +59,7 @@ class _PerDeviceGenerator(dataset_ops.DatasetV2): self._init_func = _remote_init_func._get_concrete_function_internal() # pylint: disable=protected-access self._init_captured_args = self._init_func.captured_inputs - # TODO(b/124254153): Enable autograph once the overhead is low enough. - @function.defun( - input_signature=[tensor_spec.TensorSpec([], dtypes.string)], - autograph=False) # Pure graph code. + @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) def _next_func(string_handle): # pylint: disable=protected-access multi_device_iterator = ( @@ -81,11 +76,9 @@ class _PerDeviceGenerator(dataset_ops.DatasetV2): next_func_concrete = _next_func._get_concrete_function_internal() # pylint: disable=protected-access - # TODO(b/124254153): Enable autograph once the overhead is low enough. @function.defun_with_attributes( input_signature=[tensor_spec.TensorSpec([], dtypes.string)], - attributes={"experimental_ints_on_device": True}, - autograph=False) # Pure graph code. + attributes={"experimental_ints_on_device": True}) def _remote_next_func(string_handle): return functional_ops.remote_call( target=source_device, @@ -101,19 +94,13 @@ class _PerDeviceGenerator(dataset_ops.DatasetV2): if arg == incarnation_id: self._incarnation_id_index = i - # TODO(b/124254153): Enable autograph once the overhead is low enough. - @function.defun( - input_signature=[tensor_spec.TensorSpec([], dtypes.string)], - autograph=False) # Pure graph code. + @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) def _finalize_func(unused_string_handle): return array_ops.constant(0, dtypes.int64) finalize_func_concrete = _finalize_func._get_concrete_function_internal() # pylint: disable=protected-access - # TODO(b/124254153): Enable autograph once the overhead is low enough. - @function.defun( - input_signature=[tensor_spec.TensorSpec([], dtypes.string)], - autograph=False) # Pure graph code. + @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) def _remote_finalize_func(string_handle): return functional_ops.remote_call( target=source_device, diff --git a/tensorflow/python/eager/def_function.py b/tensorflow/python/eager/def_function.py index 330b2e0a76391ab1e09b2b26aa69dece3b86f793..1d54973487ca4c6a0221e376954824d4eba2aacd 100644 --- a/tensorflow/python/eager/def_function.py +++ b/tensorflow/python/eager/def_function.py @@ -514,7 +514,7 @@ class Function(object): """Make and call a `ConcreteFunction` which initializes variables.""" # Note: using defun here avoids an infinite recursion. - @function_lib.defun(autograph=False) # Pure graph code. + @function_lib.defun def initialize_variables(): for v, init in initializer_map.items(): with ops.init_scope(): diff --git a/tensorflow/python/framework/func_graph.py b/tensorflow/python/framework/func_graph.py index aee8d4c65843c89c8ee07d4734269db07d46fb39..9097a8dd1f05de0dc271df18721557f979c44c29 100644 --- a/tensorflow/python/framework/func_graph.py +++ b/tensorflow/python/framework/func_graph.py @@ -633,8 +633,7 @@ def func_graph_from_py_func(name, # Wrapping around a decorator allows checks like tf_inspect.getargspec # to be accurate. converted_func = tf_decorator.make_decorator(original_func, wrapper) - python_func = tf_decorator.rewrap(python_func, original_func, - converted_func) + tf_decorator.rewrap(python_func, original_func, converted_func) func_outputs = python_func(*func_args, **func_kwargs) diff --git a/tensorflow/python/util/tf_decorator.py b/tensorflow/python/util/tf_decorator.py index 033d35b2fe7bdf666520ac012e0042bdf576b103..6c71f70c4e75fb907536e2b296e9e8d5487fb88c 100644 --- a/tensorflow/python/util/tf_decorator.py +++ b/tensorflow/python/util/tf_decorator.py @@ -138,10 +138,6 @@ def rewrap(decorator_func, previous_target, new_target): decorator_func: Callable returned by `wrap`. previous_target: Callable that needs to be replaced. new_target: Callable to replace previous_target with. - - Returns: - The updated decorator. If decorator_func is not a tf_decorator, new_target - is returned. """ # Because the process mutates the decorator, we only need to alter the # innermost function that wraps previous_target. @@ -149,20 +145,14 @@ def rewrap(decorator_func, previous_target, new_target): innermost_decorator = None target = None while hasattr(cur, '_tf_decorator'): - assert cur is not None innermost_decorator = cur target = getattr(cur, '_tf_decorator') if target.decorated_target is previous_target: break cur = target.decorated_target - # If decorator_func is not a decorator, new_target replaces it directly. if innermost_decorator is None: - # Consistency check. The caller should always pass the result of - # tf_decorator.unwrap as previous_target. If decorator_func is not a - # decorator, that will have returned decorator_func itself. - assert decorator_func is previous_target - return new_target + return target.decorated_target = new_target @@ -178,8 +168,6 @@ def rewrap(decorator_func, previous_target, new_target): else: innermost_decorator.__wrapped__ = new_target - return decorator_func - def unwrap(maybe_tf_decorator): """Unwraps an object into a list of TFDecorators and a final target.