提交 4bb25369 编写于 作者: M Megvii Engine Team

feat(traced_module): let CallFunction own graph

GitOrigin-RevId: 66cdbca7e54df07576a984c3fd48d3bcafb678f1
上级 9a6a3793
...@@ -17,7 +17,7 @@ from ...core._imperative_rt.core2 import apply, set_module_tracing, unset_module ...@@ -17,7 +17,7 @@ from ...core._imperative_rt.core2 import apply, set_module_tracing, unset_module
from ...core.ops.special import Const from ...core.ops.special import Const
from ...module import Module from ...module import Module
from ...tensor import Tensor from ...tensor import Tensor
from .module_tracer import active_module_tracer from .module_tracer import active_module_tracer, module_tracer
from .node import ModuleNode, Node, NodeMixin, TensorNode from .node import ModuleNode, Node, NodeMixin, TensorNode
from .pytree import TreeDef from .pytree import TreeDef
...@@ -148,6 +148,15 @@ class CallMethod(Expr): ...@@ -148,6 +148,15 @@ class CallMethod(Expr):
active_module_tracer().current_scope().insert(expr) active_module_tracer().current_scope().insert(expr)
return expr return expr
@property
def graph(self):
if isinstance(self.inputs[0], ModuleNode):
m_node = self.inputs[0]
if m_node.argdef_graph_map:
assert self.arg_def in m_node.argdef_graph_map
return m_node.argdef_graph_map[self.arg_def]
return None
def interpret(self, *inputs): def interpret(self, *inputs):
args, kwargs = self.unflatten_args(inputs) args, kwargs = self.unflatten_args(inputs)
obj = args[0] obj = args[0]
...@@ -252,7 +261,9 @@ class Constant(Expr): ...@@ -252,7 +261,9 @@ class Constant(Expr):
_constant_cache = {} _constant_cache = {}
def __init__(self, c): def __init__(self, c):
# TODO: type check, since not all types should be captured as constant assert isinstance(c, (RawTensor, Module))
if isinstance(c, Module):
assert module_tracer.is_builtin(c)
self.value = c self.value = c
self.inputs = [] self.inputs = []
node_cls = NodeMixin.get_wrapped_type(c) node_cls = NodeMixin.get_wrapped_type(c)
......
...@@ -57,9 +57,13 @@ class ModuleNode(Node): ...@@ -57,9 +57,13 @@ class ModuleNode(Node):
""" """
module_type = Module # type: Type[Module] module_type = Module # type: Type[Module]
graph = None
attr_type_map = None # type: Dict[str, Type[Any]] attr_type_map = None # type: Dict[str, Type[Any]]
arg_def = None # type: TreeDef argdef_graph_map = None # type: Dict[Treedef, "InternalGraph"]
def __init__(self, expr: "Expr", name: str = None):
super().__init__(expr, name)
self.attr_type_map = {}
self.argdef_graph_map = {}
def __repr__(self): def __repr__(self):
if self._name is None: if self._name is None:
......
...@@ -25,7 +25,7 @@ def _dict_flatten(inp): ...@@ -25,7 +25,7 @@ def _dict_flatten(inp):
for key, value in sorted(inp.items()): for key, value in sorted(inp.items()):
results.append(value) results.append(value)
aux_data.append(key) aux_data.append(key)
return results, aux_data return results, tuple(aux_data)
def _dict_unflatten(inps, aux_data): def _dict_unflatten(inps, aux_data):
...@@ -43,16 +43,23 @@ register_supported_type( ...@@ -43,16 +43,23 @@ register_supported_type(
def tree_flatten( def tree_flatten(
values, leaf_type: Callable = lambda x: type(x), is_leaf: Callable = lambda x: True values,
leaf_type: Callable = lambda x: type(x),
is_leaf: Callable = lambda _: True,
is_const_leaf: Callable = lambda _: False,
): ):
if type(values) not in SUPPORTED_TYPE: if type(values) not in SUPPORTED_TYPE:
assert is_leaf(values) assert is_leaf(values)
return [values,], LeafDef(leaf_type(values)) node = LeafDef(leaf_type(values))
if is_const_leaf(values):
node.const_val = values
return [values,], node
rst = [] rst = []
children_defs = [] children_defs = []
children_values, aux_data = SUPPORTED_TYPE[type(values)].flatten(values) children_values, aux_data = SUPPORTED_TYPE[type(values)].flatten(values)
for v in children_values: for v in children_values:
v_list, treedef = tree_flatten(v, leaf_type) v_list, treedef = tree_flatten(v, leaf_type, is_leaf, is_const_leaf)
rst.extend(v_list) rst.extend(v_list)
children_defs.append(treedef) children_defs.append(treedef)
...@@ -75,6 +82,18 @@ class TreeDef: ...@@ -75,6 +82,18 @@ class TreeDef:
start += ch.num_leaves start += ch.num_leaves
return SUPPORTED_TYPE[self.type].unflatten(children, self.aux_data) return SUPPORTED_TYPE[self.type].unflatten(children, self.aux_data)
def __hash__(self):
return hash(
tuple(
[
self.type,
self.aux_data,
self.num_leaves,
tuple([hash(x) for x in self.children_defs]),
]
)
)
def __eq__(self, other): def __eq__(self, other):
return ( return (
self.type == other.type self.type == other.type
...@@ -93,11 +112,20 @@ class LeafDef(TreeDef): ...@@ -93,11 +112,20 @@ class LeafDef(TreeDef):
type = (type,) type = (type,)
super().__init__(type, None, []) super().__init__(type, None, [])
self.num_leaves = 1 self.num_leaves = 1
self.const_val = None
def unflatten(self, leaves): def unflatten(self, leaves):
assert len(leaves) == 1 assert len(leaves) == 1
assert isinstance(leaves[0], self.type), self.type assert isinstance(leaves[0], self.type), self.type
return leaves[0] return leaves[0]
def __eq__(self, other):
return self.type == other.type and self.const_val == other.const_val
def __hash__(self):
return hash(tuple([self.type, self.const_val]))
def __repr__(self): def __repr__(self):
return "Leaf({})".format(", ".join(t.__name__ for t in self.type)) return "Leaf({}[{}])".format(
", ".join(t.__name__ for t in self.type), self.const_val
)
...@@ -42,6 +42,12 @@ def _leaf_type(node): ...@@ -42,6 +42,12 @@ def _leaf_type(node):
return type(node) return type(node)
def _is_const_leaf(node):
if isinstance(node, (RawTensor, NodeMixin, Module)):
return False
return True
class InternalGraph: class InternalGraph:
""" """
``InternalGraph`` is a graph consist of ``Node`` and ``Expr``, it is used to represent the execution procedure of Module's forward method. ``InternalGraph`` is a graph consist of ``Node`` and ``Expr``, it is used to represent the execution procedure of Module's forward method.
...@@ -72,6 +78,10 @@ class InternalGraph: ...@@ -72,6 +78,10 @@ class InternalGraph:
def outputs(self): def outputs(self):
return self._outputs return self._outputs
@property
def exprs(self):
return _expr_list(self)
def add_input(self, i): def add_input(self, i):
self._inputs.append(i) self._inputs.append(i)
...@@ -111,7 +121,9 @@ def _wrapped_function(orig_func): ...@@ -111,7 +121,9 @@ def _wrapped_function(orig_func):
def wrapped_fn(*args, **kwargs): def wrapped_fn(*args, **kwargs):
if is_tracing_module(): if is_tracing_module():
unset_module_tracing() unset_module_tracing()
inputs, tree_def = tree_flatten((args, kwargs), leaf_type=_leaf_type) inputs, tree_def = tree_flatten(
(args, kwargs), leaf_type=_leaf_type, is_const_leaf=_is_const_leaf
)
for i in inputs: for i in inputs:
if not NodeMixin.get(i, None): if not NodeMixin.get(i, None):
if isinstance(i, (RawTensor, NodeMixin)): if isinstance(i, (RawTensor, NodeMixin)):
...@@ -140,21 +152,18 @@ class TracedModuleBuilder(NodeMixin): ...@@ -140,21 +152,18 @@ class TracedModuleBuilder(NodeMixin):
_mod = None # type: Module _mod = None # type: Module
_body = None # type: InternalGraph _body = None # type: InternalGraph
_is_builtin = None # type: bool _is_builtin = None # type: bool
_arg_def = None # type: TreeDef
__builder_attributes__ = [ __builder_attributes__ = [
"_mod", "_mod",
"_body", "_body",
"_NodeMixin__node", "_NodeMixin__node",
"_is_builtin", "_is_builtin",
"_is_traced", "build",
"_arg_def" "build",
] ]
def __init__(self, mod): def __init__(self, mod, is_top_module=False):
super(TracedModuleBuilder, self).__init__() super(TracedModuleBuilder, self).__init__()
self._mod = mod self._mod = mod
self._body = InternalGraph() self._body = None
self._is_traced = False
self._is_builtin = module_tracer.is_builtin(mod) self._is_builtin = module_tracer.is_builtin(mod)
def build(self): def build(self):
...@@ -164,9 +173,6 @@ class TracedModuleBuilder(NodeMixin): ...@@ -164,9 +173,6 @@ class TracedModuleBuilder(NodeMixin):
return self._mod return self._mod
else: else:
node = NodeMixin.get(self) node = NodeMixin.get(self)
node.graph = self._body
node.attr_type_map = {}
node.arg_def = self._arg_def
traced_module = TracedModule(node) traced_module = TracedModule(node)
for k, v in self.__dict__.items(): for k, v in self.__dict__.items():
if k not in TracedModuleBuilder.__builder_attributes__: if k not in TracedModuleBuilder.__builder_attributes__:
...@@ -178,21 +184,15 @@ class TracedModuleBuilder(NodeMixin): ...@@ -178,21 +184,15 @@ class TracedModuleBuilder(NodeMixin):
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
assert isinstance(self._mod, Module) assert isinstance(self._mod, Module)
for arg in args:
assert isinstance(arg, RawTensor)
for k, v in kwargs.items():
assert isinstance(v, RawTensor)
# prepare args and kwargs for inner graph # prepare args and kwargs for inner graph
def mark_constant(x): def mark_constant(x):
node = NodeMixin.get(x, None) node = NodeMixin.get(x, None)
if node is None: # capture as constant if node is None: # capture as constant
NodeMixin.wrap(x, lambda: Constant.make(x)) NodeMixin.wrap(x, lambda: Constant.make(x))
inputs, tree_def = tree_flatten(((self, *args), kwargs), leaf_type=_leaf_type) inputs, tree_def = tree_flatten(
if self._arg_def is None: ((self, *args), kwargs), leaf_type=_leaf_type, is_const_leaf=_is_const_leaf
self._arg_def = tree_def )
assert self._arg_def == tree_def
for i in inputs: for i in inputs:
mark_constant(i) mark_constant(i)
callnode = CallMethod.make(NodeMixin.get(self)) callnode = CallMethod.make(NodeMixin.get(self))
...@@ -201,13 +201,14 @@ class TracedModuleBuilder(NodeMixin): ...@@ -201,13 +201,14 @@ class TracedModuleBuilder(NodeMixin):
callnode.arg_def = tree_def callnode.arg_def = tree_def
if self._is_builtin or self._is_traced: if self._is_builtin:
unset_module_tracing() unset_module_tracing()
outputs = self._mod(*args, **kwargs) outputs = self._mod(*args, **kwargs)
set_module_tracing() set_module_tracing()
if self._is_builtin: if self._is_builtin:
self._body = None self._body = None
else: else:
self._body = InternalGraph()
active_module_tracer().push_scope(self._body) active_module_tracer().push_scope(self._body)
# rebind self to new input node # rebind self to new input node
orig_self = NodeMixin.get(self) orig_self = NodeMixin.get(self)
...@@ -238,11 +239,12 @@ class TracedModuleBuilder(NodeMixin): ...@@ -238,11 +239,12 @@ class TracedModuleBuilder(NodeMixin):
active_module_tracer().current_scope().add_output(NodeMixin.get(i)) active_module_tracer().current_scope().add_output(NodeMixin.get(i))
NodeMixin.wrap_safe(self, orig_self) NodeMixin.wrap_safe(self, orig_self)
self._is_traced = True
active_module_tracer().pop_scope() active_module_tracer().pop_scope()
# rebind output to outer graph # rebind output to outer graph
callnode.add_outputs(outputs) callnode.add_outputs(outputs)
self_node = NodeMixin.get(self)
self_node.argdef_graph_map[callnode.arg_def] = self._body
return outputs return outputs
def __getattr__(self, name): def __getattr__(self, name):
...@@ -280,24 +282,23 @@ class TracedModuleBuilder(NodeMixin): ...@@ -280,24 +282,23 @@ class TracedModuleBuilder(NodeMixin):
class _expr_list: class _expr_list:
def __init__(self, module: "TracedModule"): def __init__(self, graph: InternalGraph):
self.module = module self.graph = graph
def __iter__(self): def __iter__(self):
graph = self.module.m_node.graph for expr in self.graph._exprs:
for expr in graph._exprs:
if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode): if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode):
yield expr yield expr
assert isinstance(expr.inputs[0].expr, GetAttr) if expr.graph is not None:
(obj,) = expr.inputs[0].expr.interpret(self.module) yield from expr.graph.exprs
if isinstance(obj, TracedModule): else:
yield from obj.exprs yield expr
yield expr
class TracedModule(Module): class TracedModule(Module):
""" """
`TracedModule` is the Module created by tracing normal module. It owns a ModuleNode(m_node), and will interpret the m_node.graph when it is called. `TracedModule` is the Module created by tracing normal module. It owns a ModuleNode(m_node). `TracedModule` can not be called directly. It can be
interpreted by CallMethod Expr.
""" """
m_node = None # type: ModuleNode m_node = None # type: ModuleNode
...@@ -307,21 +308,24 @@ class TracedModule(Module): ...@@ -307,21 +308,24 @@ class TracedModule(Module):
self.m_node = node self.m_node = node
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
inputs, treedef = tree_flatten(((self, *args), kwargs), leaf_type=_leaf_type) inputs, treedef = tree_flatten(
assert treedef == self.m_node.arg_def ((self, *args), kwargs), _leaf_type, is_const_leaf=_is_const_leaf
rst = self.m_node.graph.interpret(*inputs) )
if len(rst) == 1: assert treedef in self.m_node.argdef_graph_map
rst = rst[0] inputs = [i for i in inputs if isinstance(i, (Module, RawTensor))]
return rst outputs = self.m_node.argdef_graph_map[treedef].interpret(*inputs)
if len(outputs) == 1:
return outputs[0]
return outputs
@property @property
def exprs(self): def graph(self):
""" assert len(self.m_node.argdef_graph_map) == 1
Get all ``Expr`` s recursively. return list(self.m_node.argdef_graph_map.values())[0]
:return: Iterator[Expr] @property
""" def exprs(self):
return _expr_list(self) return self.graph.exprs
def flatten(self): def flatten(self):
""" """
...@@ -331,24 +335,26 @@ class TracedModule(Module): ...@@ -331,24 +335,26 @@ class TracedModule(Module):
""" """
new_module = copy.deepcopy(self) new_module = copy.deepcopy(self)
def _flatten_submodule(module, call=None): def _flatten_subgraph(graph, module, call=None):
if not isinstance(module, TracedModule): if graph is None:
call.inputs[0] = module assert not isinstance(module, TracedModule)
return (call,) const = Constant(module)
modulenode = const.outputs[0]
modulenode.module_type = type(module)
call.inputs[0] = modulenode
return [const, call]
exprs = [] exprs = []
graph = module.m_node.graph
for expr in graph._exprs: for expr in graph._exprs:
# replace inputs for submodule's expr # replace inputs for submodule's expr
for idx, inp in enumerate(expr.inputs): for idx, inp in enumerate(expr.inputs):
if call and inp in graph._inputs: if call and inp in graph._inputs:
expr.inputs[idx] = call.inputs[idx] inp_idx = graph._inputs.index(inp)
expr.inputs[idx] = call.inputs[inp_idx]
# replace outputs for submodule's expr # replace outputs for submodule's expr
for idx, outp in enumerate(expr.outputs): for idx, outp in enumerate(expr.outputs):
if call and outp in graph._outputs: if call and outp in graph._outputs:
expr.outputs[idx] = call.outputs[idx] oup_idx = graph._outputs.index(outp)
expr.outputs[idx] = call.outputs[oup_idx]
if isinstance(expr, GetAttr): if isinstance(expr, GetAttr):
# replace GetAttr with Constant # replace GetAttr with Constant
...@@ -356,12 +362,13 @@ class TracedModule(Module): ...@@ -356,12 +362,13 @@ class TracedModule(Module):
const = Constant(getattr(module, expr.name)) const = Constant(getattr(module, expr.name))
const.outputs = expr.outputs const.outputs = expr.outputs
exprs.append(const) exprs.append(const)
elif isinstance(expr, CallMethod): elif isinstance(expr, CallMethod):
obj_node = expr.inputs[0] obj_node = expr.inputs[0]
if isinstance(obj_node, ModuleNode): if isinstance(obj_node, ModuleNode):
assert isinstance(expr.inputs[0].expr, GetAttr) assert isinstance(expr.inputs[0].expr, GetAttr)
(obj,) = expr.inputs[0].expr.interpret(module) (obj,) = expr.inputs[0].expr.interpret(module)
exprs.extend(_flatten_submodule(obj, expr)) exprs.extend(_flatten_subgraph(expr.graph, obj, expr))
else: else:
exprs.append(expr) exprs.append(expr)
else: else:
...@@ -369,7 +376,7 @@ class TracedModule(Module): ...@@ -369,7 +376,7 @@ class TracedModule(Module):
return exprs return exprs
new_module.m_node.graph._exprs = _flatten_submodule(new_module) new_module.graph._exprs = _flatten_subgraph(new_module.graph, new_module)
return new_module return new_module
...@@ -421,7 +428,7 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule: ...@@ -421,7 +428,7 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule:
global_scope = InternalGraph() global_scope = InternalGraph()
active_module_tracer().push_scope(global_scope) active_module_tracer().push_scope(global_scope)
builder = TracedModuleBuilder(mod) builder = TracedModuleBuilder(mod, True)
NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode)) NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode))
inputs, _ = tree_flatten((args, kwargs)) inputs, _ = tree_flatten((args, kwargs))
for _, i in enumerate(inputs): for _, i in enumerate(inputs):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册