提交 e42a66b2 编写于 作者: D Dan Moldovan 提交者: TensorFlower Gardener

Fix rewrap to always return the updated entity. Use it as such at the...

Fix rewrap to always return the updated entity. Use it as such at the autograph integration point. In turn, temporarily disable autograph for dataset functions until we reduce the conversion overhead.
Fixes #25159.
With thanks to @ktaebum for identifying this fix.

PiperOrigin-RevId: 235255035
上级 d56074d1
......@@ -2366,6 +2366,8 @@ class StructuredFunctionWrapper(object):
else:
defun_kwargs.update({"func_name": func_name})
# TODO(b/124254153): Enable autograph once the overhead is low enough.
# TODO(mdan): Make sure autograph recurses into _wrapper_helper when on.
@eager_function.defun_with_attributes(
input_signature=[
tensor_spec.TensorSpec(input_shape, input_type) # pylint: disable=g-complex-comprehension
......@@ -2373,6 +2375,7 @@ class StructuredFunctionWrapper(object):
self._input_structure._flat_shapes,
self._input_structure._flat_types)
],
autograph=False,
attributes=defun_kwargs)
def wrapper_fn(*args): # pylint: disable=missing-docstring
ret = _wrapper_helper(*args)
......
......@@ -42,13 +42,15 @@ class _PerDeviceGenerator(dataset_ops.DatasetV2):
gen_dataset_ops.multi_device_iterator_to_string_handle(
multi_device_iterator_resource))
@function.defun()
# TODO(b/124254153): Enable autograph once the overhead is low enough.
@function.defun(autograph=False) # Pure graph code.
def _init_func():
return multi_device_iterator_string_handle
init_func_concrete = _init_func._get_concrete_function_internal() # pylint: disable=protected-access
@function.defun()
# TODO(b/124254153): Enable autograph once the overhead is low enough.
@function.defun(autograph=False) # Pure graph code.
def _remote_init_func():
return functional_ops.remote_call(
target=source_device,
......@@ -59,7 +61,10 @@ class _PerDeviceGenerator(dataset_ops.DatasetV2):
self._init_func = _remote_init_func._get_concrete_function_internal() # pylint: disable=protected-access
self._init_captured_args = self._init_func.captured_inputs
@function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
# TODO(b/124254153): Enable autograph once the overhead is low enough.
@function.defun(
input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
autograph=False) # Pure graph code.
def _next_func(string_handle):
# pylint: disable=protected-access
multi_device_iterator = (
......@@ -76,9 +81,11 @@ class _PerDeviceGenerator(dataset_ops.DatasetV2):
next_func_concrete = _next_func._get_concrete_function_internal() # pylint: disable=protected-access
# TODO(b/124254153): Enable autograph once the overhead is low enough.
@function.defun_with_attributes(
input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
attributes={"experimental_ints_on_device": True})
attributes={"experimental_ints_on_device": True},
autograph=False) # Pure graph code.
def _remote_next_func(string_handle):
return functional_ops.remote_call(
target=source_device,
......@@ -94,13 +101,19 @@ class _PerDeviceGenerator(dataset_ops.DatasetV2):
if arg == incarnation_id:
self._incarnation_id_index = i
@function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
# TODO(b/124254153): Enable autograph once the overhead is low enough.
@function.defun(
input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
autograph=False) # Pure graph code.
def _finalize_func(unused_string_handle):
return array_ops.constant(0, dtypes.int64)
finalize_func_concrete = _finalize_func._get_concrete_function_internal() # pylint: disable=protected-access
@function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
# TODO(b/124254153): Enable autograph once the overhead is low enough.
@function.defun(
input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
autograph=False) # Pure graph code.
def _remote_finalize_func(string_handle):
return functional_ops.remote_call(
target=source_device,
......
......@@ -514,7 +514,7 @@ class Function(object):
"""Make and call a `ConcreteFunction` which initializes variables."""
# Note: using defun here avoids an infinite recursion.
@function_lib.defun
@function_lib.defun(autograph=False) # Pure graph code.
def initialize_variables():
for v, init in initializer_map.items():
with ops.init_scope():
......
......@@ -633,7 +633,8 @@ def func_graph_from_py_func(name,
# Wrapping around a decorator allows checks like tf_inspect.getargspec
# to be accurate.
converted_func = tf_decorator.make_decorator(original_func, wrapper)
tf_decorator.rewrap(python_func, original_func, converted_func)
python_func = tf_decorator.rewrap(python_func, original_func,
converted_func)
func_outputs = python_func(*func_args, **func_kwargs)
......
......@@ -138,6 +138,10 @@ def rewrap(decorator_func, previous_target, new_target):
decorator_func: Callable returned by `wrap`.
previous_target: Callable that needs to be replaced.
new_target: Callable to replace previous_target with.
Returns:
The updated decorator. If decorator_func is not a tf_decorator, new_target
is returned.
"""
# Because the process mutates the decorator, we only need to alter the
# innermost function that wraps previous_target.
......@@ -145,14 +149,20 @@ def rewrap(decorator_func, previous_target, new_target):
innermost_decorator = None
target = None
while hasattr(cur, '_tf_decorator'):
assert cur is not None
innermost_decorator = cur
target = getattr(cur, '_tf_decorator')
if target.decorated_target is previous_target:
break
cur = target.decorated_target
# If decorator_func is not a decorator, new_target replaces it directly.
if innermost_decorator is None:
return
# Consistency check. The caller should always pass the result of
# tf_decorator.unwrap as previous_target. If decorator_func is not a
# decorator, that will have returned decorator_func itself.
assert decorator_func is previous_target
return new_target
target.decorated_target = new_target
......@@ -168,6 +178,8 @@ def rewrap(decorator_func, previous_target, new_target):
else:
innermost_decorator.__wrapped__ = new_target
return decorator_func
def unwrap(maybe_tf_decorator):
"""Unwraps an object into a list of TFDecorators and a final target.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册