提交 827c1005 编写于 作者: D Dan Moldovan 提交者: TensorFlower Gardener

Cache information about unconverted functions for faster subsequent execution....

Cache information about unconverted functions for faster subsequent execution. This resulted in about 3x speedup in a basic benchmark operating on NumPy code, although it remained 3-4 slower compared to pure Python code.
It will also avoid log spam when conversion fails in most cases.

PiperOrigin-RevId: 258843781
上级 20dad0ba
......@@ -340,8 +340,11 @@ def _attach_metadata(e, f, converted):
e.ag_error_metadata = _ErrorMetadata(cause_tb, metadata, message, source_map)
def _call_unconverted(f, args, kwargs):
def _call_unconverted(f, args, kwargs, options, update_cache=True):
"""Calls the original function without converting with AutoGraph."""
if update_cache:
conversion.cache_unconverted(f, options)
if inspect_utils.istfmethodtarget(f):
return f.__self__.call(args, kwargs)
......@@ -387,6 +390,9 @@ def converted_call(f, options, args, kwargs):
logging.log(1, 'Converted call: %s\n args: %s\n kwargs: %s\n', f, args,
kwargs)
if conversion.check_cached_unconverted(f, options):
return _call_unconverted(f, args, kwargs, options, False)
if inspect_utils.isbuiltin(f):
if f is eval:
return py_builtins.eval_in_original_context(f, args, 1)
......@@ -398,7 +404,7 @@ def converted_call(f, options, args, kwargs):
# TODO(mdan): Clean up the naming inconsistency.
if hasattr(f, 'autograph_info__') or hasattr(f, '__ag_compiled'):
logging.log(2, 'Permanently whitelisted: %s: already converted', f)
return _call_unconverted(f, args, kwargs)
return _call_unconverted(f, args, kwargs, options)
# TODO(b/122265385): Remove this bypass.
if (_is_known_loaded_type(f, 'wrapt', 'FunctionWrapper') or
......@@ -408,42 +414,42 @@ def converted_call(f, options, args, kwargs):
' by AutoGraph. The function will be called without transformation.'
' You may however apply AutoGraph before the decorator.'.format(f))
logging.log(2, 'Permanently whitelisted: %s: wrapt decorated', f)
return _call_unconverted(f, args, kwargs)
return _call_unconverted(f, args, kwargs, options)
if _is_known_loaded_type(f, 'functools', '_lru_cache_wrapper'):
logging.log(2, 'Permanently whitelisted: %s: lru_cache', f)
return _call_unconverted(f, args, kwargs)
return _call_unconverted(f, args, kwargs, options)
# Constructors are permanently whitelisted.
# TODO(mdan): Toggle as experimental feature instead.
# TODO(b/124016764): Remove this limitation.
if tf_inspect.isclass(f):
logging.log(2, 'Permanently whitelisted: %s: constructor', f)
return _call_unconverted(f, args, kwargs)
return _call_unconverted(f, args, kwargs, options)
# Other built-in modules are permanently whitelisted.
# TODO(mdan): Figure out how to do this consistently for all stdlib modules.
if any(
f in m.__dict__.values() for m in (collections, pdb, copy, inspect, re)):
logging.log(2, 'Permanently whitelisted: %s: part of builtin module', f)
return _call_unconverted(f, args, kwargs)
return _call_unconverted(f, args, kwargs, options)
# Custom ops and kernels are also permanently whitelisted.
# See tensorflow.framework.load_library.
if (hasattr(f, '__module__') and
hasattr(f.__module__, '_IS_TENSORFLOW_PLUGIN')):
logging.log(2, 'Permanently whitelisted: %s: TensorFlow plugin', f)
return _call_unconverted(f, args, kwargs)
return _call_unconverted(f, args, kwargs, options)
if not options.force_conversion and conversion.is_whitelisted_for_graph(f):
return _call_unconverted(f, args, kwargs)
return _call_unconverted(f, args, kwargs, options)
# internal_convert_user_code is for example turned off when issuing a dynamic
# call conversion from generated code while in nonrecursive mode. In that
# case we evidently don't want to recurse, but we still have to convert
# things like builtins.
if not options.internal_convert_user_code:
return _call_unconverted(f, args, kwargs)
return _call_unconverted(f, args, kwargs, options)
# TODO(mdan): Move this entire block inside to_graph.
try: # Begin of transformation error guards
......@@ -493,13 +499,13 @@ def converted_call(f, options, args, kwargs):
if not hasattr(target_entity, '__code__'):
logging.log(2, 'Permanently whitelisted: %s: native binding',
target_entity)
return _call_unconverted(f, args, kwargs)
return _call_unconverted(f, args, kwargs, options)
elif (hasattr(target_entity.__code__, 'co_filename') and
target_entity.__code__.co_filename == '<string>'):
# TODO(mdan): __globals__['txt'] might work in Py3.
logging.log(2, 'Permanently whitelisted: %s: dynamic code (exec?)',
target_entity)
return _call_unconverted(f, args, kwargs)
return _call_unconverted(f, args, kwargs, options)
converted_f = to_graph(
target_entity,
......@@ -532,7 +538,7 @@ def converted_call(f, options, args, kwargs):
' Please report this to the AutoGraph team. When filing the bug, set'
' the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and'
' attach the full output. Cause: %s', target_entity, e)
return _call_unconverted(f, args, kwargs)
return _call_unconverted(f, args, kwargs, options)
with StackTraceMapper(converted_f), tf_stack.CurrentModuleFilter():
try:
......
......@@ -129,6 +129,7 @@ _CACHE_LOCK = threading.RLock()
_CACHE = _ConversionCache()
_UNCONVERTED_CACHE = _ConversionCache()
# Note: strictly speaking, a simple factory might have been sufficient for
......@@ -414,6 +415,22 @@ def is_whitelisted_for_graph(o, check_call_override=True):
return False
def check_cached_unconverted(entity, options):
try:
# Catch-all for entities that are unhashable or don't allow weakrefs.
return _UNCONVERTED_CACHE.has(entity, options)
except TypeError:
return False
def cache_unconverted(entity, options):
try:
# Catch-all for entities that are unhashable or don't allow weakrefs.
_UNCONVERTED_CACHE[entity][options] = True
except TypeError:
pass
# TODO(mdan): Rename to convert_*_node to avoid confusion with convert.
def convert_entity_to_ast(o, program_ctx):
"""Compile a Python entity into equivalent TensorFlow.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册