提交 9a6a3793 编写于 作者: M Megvii Engine Team

feat(traced_module): add visit method

GitOrigin-RevId: 251ecebf87c94fd5b60c27596a45149d479603e9
上级 442b4f6c
......@@ -10,7 +10,7 @@ import collections
import copy
import functools
from inspect import getmembers, isclass, ismethod
from typing import List, Type
from typing import Dict, List, Type
from ... import module as M
from ...core._imperative_rt.core2 import Tensor as RawTensor
......@@ -64,6 +64,14 @@ class InternalGraph:
def insert(self, expr):
self._exprs.append(expr)
@property
def inputs(self):
return self._inputs
@property
def outputs(self):
return self._outputs
def add_input(self, i):
self._inputs.append(i)
......@@ -271,6 +279,22 @@ class TracedModuleBuilder(NodeMixin):
return wrapped
class _expr_list:
def __init__(self, module: "TracedModule"):
self.module = module
def __iter__(self):
graph = self.module.m_node.graph
for expr in graph._exprs:
if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode):
yield expr
assert isinstance(expr.inputs[0].expr, GetAttr)
(obj,) = expr.inputs[0].expr.interpret(self.module)
if isinstance(obj, TracedModule):
yield from obj.exprs
yield expr
class TracedModule(Module):
"""
`TracedModule` is the Module created by tracing normal module. It owns a ModuleNode(m_node), and will interpret the m_node.graph when it is called.
......@@ -291,14 +315,21 @@ class TracedModule(Module):
return rst
@property
def all_exprs(self):
def exprs(self):
"""
Get all ``Expr`` s recursively.
:return: Iterator[Expr]
"""
Visit all ``Expr``s in the graph recursively.
return _expr_list(self)
:return: List[Expr]
def flatten(self):
"""
Get a new module, which eliminates ``GetAttr`` and has no hierarchy.
in_nodes = [i.expr for i in self.m_node.graph._inputs if not i is self]
:return: :class:`TracedModule`
"""
new_module = copy.deepcopy(self)
def _flatten_submodule(module, call=None):
if not isinstance(module, TracedModule):
......@@ -328,6 +359,7 @@ class TracedModule(Module):
elif isinstance(expr, CallMethod):
obj_node = expr.inputs[0]
if isinstance(obj_node, ModuleNode):
assert isinstance(expr.inputs[0].expr, GetAttr)
(obj,) = expr.inputs[0].expr.interpret(module)
exprs.extend(_flatten_submodule(obj, expr))
else:
......@@ -337,7 +369,9 @@ class TracedModule(Module):
return exprs
return in_nodes + _flatten_submodule(self)
new_module.m_node.graph._exprs = _flatten_submodule(new_module)
return new_module
def __getstate__(self):
d = self.__dict__
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册