提交 4bb25369 编写于 作者: M Megvii Engine Team

feat(traced_module): let CallFunction own graph

GitOrigin-RevId: 66cdbca7e54df07576a984c3fd48d3bcafb678f1
上级 9a6a3793
......@@ -17,7 +17,7 @@ from ...core._imperative_rt.core2 import apply, set_module_tracing, unset_module
from ...core.ops.special import Const
from ...module import Module
from ...tensor import Tensor
from .module_tracer import active_module_tracer
from .module_tracer import active_module_tracer, module_tracer
from .node import ModuleNode, Node, NodeMixin, TensorNode
from .pytree import TreeDef
......@@ -148,6 +148,15 @@ class CallMethod(Expr):
active_module_tracer().current_scope().insert(expr)
return expr
@property
def graph(self):
if isinstance(self.inputs[0], ModuleNode):
m_node = self.inputs[0]
if m_node.argdef_graph_map:
assert self.arg_def in m_node.argdef_graph_map
return m_node.argdef_graph_map[self.arg_def]
return None
def interpret(self, *inputs):
args, kwargs = self.unflatten_args(inputs)
obj = args[0]
......@@ -252,7 +261,9 @@ class Constant(Expr):
_constant_cache = {}
def __init__(self, c):
# TODO: type check, since not all types should be captured as constant
assert isinstance(c, (RawTensor, Module))
if isinstance(c, Module):
assert module_tracer.is_builtin(c)
self.value = c
self.inputs = []
node_cls = NodeMixin.get_wrapped_type(c)
......
......@@ -57,9 +57,13 @@ class ModuleNode(Node):
"""
module_type = Module # type: Type[Module]
graph = None
attr_type_map = None # type: Dict[str, Type[Any]]
arg_def = None # type: TreeDef
argdef_graph_map = None # type: Dict[Treedef, "InternalGraph"]
def __init__(self, expr: "Expr", name: str = None):
super().__init__(expr, name)
self.attr_type_map = {}
self.argdef_graph_map = {}
def __repr__(self):
if self._name is None:
......
......@@ -25,7 +25,7 @@ def _dict_flatten(inp):
for key, value in sorted(inp.items()):
results.append(value)
aux_data.append(key)
return results, aux_data
return results, tuple(aux_data)
def _dict_unflatten(inps, aux_data):
......@@ -43,16 +43,23 @@ register_supported_type(
def tree_flatten(
values, leaf_type: Callable = lambda x: type(x), is_leaf: Callable = lambda x: True
values,
leaf_type: Callable = lambda x: type(x),
is_leaf: Callable = lambda _: True,
is_const_leaf: Callable = lambda _: False,
):
if type(values) not in SUPPORTED_TYPE:
assert is_leaf(values)
return [values,], LeafDef(leaf_type(values))
node = LeafDef(leaf_type(values))
if is_const_leaf(values):
node.const_val = values
return [values,], node
rst = []
children_defs = []
children_values, aux_data = SUPPORTED_TYPE[type(values)].flatten(values)
for v in children_values:
v_list, treedef = tree_flatten(v, leaf_type)
v_list, treedef = tree_flatten(v, leaf_type, is_leaf, is_const_leaf)
rst.extend(v_list)
children_defs.append(treedef)
......@@ -75,6 +82,18 @@ class TreeDef:
start += ch.num_leaves
return SUPPORTED_TYPE[self.type].unflatten(children, self.aux_data)
def __hash__(self):
return hash(
tuple(
[
self.type,
self.aux_data,
self.num_leaves,
tuple([hash(x) for x in self.children_defs]),
]
)
)
def __eq__(self, other):
return (
self.type == other.type
......@@ -93,11 +112,20 @@ class LeafDef(TreeDef):
type = (type,)
super().__init__(type, None, [])
self.num_leaves = 1
self.const_val = None
def unflatten(self, leaves):
assert len(leaves) == 1
assert isinstance(leaves[0], self.type), self.type
return leaves[0]
def __eq__(self, other):
return self.type == other.type and self.const_val == other.const_val
def __hash__(self):
return hash(tuple([self.type, self.const_val]))
def __repr__(self):
return "Leaf({})".format(", ".join(t.__name__ for t in self.type))
return "Leaf({}[{}])".format(
", ".join(t.__name__ for t in self.type), self.const_val
)
......@@ -42,6 +42,12 @@ def _leaf_type(node):
return type(node)
def _is_const_leaf(node):
if isinstance(node, (RawTensor, NodeMixin, Module)):
return False
return True
class InternalGraph:
"""
``InternalGraph`` is a graph consist of ``Node`` and ``Expr``, it is used to represent the execution procedure of Module's forward method.
......@@ -72,6 +78,10 @@ class InternalGraph:
def outputs(self):
return self._outputs
@property
def exprs(self):
return _expr_list(self)
def add_input(self, i):
self._inputs.append(i)
......@@ -111,7 +121,9 @@ def _wrapped_function(orig_func):
def wrapped_fn(*args, **kwargs):
if is_tracing_module():
unset_module_tracing()
inputs, tree_def = tree_flatten((args, kwargs), leaf_type=_leaf_type)
inputs, tree_def = tree_flatten(
(args, kwargs), leaf_type=_leaf_type, is_const_leaf=_is_const_leaf
)
for i in inputs:
if not NodeMixin.get(i, None):
if isinstance(i, (RawTensor, NodeMixin)):
......@@ -140,21 +152,18 @@ class TracedModuleBuilder(NodeMixin):
_mod = None # type: Module
_body = None # type: InternalGraph
_is_builtin = None # type: bool
_arg_def = None # type: TreeDef
__builder_attributes__ = [
"_mod",
"_body",
"_NodeMixin__node",
"_is_builtin",
"_is_traced",
"_arg_def" "build",
"build",
]
def __init__(self, mod):
def __init__(self, mod, is_top_module=False):
super(TracedModuleBuilder, self).__init__()
self._mod = mod
self._body = InternalGraph()
self._is_traced = False
self._body = None
self._is_builtin = module_tracer.is_builtin(mod)
def build(self):
......@@ -164,9 +173,6 @@ class TracedModuleBuilder(NodeMixin):
return self._mod
else:
node = NodeMixin.get(self)
node.graph = self._body
node.attr_type_map = {}
node.arg_def = self._arg_def
traced_module = TracedModule(node)
for k, v in self.__dict__.items():
if k not in TracedModuleBuilder.__builder_attributes__:
......@@ -178,21 +184,15 @@ class TracedModuleBuilder(NodeMixin):
def __call__(self, *args, **kwargs):
assert isinstance(self._mod, Module)
for arg in args:
assert isinstance(arg, RawTensor)
for k, v in kwargs.items():
assert isinstance(v, RawTensor)
# prepare args and kwargs for inner graph
def mark_constant(x):
node = NodeMixin.get(x, None)
if node is None: # capture as constant
NodeMixin.wrap(x, lambda: Constant.make(x))
inputs, tree_def = tree_flatten(((self, *args), kwargs), leaf_type=_leaf_type)
if self._arg_def is None:
self._arg_def = tree_def
assert self._arg_def == tree_def
inputs, tree_def = tree_flatten(
((self, *args), kwargs), leaf_type=_leaf_type, is_const_leaf=_is_const_leaf
)
for i in inputs:
mark_constant(i)
callnode = CallMethod.make(NodeMixin.get(self))
......@@ -201,13 +201,14 @@ class TracedModuleBuilder(NodeMixin):
callnode.arg_def = tree_def
if self._is_builtin or self._is_traced:
if self._is_builtin:
unset_module_tracing()
outputs = self._mod(*args, **kwargs)
set_module_tracing()
if self._is_builtin:
self._body = None
else:
self._body = InternalGraph()
active_module_tracer().push_scope(self._body)
# rebind self to new input node
orig_self = NodeMixin.get(self)
......@@ -238,11 +239,12 @@ class TracedModuleBuilder(NodeMixin):
active_module_tracer().current_scope().add_output(NodeMixin.get(i))
NodeMixin.wrap_safe(self, orig_self)
self._is_traced = True
active_module_tracer().pop_scope()
# rebind output to outer graph
callnode.add_outputs(outputs)
self_node = NodeMixin.get(self)
self_node.argdef_graph_map[callnode.arg_def] = self._body
return outputs
def __getattr__(self, name):
......@@ -280,24 +282,23 @@ class TracedModuleBuilder(NodeMixin):
class _expr_list:
def __init__(self, module: "TracedModule"):
self.module = module
def __init__(self, graph: InternalGraph):
self.graph = graph
def __iter__(self):
graph = self.module.m_node.graph
for expr in graph._exprs:
for expr in self.graph._exprs:
if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode):
yield expr
assert isinstance(expr.inputs[0].expr, GetAttr)
(obj,) = expr.inputs[0].expr.interpret(self.module)
if isinstance(obj, TracedModule):
yield from obj.exprs
if expr.graph is not None:
yield from expr.graph.exprs
else:
yield expr
class TracedModule(Module):
"""
`TracedModule` is the Module created by tracing normal module. It owns a ModuleNode(m_node), and will interpret the m_node.graph when it is called.
`TracedModule` is the Module created by tracing normal module. It owns a ModuleNode(m_node). `TracedModule` can not be called directly. It can be
interpreted by CallMethod Expr.
"""
m_node = None # type: ModuleNode
......@@ -307,21 +308,24 @@ class TracedModule(Module):
self.m_node = node
def forward(self, *args, **kwargs):
inputs, treedef = tree_flatten(((self, *args), kwargs), leaf_type=_leaf_type)
assert treedef == self.m_node.arg_def
rst = self.m_node.graph.interpret(*inputs)
if len(rst) == 1:
rst = rst[0]
return rst
inputs, treedef = tree_flatten(
((self, *args), kwargs), _leaf_type, is_const_leaf=_is_const_leaf
)
assert treedef in self.m_node.argdef_graph_map
inputs = [i for i in inputs if isinstance(i, (Module, RawTensor))]
outputs = self.m_node.argdef_graph_map[treedef].interpret(*inputs)
if len(outputs) == 1:
return outputs[0]
return outputs
@property
def exprs(self):
"""
Get all ``Expr`` s recursively.
def graph(self):
assert len(self.m_node.argdef_graph_map) == 1
return list(self.m_node.argdef_graph_map.values())[0]
:return: Iterator[Expr]
"""
return _expr_list(self)
@property
def exprs(self):
return self.graph.exprs
def flatten(self):
"""
......@@ -331,24 +335,26 @@ class TracedModule(Module):
"""
new_module = copy.deepcopy(self)
def _flatten_submodule(module, call=None):
if not isinstance(module, TracedModule):
call.inputs[0] = module
return (call,)
def _flatten_subgraph(graph, module, call=None):
if graph is None:
assert not isinstance(module, TracedModule)
const = Constant(module)
modulenode = const.outputs[0]
modulenode.module_type = type(module)
call.inputs[0] = modulenode
return [const, call]
exprs = []
graph = module.m_node.graph
for expr in graph._exprs:
# replace inputs for submodule's expr
for idx, inp in enumerate(expr.inputs):
if call and inp in graph._inputs:
expr.inputs[idx] = call.inputs[idx]
inp_idx = graph._inputs.index(inp)
expr.inputs[idx] = call.inputs[inp_idx]
# replace outputs for submodule's expr
for idx, outp in enumerate(expr.outputs):
if call and outp in graph._outputs:
expr.outputs[idx] = call.outputs[idx]
oup_idx = graph._outputs.index(outp)
expr.outputs[idx] = call.outputs[oup_idx]
if isinstance(expr, GetAttr):
# replace GetAttr with Constant
......@@ -356,12 +362,13 @@ class TracedModule(Module):
const = Constant(getattr(module, expr.name))
const.outputs = expr.outputs
exprs.append(const)
elif isinstance(expr, CallMethod):
obj_node = expr.inputs[0]
if isinstance(obj_node, ModuleNode):
assert isinstance(expr.inputs[0].expr, GetAttr)
(obj,) = expr.inputs[0].expr.interpret(module)
exprs.extend(_flatten_submodule(obj, expr))
exprs.extend(_flatten_subgraph(expr.graph, obj, expr))
else:
exprs.append(expr)
else:
......@@ -369,7 +376,7 @@ class TracedModule(Module):
return exprs
new_module.m_node.graph._exprs = _flatten_submodule(new_module)
new_module.graph._exprs = _flatten_subgraph(new_module.graph, new_module)
return new_module
......@@ -421,7 +428,7 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule:
global_scope = InternalGraph()
active_module_tracer().push_scope(global_scope)
builder = TracedModuleBuilder(mod)
builder = TracedModuleBuilder(mod, True)
NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode))
inputs, _ = tree_flatten((args, kwargs))
for _, i in enumerate(inputs):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册