From 350791003de42dbb17c53474a677b108f473b0ba Mon Sep 17 00:00:00 2001 From: Dan Moldovan Date: Wed, 12 Dec 2018 11:52:20 -0800 Subject: [PATCH] Reduce the cost of serializing ConversionOptions to code, by using a more efficient inspect.util.getqualifiedname, reducing its max_depth and falling back to caching the value in the namespace. The latter step makes it more difficult to run the generated code afterwards, but it should in turn speed up the conversion process. This also adds an extra check to tf_decorator to improve robustness. PiperOrigin-RevId: 225226256 --- .../python/autograph/converters/call_trees.py | 2 +- tensorflow/python/autograph/core/converter.py | 18 ++++-- tensorflow/python/autograph/impl/api.py | 3 + .../python/autograph/pyct/inspect_utils.py | 25 +++++--- .../autograph/pyct/inspect_utils_test.py | 57 +++++++++++++++++++ tensorflow/python/util/tf_decorator.py | 5 ++ 6 files changed, 94 insertions(+), 16 deletions(-) diff --git a/tensorflow/python/autograph/converters/call_trees.py b/tensorflow/python/autograph/converters/call_trees.py index 3e0b40290f3..b1bfe04347e 100644 --- a/tensorflow/python/autograph/converters/call_trees.py +++ b/tensorflow/python/autograph/converters/call_trees.py @@ -261,7 +261,7 @@ class CallTreeTransformer(converter.Base): func=func, owner=owner, options=self.ctx.program.options.to_ast( - self.ctx.info.namespace, + self.ctx, internal_convert_user_code=self.ctx.program.options.recursive), args=node.args) # TODO(mdan): Improve the template mechanism to better support this. diff --git a/tensorflow/python/autograph/core/converter.py b/tensorflow/python/autograph/core/converter.py index eea2621056b..b9c2449566e 100644 --- a/tensorflow/python/autograph/core/converter.py +++ b/tensorflow/python/autograph/core/converter.py @@ -179,15 +179,14 @@ class ConversionOptions(object): return (Feature.ALL in self.optional_features or feature in self.optional_features) - def to_ast(self, namespace, internal_convert_user_code=None): + def to_ast(self, ctx, internal_convert_user_code=None): """Returns a representation of this object as an AST node. The AST node encodes a constructor that would create an object with the same contents. Args: - namespace: Dict[str, Any], the namespace to use when serializing values to - names. + ctx: EntityContext, the entity with which this AST needs to be consistent. internal_convert_user_code: Optional[bool], allows ovrriding the corresponding value. @@ -205,10 +204,11 @@ class ConversionOptions(object): """ def as_qualified_name(o): - name = inspect_utils.getqualifiedname(namespace, o) + name = inspect_utils.getqualifiedname(ctx.info.namespace, o, max_depth=1) if not name: - raise ValueError('Could not locate entity {} in {}'.format( - o, namespace)) + # TODO(mdan): This needs to account for the symbols defined locally. + name = ctx.namer.new_symbol(o.__name__, ()) + ctx.program.add_symbol(name, o) return name def list_of_names(values): @@ -279,6 +279,7 @@ class ProgramContext(object): self.dependency_cache = {} self.additional_imports = set() self.name_map = {} + self.additional_symbols = {} @property def required_imports(self): @@ -321,6 +322,11 @@ class ProgramContext(object): else: self.name_map[o] = name + def add_symbol(self, name, value): + if name in self.additional_symbols: + assert self.additional_symbols[name] is value + self.additional_symbols[name] = value + def add_to_cache(self, original_entity, converted_ast): self.conversion_order.append(original_entity) self.dependency_cache[original_entity] = converted_ast diff --git a/tensorflow/python/autograph/impl/api.py b/tensorflow/python/autograph/impl/api.py index 54b46b1efdb..a20ad71c97c 100644 --- a/tensorflow/python/autograph/impl/api.py +++ b/tensorflow/python/autograph/impl/api.py @@ -424,6 +424,9 @@ def to_graph(entity, # Avoid overwriting entities that have been transformed. if key not in compiled_module.__dict__: compiled_module.__dict__[key] = val + for key, val in program_ctx.additional_symbols.items(): + if key not in compiled_module.__dict__: + compiled_module.__dict__[key] = val compiled = getattr(compiled_module, name) if tf_inspect.isfunction(entity): diff --git a/tensorflow/python/autograph/pyct/inspect_utils.py b/tensorflow/python/autograph/pyct/inspect_utils.py index 7c819f364fa..56945b464b9 100644 --- a/tensorflow/python/autograph/pyct/inspect_utils.py +++ b/tensorflow/python/autograph/pyct/inspect_utils.py @@ -101,7 +101,7 @@ def getnamespace(f): return namespace -def getqualifiedname(namespace, object_, max_depth=2): +def getqualifiedname(namespace, object_, max_depth=7, visited=None): """Returns the name by which a value can be referred to in a given namespace. If the object defines a parent module, the function attempts to use it to @@ -115,16 +115,20 @@ def getqualifiedname(namespace, object_, max_depth=2): object_: Any, the value to search. max_depth: Optional[int], a limit to the recursion depth when searching inside modules. + visited: Optional[Set[int]], ID of modules to avoid visiting. Returns: Union[str, None], the fully-qualified name that resolves to the value o, or None if it couldn't be found. """ - for name, value in namespace.items(): + if visited is None: + visited = set() + + for name in namespace: # The value may be referenced by more than one symbol, case in which # any symbol will be fine. If the program contains symbol aliases that # change over time, this may capture a symbol that will later point to # something else. # TODO(mdan): Prefer the symbol that matches the value type name. - if object_ is value: + if object_ is namespace[name]: return name # If an object is not found, try to search its parent modules. @@ -132,22 +136,25 @@ def getqualifiedname(namespace, object_, max_depth=2): if (parent is not None and parent is not object_ and parent is not namespace): # No limit to recursion depth because of the guard above. - parent_name = getqualifiedname(namespace, parent, max_depth=0) + parent_name = getqualifiedname( + namespace, parent, max_depth=0, visited=visited) if parent_name is not None: - name_in_parent = getqualifiedname(parent.__dict__, object_, max_depth=0) + name_in_parent = getqualifiedname( + parent.__dict__, object_, max_depth=0, visited=visited) assert name_in_parent is not None, ( 'An object should always be found in its owner module') return '{}.{}'.format(parent_name, name_in_parent) - # TODO(mdan): Use breadth-first search and avoid visiting modules twice. if max_depth: # Iterating over a copy prevents "changed size due to iteration" errors. # It's unclear why those occur - suspecting new modules may load during # iteration. - for name, value in namespace.copy().items(): - if tf_inspect.ismodule(value): + for name in tuple(namespace.keys()): + value = namespace[name] + if tf_inspect.ismodule(value) and id(value) not in visited: + visited.add(id(value)) name_in_module = getqualifiedname(value.__dict__, object_, - max_depth - 1) + max_depth - 1, visited) if name_in_module is not None: return '{}.{}'.format(name, name_in_module) return None diff --git a/tensorflow/python/autograph/pyct/inspect_utils_test.py b/tensorflow/python/autograph/pyct/inspect_utils_test.py index a2c39056d1b..420a20c22f2 100644 --- a/tensorflow/python/autograph/pyct/inspect_utils_test.py +++ b/tensorflow/python/autograph/pyct/inspect_utils_test.py @@ -183,6 +183,63 @@ class InspectUtilsTest(test.TestCase): self.assertEqual(inspect_utils.getqualifiedname(ns, bar), 'bar') self.assertEqual(inspect_utils.getqualifiedname(ns, baz), 'bar.baz') + def test_getqualifiedname_efficiency(self): + foo = object() + + # We create a densely connected graph consisting of a relatively small + # number of modules and hide our symbol in one of them. The path to the + # symbol is at least 10, and each node has about 10 neighbors. However, + # by skipping visited modules, the search should take much less. + ns = {} + prev_level = [] + for i in range(10): + current_level = [] + for j in range(10): + mod_name = 'mod_{}_{}'.format(i, j) + mod = imp.new_module(mod_name) + current_level.append(mod) + if i == 9 and j == 9: + mod.foo = foo + if prev_level: + # All modules at level i refer to all modules at level i+1 + for prev in prev_level: + for mod in current_level: + prev.__dict__[mod.__name__] = mod + else: + for mod in current_level: + ns[mod.__name__] = mod + prev_level = current_level + + self.assertIsNone(inspect_utils.getqualifiedname(ns, inspect_utils)) + self.assertIsNotNone( + inspect_utils.getqualifiedname(ns, foo, max_depth=10000000000)) + + def test_getqualifiedname_cycles(self): + foo = object() + + # We create a graph of modules that contains circular references. The + # search process should avoid them. The searched object is hidden at the + # bottom of a path of length roughly 10. + ns = {} + mods = [] + for i in range(10): + mod = imp.new_module('mod_{}'.format(i)) + if i == 9: + mod.foo = foo + # Module i refers to module i+1 + if mods: + mods[-1].__dict__[mod.__name__] = mod + else: + ns[mod.__name__] = mod + # Module i refers to all modules j < i. + for prev in mods: + mod.__dict__[prev.__name__] = prev + mods.append(mod) + + self.assertIsNone(inspect_utils.getqualifiedname(ns, inspect_utils)) + self.assertIsNotNone( + inspect_utils.getqualifiedname(ns, foo, max_depth=10000000000)) + def test_getqualifiedname_finds_via_parent_module(self): # TODO(mdan): This test is vulnerable to change in the lib module. # A better way to forge modules should be found. diff --git a/tensorflow/python/util/tf_decorator.py b/tensorflow/python/util/tf_decorator.py index 0cfc836246d..f018e1a1bd3 100644 --- a/tensorflow/python/util/tf_decorator.py +++ b/tensorflow/python/util/tf_decorator.py @@ -98,6 +98,9 @@ def make_decorator(target, if hasattr(target, '__doc__'): decorator_func.__doc__ = decorator.__doc__ decorator_func.__wrapped__ = target + # Keeping a second handle to `target` allows callers to detect whether the + # decorator was modified using `rewrap`. + decorator_func.__original_wrapped__ = target return decorator_func @@ -173,6 +176,8 @@ def unwrap(maybe_tf_decorator): decorators.append(getattr(cur, '_tf_decorator')) else: break + if not hasattr(decorators[-1], 'decorated_target'): + break cur = decorators[-1].decorated_target return decorators, cur -- GitLab