提交 35079100 编写于 作者: D Dan Moldovan 提交者: TensorFlower Gardener

Reduce the cost of serializing ConversionOptions to code, by using a more...

Reduce the cost of serializing ConversionOptions to code, by using a more efficient inspect.util.getqualifiedname, reducing its max_depth and falling back to caching the value in the namespace. The latter step makes it more difficult to run the generated code afterwards, but it should in turn speed up the conversion process. This also adds an extra check to tf_decorator to improve robustness.

PiperOrigin-RevId: 225226256
上级 3ae0654d
......@@ -261,7 +261,7 @@ class CallTreeTransformer(converter.Base):
func=func,
owner=owner,
options=self.ctx.program.options.to_ast(
self.ctx.info.namespace,
self.ctx,
internal_convert_user_code=self.ctx.program.options.recursive),
args=node.args)
# TODO(mdan): Improve the template mechanism to better support this.
......
......@@ -179,15 +179,14 @@ class ConversionOptions(object):
return (Feature.ALL in self.optional_features or
feature in self.optional_features)
def to_ast(self, namespace, internal_convert_user_code=None):
def to_ast(self, ctx, internal_convert_user_code=None):
"""Returns a representation of this object as an AST node.
The AST node encodes a constructor that would create an object with the
same contents.
Args:
namespace: Dict[str, Any], the namespace to use when serializing values to
names.
ctx: EntityContext, the entity with which this AST needs to be consistent.
internal_convert_user_code: Optional[bool], allows ovrriding the
corresponding value.
......@@ -205,10 +204,11 @@ class ConversionOptions(object):
"""
def as_qualified_name(o):
name = inspect_utils.getqualifiedname(namespace, o)
name = inspect_utils.getqualifiedname(ctx.info.namespace, o, max_depth=1)
if not name:
raise ValueError('Could not locate entity {} in {}'.format(
o, namespace))
# TODO(mdan): This needs to account for the symbols defined locally.
name = ctx.namer.new_symbol(o.__name__, ())
ctx.program.add_symbol(name, o)
return name
def list_of_names(values):
......@@ -279,6 +279,7 @@ class ProgramContext(object):
self.dependency_cache = {}
self.additional_imports = set()
self.name_map = {}
self.additional_symbols = {}
@property
def required_imports(self):
......@@ -321,6 +322,11 @@ class ProgramContext(object):
else:
self.name_map[o] = name
def add_symbol(self, name, value):
if name in self.additional_symbols:
assert self.additional_symbols[name] is value
self.additional_symbols[name] = value
def add_to_cache(self, original_entity, converted_ast):
self.conversion_order.append(original_entity)
self.dependency_cache[original_entity] = converted_ast
......
......@@ -424,6 +424,9 @@ def to_graph(entity,
# Avoid overwriting entities that have been transformed.
if key not in compiled_module.__dict__:
compiled_module.__dict__[key] = val
for key, val in program_ctx.additional_symbols.items():
if key not in compiled_module.__dict__:
compiled_module.__dict__[key] = val
compiled = getattr(compiled_module, name)
if tf_inspect.isfunction(entity):
......
......@@ -101,7 +101,7 @@ def getnamespace(f):
return namespace
def getqualifiedname(namespace, object_, max_depth=2):
def getqualifiedname(namespace, object_, max_depth=7, visited=None):
"""Returns the name by which a value can be referred to in a given namespace.
If the object defines a parent module, the function attempts to use it to
......@@ -115,16 +115,20 @@ def getqualifiedname(namespace, object_, max_depth=2):
object_: Any, the value to search.
max_depth: Optional[int], a limit to the recursion depth when searching
inside modules.
visited: Optional[Set[int]], ID of modules to avoid visiting.
Returns: Union[str, None], the fully-qualified name that resolves to the value
o, or None if it couldn't be found.
"""
for name, value in namespace.items():
if visited is None:
visited = set()
for name in namespace:
# The value may be referenced by more than one symbol, case in which
# any symbol will be fine. If the program contains symbol aliases that
# change over time, this may capture a symbol that will later point to
# something else.
# TODO(mdan): Prefer the symbol that matches the value type name.
if object_ is value:
if object_ is namespace[name]:
return name
# If an object is not found, try to search its parent modules.
......@@ -132,22 +136,25 @@ def getqualifiedname(namespace, object_, max_depth=2):
if (parent is not None and parent is not object_ and
parent is not namespace):
# No limit to recursion depth because of the guard above.
parent_name = getqualifiedname(namespace, parent, max_depth=0)
parent_name = getqualifiedname(
namespace, parent, max_depth=0, visited=visited)
if parent_name is not None:
name_in_parent = getqualifiedname(parent.__dict__, object_, max_depth=0)
name_in_parent = getqualifiedname(
parent.__dict__, object_, max_depth=0, visited=visited)
assert name_in_parent is not None, (
'An object should always be found in its owner module')
return '{}.{}'.format(parent_name, name_in_parent)
# TODO(mdan): Use breadth-first search and avoid visiting modules twice.
if max_depth:
# Iterating over a copy prevents "changed size due to iteration" errors.
# It's unclear why those occur - suspecting new modules may load during
# iteration.
for name, value in namespace.copy().items():
if tf_inspect.ismodule(value):
for name in tuple(namespace.keys()):
value = namespace[name]
if tf_inspect.ismodule(value) and id(value) not in visited:
visited.add(id(value))
name_in_module = getqualifiedname(value.__dict__, object_,
max_depth - 1)
max_depth - 1, visited)
if name_in_module is not None:
return '{}.{}'.format(name, name_in_module)
return None
......
......@@ -183,6 +183,63 @@ class InspectUtilsTest(test.TestCase):
self.assertEqual(inspect_utils.getqualifiedname(ns, bar), 'bar')
self.assertEqual(inspect_utils.getqualifiedname(ns, baz), 'bar.baz')
def test_getqualifiedname_efficiency(self):
foo = object()
# We create a densely connected graph consisting of a relatively small
# number of modules and hide our symbol in one of them. The path to the
# symbol is at least 10, and each node has about 10 neighbors. However,
# by skipping visited modules, the search should take much less.
ns = {}
prev_level = []
for i in range(10):
current_level = []
for j in range(10):
mod_name = 'mod_{}_{}'.format(i, j)
mod = imp.new_module(mod_name)
current_level.append(mod)
if i == 9 and j == 9:
mod.foo = foo
if prev_level:
# All modules at level i refer to all modules at level i+1
for prev in prev_level:
for mod in current_level:
prev.__dict__[mod.__name__] = mod
else:
for mod in current_level:
ns[mod.__name__] = mod
prev_level = current_level
self.assertIsNone(inspect_utils.getqualifiedname(ns, inspect_utils))
self.assertIsNotNone(
inspect_utils.getqualifiedname(ns, foo, max_depth=10000000000))
def test_getqualifiedname_cycles(self):
foo = object()
# We create a graph of modules that contains circular references. The
# search process should avoid them. The searched object is hidden at the
# bottom of a path of length roughly 10.
ns = {}
mods = []
for i in range(10):
mod = imp.new_module('mod_{}'.format(i))
if i == 9:
mod.foo = foo
# Module i refers to module i+1
if mods:
mods[-1].__dict__[mod.__name__] = mod
else:
ns[mod.__name__] = mod
# Module i refers to all modules j < i.
for prev in mods:
mod.__dict__[prev.__name__] = prev
mods.append(mod)
self.assertIsNone(inspect_utils.getqualifiedname(ns, inspect_utils))
self.assertIsNotNone(
inspect_utils.getqualifiedname(ns, foo, max_depth=10000000000))
def test_getqualifiedname_finds_via_parent_module(self):
# TODO(mdan): This test is vulnerable to change in the lib module.
# A better way to forge modules should be found.
......
......@@ -98,6 +98,9 @@ def make_decorator(target,
if hasattr(target, '__doc__'):
decorator_func.__doc__ = decorator.__doc__
decorator_func.__wrapped__ = target
# Keeping a second handle to `target` allows callers to detect whether the
# decorator was modified using `rewrap`.
decorator_func.__original_wrapped__ = target
return decorator_func
......@@ -173,6 +176,8 @@ def unwrap(maybe_tf_decorator):
decorators.append(getattr(cur, '_tf_decorator'))
else:
break
if not hasattr(decorators[-1], 'decorated_target'):
break
cur = decorators[-1].decorated_target
return decorators, cur
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册