提交 de67c554 编写于 作者: D Dan Moldovan 提交者: TensorFlower Gardener

Use weak refs when caching symbols in the namespace, to avoid tripping...

Use weak refs when caching symbols in the namespace, to avoid tripping circular reference detectors in tests.

PiperOrigin-RevId: 225376019
上级 04f51da5
......@@ -183,7 +183,7 @@ class CallTreeTransformer(converter.Base):
for dec in target_node.decorator_list:
decorator_fn = self._resolve_decorator_name(dec)
if (decorator_fn is not None and
decorator_fn in self.ctx.program.options.strip_decorators):
self.ctx.program.options.should_strip(decorator_fn)):
return False
return True
......
......@@ -63,6 +63,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import weakref
import enum
from tensorflow.python.autograph.core import config
......@@ -175,6 +177,16 @@ class ConversionOptions(object):
# TODO(mdan): Revert if function.defun becomes a public symbol.
return self._strip_decorators + (function.defun,)
def should_strip(self, decorator):
for blacklisted in self.strip_decorators:
if blacklisted is decorator:
return True
if isinstance(blacklisted, weakref.ref):
blacklisted_deref = blacklisted()
if (blacklisted_deref is not None and blacklisted_deref is decorator):
return True
return False
def uses(self, feature):
return (Feature.ALL in self.optional_features or
feature in self.optional_features)
......@@ -208,7 +220,7 @@ class ConversionOptions(object):
if not name:
# TODO(mdan): This needs to account for the symbols defined locally.
name = ctx.namer.new_symbol(o.__name__, ())
ctx.program.add_symbol(name, o)
ctx.program.add_symbol(name, weakref.ref(o))
return name
def list_of_names(values):
......
......@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import weakref
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.autograph.pyct import anno
......@@ -29,6 +31,36 @@ class TestConverter(converter.Base):
pass
class ConversionOptionsTest(test.TestCase):
def test_should_strip_weakrefs(self):
def test_fn():
pass
def weak_test_fn_a():
pass
def weak_test_fn_b():
pass
def weak_test_fn_c():
pass
wr_a = weakref.ref(weak_test_fn_a)
# Create an extra weakref to check whether the existence of multiple weak
# references influences the process.
_ = weakref.ref(weak_test_fn_b)
wr_b = weakref.ref(weak_test_fn_b)
_ = weakref.ref(weak_test_fn_c)
opts = converter.ConversionOptions(strip_decorators=(test_fn, wr_a, wr_b))
self.assertTrue(opts.should_strip(test_fn))
self.assertTrue(opts.should_strip(weak_test_fn_a))
self.assertTrue(opts.should_strip(weak_test_fn_b))
self.assertFalse(opts.should_strip(weak_test_fn_c))
class ConverterBaseTest(converter_testing.TestCase):
def test_get_definition_directive_basic(self):
......
......@@ -101,7 +101,7 @@ def getnamespace(f):
return namespace
def getqualifiedname(namespace, object_, max_depth=7, visited=None):
def getqualifiedname(namespace, object_, max_depth=5, visited=None):
"""Returns the name by which a value can be referred to in a given namespace.
If the object defines a parent module, the function attempts to use it to
......@@ -149,7 +149,7 @@ def getqualifiedname(namespace, object_, max_depth=7, visited=None):
# Iterating over a copy prevents "changed size due to iteration" errors.
# It's unclear why those occur - suspecting new modules may load during
# iteration.
for name in tuple(namespace.keys()):
for name in namespace.keys():
value = namespace[name]
if tf_inspect.ismodule(value) and id(value) not in visited:
visited.add(id(value))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册