From de67c554a1496616e007548ca862ea74d7fc6ae3 Mon Sep 17 00:00:00 2001 From: Dan Moldovan Date: Thu, 13 Dec 2018 08:32:38 -0800 Subject: [PATCH] Use weak refs when caching symbols in the namespace, to avoid tripping circular reference detectors in tests. PiperOrigin-RevId: 225376019 --- .../python/autograph/converters/call_trees.py | 2 +- tensorflow/python/autograph/core/converter.py | 14 +++++++- .../python/autograph/core/converter_test.py | 32 +++++++++++++++++++ .../python/autograph/pyct/inspect_utils.py | 4 +-- 4 files changed, 48 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/autograph/converters/call_trees.py b/tensorflow/python/autograph/converters/call_trees.py index b1bfe04347e..d4eb17e976f 100644 --- a/tensorflow/python/autograph/converters/call_trees.py +++ b/tensorflow/python/autograph/converters/call_trees.py @@ -183,7 +183,7 @@ class CallTreeTransformer(converter.Base): for dec in target_node.decorator_list: decorator_fn = self._resolve_decorator_name(dec) if (decorator_fn is not None and - decorator_fn in self.ctx.program.options.strip_decorators): + self.ctx.program.options.should_strip(decorator_fn)): return False return True diff --git a/tensorflow/python/autograph/core/converter.py b/tensorflow/python/autograph/core/converter.py index b9c2449566e..4543b113983 100644 --- a/tensorflow/python/autograph/core/converter.py +++ b/tensorflow/python/autograph/core/converter.py @@ -63,6 +63,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import weakref + import enum from tensorflow.python.autograph.core import config @@ -175,6 +177,16 @@ class ConversionOptions(object): # TODO(mdan): Revert if function.defun becomes a public symbol. return self._strip_decorators + (function.defun,) + def should_strip(self, decorator): + for blacklisted in self.strip_decorators: + if blacklisted is decorator: + return True + if isinstance(blacklisted, weakref.ref): + blacklisted_deref = blacklisted() + if (blacklisted_deref is not None and blacklisted_deref is decorator): + return True + return False + def uses(self, feature): return (Feature.ALL in self.optional_features or feature in self.optional_features) @@ -208,7 +220,7 @@ class ConversionOptions(object): if not name: # TODO(mdan): This needs to account for the symbols defined locally. name = ctx.namer.new_symbol(o.__name__, ()) - ctx.program.add_symbol(name, o) + ctx.program.add_symbol(name, weakref.ref(o)) return name def list_of_names(values): diff --git a/tensorflow/python/autograph/core/converter_test.py b/tensorflow/python/autograph/core/converter_test.py index b73c67e3377..864ea6c7d2b 100644 --- a/tensorflow/python/autograph/core/converter_test.py +++ b/tensorflow/python/autograph/core/converter_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import weakref + from tensorflow.python.autograph.core import converter from tensorflow.python.autograph.core import converter_testing from tensorflow.python.autograph.pyct import anno @@ -29,6 +31,36 @@ class TestConverter(converter.Base): pass +class ConversionOptionsTest(test.TestCase): + + def test_should_strip_weakrefs(self): + def test_fn(): + pass + + def weak_test_fn_a(): + pass + + def weak_test_fn_b(): + pass + + def weak_test_fn_c(): + pass + + wr_a = weakref.ref(weak_test_fn_a) + # Create an extra weakref to check whether the existence of multiple weak + # references influences the process. + _ = weakref.ref(weak_test_fn_b) + wr_b = weakref.ref(weak_test_fn_b) + _ = weakref.ref(weak_test_fn_c) + + opts = converter.ConversionOptions(strip_decorators=(test_fn, wr_a, wr_b)) + + self.assertTrue(opts.should_strip(test_fn)) + self.assertTrue(opts.should_strip(weak_test_fn_a)) + self.assertTrue(opts.should_strip(weak_test_fn_b)) + self.assertFalse(opts.should_strip(weak_test_fn_c)) + + class ConverterBaseTest(converter_testing.TestCase): def test_get_definition_directive_basic(self): diff --git a/tensorflow/python/autograph/pyct/inspect_utils.py b/tensorflow/python/autograph/pyct/inspect_utils.py index 56945b464b9..360dd83b5e3 100644 --- a/tensorflow/python/autograph/pyct/inspect_utils.py +++ b/tensorflow/python/autograph/pyct/inspect_utils.py @@ -101,7 +101,7 @@ def getnamespace(f): return namespace -def getqualifiedname(namespace, object_, max_depth=7, visited=None): +def getqualifiedname(namespace, object_, max_depth=5, visited=None): """Returns the name by which a value can be referred to in a given namespace. If the object defines a parent module, the function attempts to use it to @@ -149,7 +149,7 @@ def getqualifiedname(namespace, object_, max_depth=7, visited=None): # Iterating over a copy prevents "changed size due to iteration" errors. # It's unclear why those occur - suspecting new modules may load during # iteration. - for name in tuple(namespace.keys()): + for name in namespace.keys(): value = namespace[name] if tf_inspect.ismodule(value) and id(value) not in visited: visited.add(id(value)) -- GitLab