diff --git a/tensorflow/python/autograph/converters/call_trees.py b/tensorflow/python/autograph/converters/call_trees.py index 3e0b40290f37646c09ffd4058b2cf20e160660bc..b1bfe04347e9f1ce6a11eeee435aa6a48995f27a 100644 --- a/tensorflow/python/autograph/converters/call_trees.py +++ b/tensorflow/python/autograph/converters/call_trees.py @@ -261,7 +261,7 @@ class CallTreeTransformer(converter.Base): func=func, owner=owner, options=self.ctx.program.options.to_ast( - self.ctx.info.namespace, + self.ctx, internal_convert_user_code=self.ctx.program.options.recursive), args=node.args) # TODO(mdan): Improve the template mechanism to better support this. diff --git a/tensorflow/python/autograph/core/converter.py b/tensorflow/python/autograph/core/converter.py index eea2621056b36f97a77fa2035c913f3b56abcf4b..b9c2449566ec8b7b4c2f1b2f1fdd919f0d18bf32 100644 --- a/tensorflow/python/autograph/core/converter.py +++ b/tensorflow/python/autograph/core/converter.py @@ -179,15 +179,14 @@ class ConversionOptions(object): return (Feature.ALL in self.optional_features or feature in self.optional_features) - def to_ast(self, namespace, internal_convert_user_code=None): + def to_ast(self, ctx, internal_convert_user_code=None): """Returns a representation of this object as an AST node. The AST node encodes a constructor that would create an object with the same contents. Args: - namespace: Dict[str, Any], the namespace to use when serializing values to - names. + ctx: EntityContext, the entity with which this AST needs to be consistent. internal_convert_user_code: Optional[bool], allows ovrriding the corresponding value. @@ -205,10 +204,11 @@ class ConversionOptions(object): """ def as_qualified_name(o): - name = inspect_utils.getqualifiedname(namespace, o) + name = inspect_utils.getqualifiedname(ctx.info.namespace, o, max_depth=1) if not name: - raise ValueError('Could not locate entity {} in {}'.format( - o, namespace)) + # TODO(mdan): This needs to account for the symbols defined locally. + name = ctx.namer.new_symbol(o.__name__, ()) + ctx.program.add_symbol(name, o) return name def list_of_names(values): @@ -279,6 +279,7 @@ class ProgramContext(object): self.dependency_cache = {} self.additional_imports = set() self.name_map = {} + self.additional_symbols = {} @property def required_imports(self): @@ -321,6 +322,11 @@ class ProgramContext(object): else: self.name_map[o] = name + def add_symbol(self, name, value): + if name in self.additional_symbols: + assert self.additional_symbols[name] is value + self.additional_symbols[name] = value + def add_to_cache(self, original_entity, converted_ast): self.conversion_order.append(original_entity) self.dependency_cache[original_entity] = converted_ast diff --git a/tensorflow/python/autograph/impl/api.py b/tensorflow/python/autograph/impl/api.py index 54b46b1efdbc14081b5ef00cc9a179c99b5439ec..a20ad71c97cb37df772712c3e00c3fdc16b6fb8d 100644 --- a/tensorflow/python/autograph/impl/api.py +++ b/tensorflow/python/autograph/impl/api.py @@ -424,6 +424,9 @@ def to_graph(entity, # Avoid overwriting entities that have been transformed. if key not in compiled_module.__dict__: compiled_module.__dict__[key] = val + for key, val in program_ctx.additional_symbols.items(): + if key not in compiled_module.__dict__: + compiled_module.__dict__[key] = val compiled = getattr(compiled_module, name) if tf_inspect.isfunction(entity): diff --git a/tensorflow/python/autograph/pyct/inspect_utils.py b/tensorflow/python/autograph/pyct/inspect_utils.py index 7c819f364fa79d40c0fbb080b3b358b36bfd8c0c..56945b464b9f89cf8f56f02bf65e37e634af8f90 100644 --- a/tensorflow/python/autograph/pyct/inspect_utils.py +++ b/tensorflow/python/autograph/pyct/inspect_utils.py @@ -101,7 +101,7 @@ def getnamespace(f): return namespace -def getqualifiedname(namespace, object_, max_depth=2): +def getqualifiedname(namespace, object_, max_depth=7, visited=None): """Returns the name by which a value can be referred to in a given namespace. If the object defines a parent module, the function attempts to use it to @@ -115,16 +115,20 @@ def getqualifiedname(namespace, object_, max_depth=2): object_: Any, the value to search. max_depth: Optional[int], a limit to the recursion depth when searching inside modules. + visited: Optional[Set[int]], ID of modules to avoid visiting. Returns: Union[str, None], the fully-qualified name that resolves to the value o, or None if it couldn't be found. """ - for name, value in namespace.items(): + if visited is None: + visited = set() + + for name in namespace: # The value may be referenced by more than one symbol, case in which # any symbol will be fine. If the program contains symbol aliases that # change over time, this may capture a symbol that will later point to # something else. # TODO(mdan): Prefer the symbol that matches the value type name. - if object_ is value: + if object_ is namespace[name]: return name # If an object is not found, try to search its parent modules. @@ -132,22 +136,25 @@ def getqualifiedname(namespace, object_, max_depth=2): if (parent is not None and parent is not object_ and parent is not namespace): # No limit to recursion depth because of the guard above. - parent_name = getqualifiedname(namespace, parent, max_depth=0) + parent_name = getqualifiedname( + namespace, parent, max_depth=0, visited=visited) if parent_name is not None: - name_in_parent = getqualifiedname(parent.__dict__, object_, max_depth=0) + name_in_parent = getqualifiedname( + parent.__dict__, object_, max_depth=0, visited=visited) assert name_in_parent is not None, ( 'An object should always be found in its owner module') return '{}.{}'.format(parent_name, name_in_parent) - # TODO(mdan): Use breadth-first search and avoid visiting modules twice. if max_depth: # Iterating over a copy prevents "changed size due to iteration" errors. # It's unclear why those occur - suspecting new modules may load during # iteration. - for name, value in namespace.copy().items(): - if tf_inspect.ismodule(value): + for name in tuple(namespace.keys()): + value = namespace[name] + if tf_inspect.ismodule(value) and id(value) not in visited: + visited.add(id(value)) name_in_module = getqualifiedname(value.__dict__, object_, - max_depth - 1) + max_depth - 1, visited) if name_in_module is not None: return '{}.{}'.format(name, name_in_module) return None diff --git a/tensorflow/python/autograph/pyct/inspect_utils_test.py b/tensorflow/python/autograph/pyct/inspect_utils_test.py index a2c39056d1b09dbae937915cf17de5c6f55d4886..420a20c22f22ace67b30d0f79f11014a04e646d1 100644 --- a/tensorflow/python/autograph/pyct/inspect_utils_test.py +++ b/tensorflow/python/autograph/pyct/inspect_utils_test.py @@ -183,6 +183,63 @@ class InspectUtilsTest(test.TestCase): self.assertEqual(inspect_utils.getqualifiedname(ns, bar), 'bar') self.assertEqual(inspect_utils.getqualifiedname(ns, baz), 'bar.baz') + def test_getqualifiedname_efficiency(self): + foo = object() + + # We create a densely connected graph consisting of a relatively small + # number of modules and hide our symbol in one of them. The path to the + # symbol is at least 10, and each node has about 10 neighbors. However, + # by skipping visited modules, the search should take much less. + ns = {} + prev_level = [] + for i in range(10): + current_level = [] + for j in range(10): + mod_name = 'mod_{}_{}'.format(i, j) + mod = imp.new_module(mod_name) + current_level.append(mod) + if i == 9 and j == 9: + mod.foo = foo + if prev_level: + # All modules at level i refer to all modules at level i+1 + for prev in prev_level: + for mod in current_level: + prev.__dict__[mod.__name__] = mod + else: + for mod in current_level: + ns[mod.__name__] = mod + prev_level = current_level + + self.assertIsNone(inspect_utils.getqualifiedname(ns, inspect_utils)) + self.assertIsNotNone( + inspect_utils.getqualifiedname(ns, foo, max_depth=10000000000)) + + def test_getqualifiedname_cycles(self): + foo = object() + + # We create a graph of modules that contains circular references. The + # search process should avoid them. The searched object is hidden at the + # bottom of a path of length roughly 10. + ns = {} + mods = [] + for i in range(10): + mod = imp.new_module('mod_{}'.format(i)) + if i == 9: + mod.foo = foo + # Module i refers to module i+1 + if mods: + mods[-1].__dict__[mod.__name__] = mod + else: + ns[mod.__name__] = mod + # Module i refers to all modules j < i. + for prev in mods: + mod.__dict__[prev.__name__] = prev + mods.append(mod) + + self.assertIsNone(inspect_utils.getqualifiedname(ns, inspect_utils)) + self.assertIsNotNone( + inspect_utils.getqualifiedname(ns, foo, max_depth=10000000000)) + def test_getqualifiedname_finds_via_parent_module(self): # TODO(mdan): This test is vulnerable to change in the lib module. # A better way to forge modules should be found. diff --git a/tensorflow/python/util/tf_decorator.py b/tensorflow/python/util/tf_decorator.py index 0cfc836246d2d885c28d168fe90b08a325cf6ded..f018e1a1bd35f0111cacc20e678c0466bfd5f2e3 100644 --- a/tensorflow/python/util/tf_decorator.py +++ b/tensorflow/python/util/tf_decorator.py @@ -98,6 +98,9 @@ def make_decorator(target, if hasattr(target, '__doc__'): decorator_func.__doc__ = decorator.__doc__ decorator_func.__wrapped__ = target + # Keeping a second handle to `target` allows callers to detect whether the + # decorator was modified using `rewrap`. + decorator_func.__original_wrapped__ = target return decorator_func @@ -173,6 +176,8 @@ def unwrap(maybe_tf_decorator): decorators.append(getattr(cur, '_tf_decorator')) else: break + if not hasattr(decorators[-1], 'decorated_target'): + break cur = decorators[-1].decorated_target return decorators, cur