提交 1721fa8b 编写于 作者: D Dan Moldovan 提交者: TensorFlower Gardener

Automated rollback of commit e42a66b2

PiperOrigin-RevId: 235411715
上级 1ad34b4f
...@@ -2366,8 +2366,6 @@ class StructuredFunctionWrapper(object): ...@@ -2366,8 +2366,6 @@ class StructuredFunctionWrapper(object):
else: else:
defun_kwargs.update({"func_name": func_name}) defun_kwargs.update({"func_name": func_name})
# TODO(b/124254153): Enable autograph once the overhead is low enough.
# TODO(mdan): Make sure autograph recurses into _wrapper_helper when on.
@eager_function.defun_with_attributes( @eager_function.defun_with_attributes(
input_signature=[ input_signature=[
tensor_spec.TensorSpec(input_shape, input_type) # pylint: disable=g-complex-comprehension tensor_spec.TensorSpec(input_shape, input_type) # pylint: disable=g-complex-comprehension
...@@ -2375,7 +2373,6 @@ class StructuredFunctionWrapper(object): ...@@ -2375,7 +2373,6 @@ class StructuredFunctionWrapper(object):
self._input_structure._flat_shapes, self._input_structure._flat_shapes,
self._input_structure._flat_types) self._input_structure._flat_types)
], ],
autograph=False,
attributes=defun_kwargs) attributes=defun_kwargs)
def wrapper_fn(*args): # pylint: disable=missing-docstring def wrapper_fn(*args): # pylint: disable=missing-docstring
ret = _wrapper_helper(*args) ret = _wrapper_helper(*args)
......
...@@ -42,15 +42,13 @@ class _PerDeviceGenerator(dataset_ops.DatasetV2): ...@@ -42,15 +42,13 @@ class _PerDeviceGenerator(dataset_ops.DatasetV2):
gen_dataset_ops.multi_device_iterator_to_string_handle( gen_dataset_ops.multi_device_iterator_to_string_handle(
multi_device_iterator_resource)) multi_device_iterator_resource))
# TODO(b/124254153): Enable autograph once the overhead is low enough. @function.defun()
@function.defun(autograph=False) # Pure graph code.
def _init_func(): def _init_func():
return multi_device_iterator_string_handle return multi_device_iterator_string_handle
init_func_concrete = _init_func._get_concrete_function_internal() # pylint: disable=protected-access init_func_concrete = _init_func._get_concrete_function_internal() # pylint: disable=protected-access
# TODO(b/124254153): Enable autograph once the overhead is low enough. @function.defun()
@function.defun(autograph=False) # Pure graph code.
def _remote_init_func(): def _remote_init_func():
return functional_ops.remote_call( return functional_ops.remote_call(
target=source_device, target=source_device,
...@@ -61,10 +59,7 @@ class _PerDeviceGenerator(dataset_ops.DatasetV2): ...@@ -61,10 +59,7 @@ class _PerDeviceGenerator(dataset_ops.DatasetV2):
self._init_func = _remote_init_func._get_concrete_function_internal() # pylint: disable=protected-access self._init_func = _remote_init_func._get_concrete_function_internal() # pylint: disable=protected-access
self._init_captured_args = self._init_func.captured_inputs self._init_captured_args = self._init_func.captured_inputs
# TODO(b/124254153): Enable autograph once the overhead is low enough. @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
@function.defun(
input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
autograph=False) # Pure graph code.
def _next_func(string_handle): def _next_func(string_handle):
# pylint: disable=protected-access # pylint: disable=protected-access
multi_device_iterator = ( multi_device_iterator = (
...@@ -81,11 +76,9 @@ class _PerDeviceGenerator(dataset_ops.DatasetV2): ...@@ -81,11 +76,9 @@ class _PerDeviceGenerator(dataset_ops.DatasetV2):
next_func_concrete = _next_func._get_concrete_function_internal() # pylint: disable=protected-access next_func_concrete = _next_func._get_concrete_function_internal() # pylint: disable=protected-access
# TODO(b/124254153): Enable autograph once the overhead is low enough.
@function.defun_with_attributes( @function.defun_with_attributes(
input_signature=[tensor_spec.TensorSpec([], dtypes.string)], input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
attributes={"experimental_ints_on_device": True}, attributes={"experimental_ints_on_device": True})
autograph=False) # Pure graph code.
def _remote_next_func(string_handle): def _remote_next_func(string_handle):
return functional_ops.remote_call( return functional_ops.remote_call(
target=source_device, target=source_device,
...@@ -101,19 +94,13 @@ class _PerDeviceGenerator(dataset_ops.DatasetV2): ...@@ -101,19 +94,13 @@ class _PerDeviceGenerator(dataset_ops.DatasetV2):
if arg == incarnation_id: if arg == incarnation_id:
self._incarnation_id_index = i self._incarnation_id_index = i
# TODO(b/124254153): Enable autograph once the overhead is low enough. @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
@function.defun(
input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
autograph=False) # Pure graph code.
def _finalize_func(unused_string_handle): def _finalize_func(unused_string_handle):
return array_ops.constant(0, dtypes.int64) return array_ops.constant(0, dtypes.int64)
finalize_func_concrete = _finalize_func._get_concrete_function_internal() # pylint: disable=protected-access finalize_func_concrete = _finalize_func._get_concrete_function_internal() # pylint: disable=protected-access
# TODO(b/124254153): Enable autograph once the overhead is low enough. @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
@function.defun(
input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
autograph=False) # Pure graph code.
def _remote_finalize_func(string_handle): def _remote_finalize_func(string_handle):
return functional_ops.remote_call( return functional_ops.remote_call(
target=source_device, target=source_device,
......
...@@ -514,7 +514,7 @@ class Function(object): ...@@ -514,7 +514,7 @@ class Function(object):
"""Make and call a `ConcreteFunction` which initializes variables.""" """Make and call a `ConcreteFunction` which initializes variables."""
# Note: using defun here avoids an infinite recursion. # Note: using defun here avoids an infinite recursion.
@function_lib.defun(autograph=False) # Pure graph code. @function_lib.defun
def initialize_variables(): def initialize_variables():
for v, init in initializer_map.items(): for v, init in initializer_map.items():
with ops.init_scope(): with ops.init_scope():
......
...@@ -633,8 +633,7 @@ def func_graph_from_py_func(name, ...@@ -633,8 +633,7 @@ def func_graph_from_py_func(name,
# Wrapping around a decorator allows checks like tf_inspect.getargspec # Wrapping around a decorator allows checks like tf_inspect.getargspec
# to be accurate. # to be accurate.
converted_func = tf_decorator.make_decorator(original_func, wrapper) converted_func = tf_decorator.make_decorator(original_func, wrapper)
python_func = tf_decorator.rewrap(python_func, original_func, tf_decorator.rewrap(python_func, original_func, converted_func)
converted_func)
func_outputs = python_func(*func_args, **func_kwargs) func_outputs = python_func(*func_args, **func_kwargs)
......
...@@ -138,10 +138,6 @@ def rewrap(decorator_func, previous_target, new_target): ...@@ -138,10 +138,6 @@ def rewrap(decorator_func, previous_target, new_target):
decorator_func: Callable returned by `wrap`. decorator_func: Callable returned by `wrap`.
previous_target: Callable that needs to be replaced. previous_target: Callable that needs to be replaced.
new_target: Callable to replace previous_target with. new_target: Callable to replace previous_target with.
Returns:
The updated decorator. If decorator_func is not a tf_decorator, new_target
is returned.
""" """
# Because the process mutates the decorator, we only need to alter the # Because the process mutates the decorator, we only need to alter the
# innermost function that wraps previous_target. # innermost function that wraps previous_target.
...@@ -149,20 +145,14 @@ def rewrap(decorator_func, previous_target, new_target): ...@@ -149,20 +145,14 @@ def rewrap(decorator_func, previous_target, new_target):
innermost_decorator = None innermost_decorator = None
target = None target = None
while hasattr(cur, '_tf_decorator'): while hasattr(cur, '_tf_decorator'):
assert cur is not None
innermost_decorator = cur innermost_decorator = cur
target = getattr(cur, '_tf_decorator') target = getattr(cur, '_tf_decorator')
if target.decorated_target is previous_target: if target.decorated_target is previous_target:
break break
cur = target.decorated_target cur = target.decorated_target
# If decorator_func is not a decorator, new_target replaces it directly.
if innermost_decorator is None: if innermost_decorator is None:
# Consistency check. The caller should always pass the result of return
# tf_decorator.unwrap as previous_target. If decorator_func is not a
# decorator, that will have returned decorator_func itself.
assert decorator_func is previous_target
return new_target
target.decorated_target = new_target target.decorated_target = new_target
...@@ -178,8 +168,6 @@ def rewrap(decorator_func, previous_target, new_target): ...@@ -178,8 +168,6 @@ def rewrap(decorator_func, previous_target, new_target):
else: else:
innermost_decorator.__wrapped__ = new_target innermost_decorator.__wrapped__ = new_target
return decorator_func
def unwrap(maybe_tf_decorator): def unwrap(maybe_tf_decorator):
"""Unwraps an object into a list of TFDecorators and a final target. """Unwraps an object into a list of TFDecorators and a final target.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册