提交 e42a66b2 编写于 作者: D Dan Moldovan 提交者: TensorFlower Gardener

Fix rewrap to always return the updated entity. Use it as such at the...

Fix rewrap to always return the updated entity. Use it as such at the autograph integration point. In turn, temporarily disable autograph for dataset functions until we reduce the conversion overhead.
Fixes #25159.
With thanks to @ktaebum for identifying this fix.

PiperOrigin-RevId: 235255035
上级 d56074d1
...@@ -2366,6 +2366,8 @@ class StructuredFunctionWrapper(object): ...@@ -2366,6 +2366,8 @@ class StructuredFunctionWrapper(object):
else: else:
defun_kwargs.update({"func_name": func_name}) defun_kwargs.update({"func_name": func_name})
# TODO(b/124254153): Enable autograph once the overhead is low enough.
# TODO(mdan): Make sure autograph recurses into _wrapper_helper when on.
@eager_function.defun_with_attributes( @eager_function.defun_with_attributes(
input_signature=[ input_signature=[
tensor_spec.TensorSpec(input_shape, input_type) # pylint: disable=g-complex-comprehension tensor_spec.TensorSpec(input_shape, input_type) # pylint: disable=g-complex-comprehension
...@@ -2373,6 +2375,7 @@ class StructuredFunctionWrapper(object): ...@@ -2373,6 +2375,7 @@ class StructuredFunctionWrapper(object):
self._input_structure._flat_shapes, self._input_structure._flat_shapes,
self._input_structure._flat_types) self._input_structure._flat_types)
], ],
autograph=False,
attributes=defun_kwargs) attributes=defun_kwargs)
def wrapper_fn(*args): # pylint: disable=missing-docstring def wrapper_fn(*args): # pylint: disable=missing-docstring
ret = _wrapper_helper(*args) ret = _wrapper_helper(*args)
......
...@@ -42,13 +42,15 @@ class _PerDeviceGenerator(dataset_ops.DatasetV2): ...@@ -42,13 +42,15 @@ class _PerDeviceGenerator(dataset_ops.DatasetV2):
gen_dataset_ops.multi_device_iterator_to_string_handle( gen_dataset_ops.multi_device_iterator_to_string_handle(
multi_device_iterator_resource)) multi_device_iterator_resource))
@function.defun() # TODO(b/124254153): Enable autograph once the overhead is low enough.
@function.defun(autograph=False) # Pure graph code.
def _init_func(): def _init_func():
return multi_device_iterator_string_handle return multi_device_iterator_string_handle
init_func_concrete = _init_func._get_concrete_function_internal() # pylint: disable=protected-access init_func_concrete = _init_func._get_concrete_function_internal() # pylint: disable=protected-access
@function.defun() # TODO(b/124254153): Enable autograph once the overhead is low enough.
@function.defun(autograph=False) # Pure graph code.
def _remote_init_func(): def _remote_init_func():
return functional_ops.remote_call( return functional_ops.remote_call(
target=source_device, target=source_device,
...@@ -59,7 +61,10 @@ class _PerDeviceGenerator(dataset_ops.DatasetV2): ...@@ -59,7 +61,10 @@ class _PerDeviceGenerator(dataset_ops.DatasetV2):
self._init_func = _remote_init_func._get_concrete_function_internal() # pylint: disable=protected-access self._init_func = _remote_init_func._get_concrete_function_internal() # pylint: disable=protected-access
self._init_captured_args = self._init_func.captured_inputs self._init_captured_args = self._init_func.captured_inputs
@function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) # TODO(b/124254153): Enable autograph once the overhead is low enough.
@function.defun(
input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
autograph=False) # Pure graph code.
def _next_func(string_handle): def _next_func(string_handle):
# pylint: disable=protected-access # pylint: disable=protected-access
multi_device_iterator = ( multi_device_iterator = (
...@@ -76,9 +81,11 @@ class _PerDeviceGenerator(dataset_ops.DatasetV2): ...@@ -76,9 +81,11 @@ class _PerDeviceGenerator(dataset_ops.DatasetV2):
next_func_concrete = _next_func._get_concrete_function_internal() # pylint: disable=protected-access next_func_concrete = _next_func._get_concrete_function_internal() # pylint: disable=protected-access
# TODO(b/124254153): Enable autograph once the overhead is low enough.
@function.defun_with_attributes( @function.defun_with_attributes(
input_signature=[tensor_spec.TensorSpec([], dtypes.string)], input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
attributes={"experimental_ints_on_device": True}) attributes={"experimental_ints_on_device": True},
autograph=False) # Pure graph code.
def _remote_next_func(string_handle): def _remote_next_func(string_handle):
return functional_ops.remote_call( return functional_ops.remote_call(
target=source_device, target=source_device,
...@@ -94,13 +101,19 @@ class _PerDeviceGenerator(dataset_ops.DatasetV2): ...@@ -94,13 +101,19 @@ class _PerDeviceGenerator(dataset_ops.DatasetV2):
if arg == incarnation_id: if arg == incarnation_id:
self._incarnation_id_index = i self._incarnation_id_index = i
@function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) # TODO(b/124254153): Enable autograph once the overhead is low enough.
@function.defun(
input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
autograph=False) # Pure graph code.
def _finalize_func(unused_string_handle): def _finalize_func(unused_string_handle):
return array_ops.constant(0, dtypes.int64) return array_ops.constant(0, dtypes.int64)
finalize_func_concrete = _finalize_func._get_concrete_function_internal() # pylint: disable=protected-access finalize_func_concrete = _finalize_func._get_concrete_function_internal() # pylint: disable=protected-access
@function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) # TODO(b/124254153): Enable autograph once the overhead is low enough.
@function.defun(
input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
autograph=False) # Pure graph code.
def _remote_finalize_func(string_handle): def _remote_finalize_func(string_handle):
return functional_ops.remote_call( return functional_ops.remote_call(
target=source_device, target=source_device,
......
...@@ -514,7 +514,7 @@ class Function(object): ...@@ -514,7 +514,7 @@ class Function(object):
"""Make and call a `ConcreteFunction` which initializes variables.""" """Make and call a `ConcreteFunction` which initializes variables."""
# Note: using defun here avoids an infinite recursion. # Note: using defun here avoids an infinite recursion.
@function_lib.defun @function_lib.defun(autograph=False) # Pure graph code.
def initialize_variables(): def initialize_variables():
for v, init in initializer_map.items(): for v, init in initializer_map.items():
with ops.init_scope(): with ops.init_scope():
......
...@@ -633,7 +633,8 @@ def func_graph_from_py_func(name, ...@@ -633,7 +633,8 @@ def func_graph_from_py_func(name,
# Wrapping around a decorator allows checks like tf_inspect.getargspec # Wrapping around a decorator allows checks like tf_inspect.getargspec
# to be accurate. # to be accurate.
converted_func = tf_decorator.make_decorator(original_func, wrapper) converted_func = tf_decorator.make_decorator(original_func, wrapper)
tf_decorator.rewrap(python_func, original_func, converted_func) python_func = tf_decorator.rewrap(python_func, original_func,
converted_func)
func_outputs = python_func(*func_args, **func_kwargs) func_outputs = python_func(*func_args, **func_kwargs)
......
...@@ -138,6 +138,10 @@ def rewrap(decorator_func, previous_target, new_target): ...@@ -138,6 +138,10 @@ def rewrap(decorator_func, previous_target, new_target):
decorator_func: Callable returned by `wrap`. decorator_func: Callable returned by `wrap`.
previous_target: Callable that needs to be replaced. previous_target: Callable that needs to be replaced.
new_target: Callable to replace previous_target with. new_target: Callable to replace previous_target with.
Returns:
The updated decorator. If decorator_func is not a tf_decorator, new_target
is returned.
""" """
# Because the process mutates the decorator, we only need to alter the # Because the process mutates the decorator, we only need to alter the
# innermost function that wraps previous_target. # innermost function that wraps previous_target.
...@@ -145,14 +149,20 @@ def rewrap(decorator_func, previous_target, new_target): ...@@ -145,14 +149,20 @@ def rewrap(decorator_func, previous_target, new_target):
innermost_decorator = None innermost_decorator = None
target = None target = None
while hasattr(cur, '_tf_decorator'): while hasattr(cur, '_tf_decorator'):
assert cur is not None
innermost_decorator = cur innermost_decorator = cur
target = getattr(cur, '_tf_decorator') target = getattr(cur, '_tf_decorator')
if target.decorated_target is previous_target: if target.decorated_target is previous_target:
break break
cur = target.decorated_target cur = target.decorated_target
# If decorator_func is not a decorator, new_target replaces it directly.
if innermost_decorator is None: if innermost_decorator is None:
return # Consistency check. The caller should always pass the result of
# tf_decorator.unwrap as previous_target. If decorator_func is not a
# decorator, that will have returned decorator_func itself.
assert decorator_func is previous_target
return new_target
target.decorated_target = new_target target.decorated_target = new_target
...@@ -168,6 +178,8 @@ def rewrap(decorator_func, previous_target, new_target): ...@@ -168,6 +178,8 @@ def rewrap(decorator_func, previous_target, new_target):
else: else:
innermost_decorator.__wrapped__ = new_target innermost_decorator.__wrapped__ = new_target
return decorator_func
def unwrap(maybe_tf_decorator): def unwrap(maybe_tf_decorator):
"""Unwraps an object into a list of TFDecorators and a final target. """Unwraps an object into a list of TFDecorators and a final target.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册