提交 9a6a3793 编写于 作者: M Megvii Engine Team

feat(traced_module): add visit method

GitOrigin-RevId: 251ecebf87c94fd5b60c27596a45149d479603e9
上级 442b4f6c
...@@ -10,7 +10,7 @@ import collections ...@@ -10,7 +10,7 @@ import collections
import copy import copy
import functools import functools
from inspect import getmembers, isclass, ismethod from inspect import getmembers, isclass, ismethod
from typing import List, Type from typing import Dict, List, Type
from ... import module as M from ... import module as M
from ...core._imperative_rt.core2 import Tensor as RawTensor from ...core._imperative_rt.core2 import Tensor as RawTensor
...@@ -64,6 +64,14 @@ class InternalGraph: ...@@ -64,6 +64,14 @@ class InternalGraph:
def insert(self, expr): def insert(self, expr):
self._exprs.append(expr) self._exprs.append(expr)
@property
def inputs(self):
return self._inputs
@property
def outputs(self):
return self._outputs
def add_input(self, i): def add_input(self, i):
self._inputs.append(i) self._inputs.append(i)
...@@ -271,6 +279,22 @@ class TracedModuleBuilder(NodeMixin): ...@@ -271,6 +279,22 @@ class TracedModuleBuilder(NodeMixin):
return wrapped return wrapped
class _expr_list:
def __init__(self, module: "TracedModule"):
self.module = module
def __iter__(self):
graph = self.module.m_node.graph
for expr in graph._exprs:
if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode):
yield expr
assert isinstance(expr.inputs[0].expr, GetAttr)
(obj,) = expr.inputs[0].expr.interpret(self.module)
if isinstance(obj, TracedModule):
yield from obj.exprs
yield expr
class TracedModule(Module): class TracedModule(Module):
""" """
`TracedModule` is the Module created by tracing normal module. It owns a ModuleNode(m_node), and will interpret the m_node.graph when it is called. `TracedModule` is the Module created by tracing normal module. It owns a ModuleNode(m_node), and will interpret the m_node.graph when it is called.
...@@ -291,14 +315,21 @@ class TracedModule(Module): ...@@ -291,14 +315,21 @@ class TracedModule(Module):
return rst return rst
@property @property
def all_exprs(self): def exprs(self):
"""
Get all ``Expr`` s recursively.
:return: Iterator[Expr]
""" """
Visit all ``Expr``s in the graph recursively. return _expr_list(self)
:return: List[Expr] def flatten(self):
""" """
Get a new module, which eliminates ``GetAttr`` and has no hierarchy.
in_nodes = [i.expr for i in self.m_node.graph._inputs if not i is self] :return: :class:`TracedModule`
"""
new_module = copy.deepcopy(self)
def _flatten_submodule(module, call=None): def _flatten_submodule(module, call=None):
if not isinstance(module, TracedModule): if not isinstance(module, TracedModule):
...@@ -328,6 +359,7 @@ class TracedModule(Module): ...@@ -328,6 +359,7 @@ class TracedModule(Module):
elif isinstance(expr, CallMethod): elif isinstance(expr, CallMethod):
obj_node = expr.inputs[0] obj_node = expr.inputs[0]
if isinstance(obj_node, ModuleNode): if isinstance(obj_node, ModuleNode):
assert isinstance(expr.inputs[0].expr, GetAttr)
(obj,) = expr.inputs[0].expr.interpret(module) (obj,) = expr.inputs[0].expr.interpret(module)
exprs.extend(_flatten_submodule(obj, expr)) exprs.extend(_flatten_submodule(obj, expr))
else: else:
...@@ -337,7 +369,9 @@ class TracedModule(Module): ...@@ -337,7 +369,9 @@ class TracedModule(Module):
return exprs return exprs
return in_nodes + _flatten_submodule(self) new_module.m_node.graph._exprs = _flatten_submodule(new_module)
return new_module
def __getstate__(self): def __getstate__(self):
d = self.__dict__ d = self.__dict__
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册