提交 1721fa8b 编写于 作者: D Dan Moldovan 提交者: TensorFlower Gardener

Automated rollback of commit e42a66b2

PiperOrigin-RevId: 235411715
上级 1ad34b4f
......@@ -2366,8 +2366,6 @@ class StructuredFunctionWrapper(object):
else:
defun_kwargs.update({"func_name": func_name})
# TODO(b/124254153): Enable autograph once the overhead is low enough.
# TODO(mdan): Make sure autograph recurses into _wrapper_helper when on.
@eager_function.defun_with_attributes(
input_signature=[
tensor_spec.TensorSpec(input_shape, input_type) # pylint: disable=g-complex-comprehension
......@@ -2375,7 +2373,6 @@ class StructuredFunctionWrapper(object):
self._input_structure._flat_shapes,
self._input_structure._flat_types)
],
autograph=False,
attributes=defun_kwargs)
def wrapper_fn(*args): # pylint: disable=missing-docstring
ret = _wrapper_helper(*args)
......
......@@ -42,15 +42,13 @@ class _PerDeviceGenerator(dataset_ops.DatasetV2):
gen_dataset_ops.multi_device_iterator_to_string_handle(
multi_device_iterator_resource))
# TODO(b/124254153): Enable autograph once the overhead is low enough.
@function.defun(autograph=False) # Pure graph code.
@function.defun()
def _init_func():
return multi_device_iterator_string_handle
init_func_concrete = _init_func._get_concrete_function_internal() # pylint: disable=protected-access
# TODO(b/124254153): Enable autograph once the overhead is low enough.
@function.defun(autograph=False) # Pure graph code.
@function.defun()
def _remote_init_func():
return functional_ops.remote_call(
target=source_device,
......@@ -61,10 +59,7 @@ class _PerDeviceGenerator(dataset_ops.DatasetV2):
self._init_func = _remote_init_func._get_concrete_function_internal() # pylint: disable=protected-access
self._init_captured_args = self._init_func.captured_inputs
# TODO(b/124254153): Enable autograph once the overhead is low enough.
@function.defun(
input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
autograph=False) # Pure graph code.
@function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
def _next_func(string_handle):
# pylint: disable=protected-access
multi_device_iterator = (
......@@ -81,11 +76,9 @@ class _PerDeviceGenerator(dataset_ops.DatasetV2):
next_func_concrete = _next_func._get_concrete_function_internal() # pylint: disable=protected-access
# TODO(b/124254153): Enable autograph once the overhead is low enough.
@function.defun_with_attributes(
input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
attributes={"experimental_ints_on_device": True},
autograph=False) # Pure graph code.
attributes={"experimental_ints_on_device": True})
def _remote_next_func(string_handle):
return functional_ops.remote_call(
target=source_device,
......@@ -101,19 +94,13 @@ class _PerDeviceGenerator(dataset_ops.DatasetV2):
if arg == incarnation_id:
self._incarnation_id_index = i
# TODO(b/124254153): Enable autograph once the overhead is low enough.
@function.defun(
input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
autograph=False) # Pure graph code.
@function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
def _finalize_func(unused_string_handle):
return array_ops.constant(0, dtypes.int64)
finalize_func_concrete = _finalize_func._get_concrete_function_internal() # pylint: disable=protected-access
# TODO(b/124254153): Enable autograph once the overhead is low enough.
@function.defun(
input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
autograph=False) # Pure graph code.
@function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
def _remote_finalize_func(string_handle):
return functional_ops.remote_call(
target=source_device,
......
......@@ -514,7 +514,7 @@ class Function(object):
"""Make and call a `ConcreteFunction` which initializes variables."""
# Note: using defun here avoids an infinite recursion.
@function_lib.defun(autograph=False) # Pure graph code.
@function_lib.defun
def initialize_variables():
for v, init in initializer_map.items():
with ops.init_scope():
......
......@@ -633,8 +633,7 @@ def func_graph_from_py_func(name,
# Wrapping around a decorator allows checks like tf_inspect.getargspec
# to be accurate.
converted_func = tf_decorator.make_decorator(original_func, wrapper)
python_func = tf_decorator.rewrap(python_func, original_func,
converted_func)
tf_decorator.rewrap(python_func, original_func, converted_func)
func_outputs = python_func(*func_args, **func_kwargs)
......
......@@ -138,10 +138,6 @@ def rewrap(decorator_func, previous_target, new_target):
decorator_func: Callable returned by `wrap`.
previous_target: Callable that needs to be replaced.
new_target: Callable to replace previous_target with.
Returns:
The updated decorator. If decorator_func is not a tf_decorator, new_target
is returned.
"""
# Because the process mutates the decorator, we only need to alter the
# innermost function that wraps previous_target.
......@@ -149,20 +145,14 @@ def rewrap(decorator_func, previous_target, new_target):
innermost_decorator = None
target = None
while hasattr(cur, '_tf_decorator'):
assert cur is not None
innermost_decorator = cur
target = getattr(cur, '_tf_decorator')
if target.decorated_target is previous_target:
break
cur = target.decorated_target
# If decorator_func is not a decorator, new_target replaces it directly.
if innermost_decorator is None:
# Consistency check. The caller should always pass the result of
# tf_decorator.unwrap as previous_target. If decorator_func is not a
# decorator, that will have returned decorator_func itself.
assert decorator_func is previous_target
return new_target
return
target.decorated_target = new_target
......@@ -178,8 +168,6 @@ def rewrap(decorator_func, previous_target, new_target):
else:
innermost_decorator.__wrapped__ = new_target
return decorator_func
def unwrap(maybe_tf_decorator):
"""Unwraps an object into a list of TFDecorators and a final target.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册