提交 f2691566 编写于 作者: M Megvii Engine Team

feat(traced_module): add pytree

GitOrigin-RevId: 6c6e53521c71474c67590e0a94723a1d6be89218
上级 bee305be
......@@ -7,7 +7,7 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import builtins
import collections
from typing import Callable, List
......@@ -19,6 +19,7 @@ from ...module import Module
from ...tensor import Tensor
from .module_tracer import active_module_tracer
from .node import ModuleNode, Node, NodeMixin, TensorNode
from .pytree import TreeDef
class Expr:
......@@ -28,9 +29,22 @@ class Expr:
inputs = None # type: List[Node]
outputs = None # type: List[Node]
def add_input(self, node):
self.inputs.append(node)
const_val = None # type: List[Any]
arg_def = None # type: TreeDef
def add_inputs(self, vals):
if not isinstance(vals, collections.abc.Sequence):
vals = (vals,)
for val in vals:
node = NodeMixin.get(val, None)
if isinstance(node, (TensorNode, ModuleNode)):
if node not in self.inputs:
self.inputs.append(node)
else:
assert node is None
assert type(val) in builtins.__dict__.values()
idx = len(self.inputs) + len(self.const_val)
self.const_val.append((idx, val))
def add_outputs(self, outputs):
self.outputs = []
......@@ -38,50 +52,31 @@ class Expr:
outputs = (outputs,)
for i in outputs:
assert isinstance(i, RawTensor)
self.outputs.append(NodeMixin.get_wrapped_type(i)(self))
for i, node in zip(outputs, self.outputs,):
NodeMixin.wrap_safe(i, node)
@classmethod
def get_args_node(cls, arg):
"""
Create nodes by ``arg``, which may be a container.
Return the same structure with arg.
If ``arg`` was not Tensor or Module, it will be stored as const.
:param arg: tensor, module or const.
"""
if isinstance(arg, (RawTensor, Module)):
if not NodeMixin.get(arg, None):
NodeMixin.wrap_safe(arg, Constant.make(arg))
return NodeMixin.get(arg)
elif isinstance(arg, collections.abc.Sequence):
seq_cls = type(arg)
return seq_cls([Expr.get_args_node(a) for a in arg])
def unflatten_args(self, inputs):
if self.arg_def is not None:
inputs = list(inputs)
for idx, val in self.const_val:
inputs.insert(idx, val)
args, kwargs = self.arg_def.unflatten(inputs)
return args, kwargs
else:
# TODO: assert arg type
return arg # as const
return inputs, {}
@classmethod
def get_arg_value(cls, inp_node, node2value):
"""
Get values from node2value by inp_node, which may be a container.
Return the same structure with inp_node.
If ``inp_node`` was not in node2value, it is a const.
:param inp_node: nodes.
:param node2value: dict from node to tensor and module.
"""
if inp_node in node2value:
return node2value[inp_node]
elif isinstance(inp_node, collections.abc.Sequence):
seq_cls = type(inp_node)
return seq_cls([Expr.get_arg_value(i, node2value) for i in inp_node])
else:
return inp_node
@property
def kwargs(self):
_, kwargs = self.unflatten_args(self.inputs)
return kwargs
@property
def args(self):
args, _ = self.unflatten_args(self.inputs)
return args
# expr: None (i.e. fake expression which is used to mark input)
......@@ -144,16 +139,8 @@ class CallMethod(Expr):
self.inputs = [
module,
]
self.const_val = []
self.method = method
self.arg_names = []
self.kwargs = {} # const kwargs
def add_input(self, node, arg_name=None):
if arg_name == "self": # FIXME: <XP>
return
self.inputs.append(node)
if arg_name is not None:
self.arg_names.append(arg_name)
@classmethod
def make(cls, *args, **kwargs):
......@@ -162,19 +149,22 @@ class CallMethod(Expr):
return expr
def interpret(self, *inputs):
mod = inputs[0]
args = inputs[1:]
outputs = getattr(mod, self.method)(*args, **self.kwargs)
args, kwargs = self.unflatten_args(inputs)
obj = args[0]
args = args[1:]
outputs = getattr(obj, self.method)(*args, **kwargs)
if isinstance(outputs, RawTensor):
outputs = (outputs,)
return outputs
def __repr__(self):
return "{} = CallMethod({}, {})({})".format(
args = ", ".join(str(i) for i in self.args[1:])
kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
return "{} = {}.{}({})".format(
", ".join(str(i) for i in self.outputs),
self.inputs[0],
self.method,
", ".join(str(i) for i in self.inputs[1:]),
", ".join([args, kwargs]),
)
......@@ -227,13 +217,8 @@ class CallFunction(Expr):
def __init__(self, func):
assert isinstance(func, Callable)
self.func = func
self.const_val = []
self.inputs = []
self.arg_names = []
self.kwargs = {} # const kwargs
def add_input(self, node, arg_name):
self.inputs.append(node)
self.arg_names.append(arg_name)
@classmethod
def make(cls, *args, **kwargs):
......@@ -242,18 +227,20 @@ class CallFunction(Expr):
return expr
def interpret(self, *inputs):
inp_dict = dict([(name, node) for node, name in zip(inputs, self.arg_names)])
outputs = self.func(**inp_dict, **self.kwargs)
args, kwargs = self.unflatten_args(inputs)
outputs = self.func(*args, **kwargs)
outputs = (
outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,)
)
return outputs
def __repr__(self):
args = ", ".join(str(i) for i in self.args)
kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
return "{} = {}({})".format(
", ".join(str(i) for i in self.outputs),
self.func.__module__ + "." + self.func.__name__,
", ".join(str(i) for i in self.inputs),
", ".join([args, kwargs]),
)
......
......@@ -15,6 +15,72 @@ from ...module import Module
_active_module_tracer = None
BUILTIN_ARRAY_METHOD = [
"__lt__",
"__le__",
"__gt__",
"__ge__",
"__eq__",
"__ne__",
"__neg__",
"__pos__",
"__abs__",
"__invert__",
"__round__",
"__floor__",
"__ceil__",
"__add__",
"__sub__",
"__mul__",
"__matmul__",
"__truediv__",
"__floordiv__",
"__mod__",
"__pow__",
"__lshift__",
"__rshift__",
"__and__",
"__or__",
"__xor__",
"__radd__",
"__rsub__",
"__rmul__",
"__rmatmul__",
"__rtruediv__",
"__rfloordiv__",
"__rmod__",
"__rpow__",
"__rlshift__",
"__rrshift__",
"__rand__",
"__ror__",
"__rxor__",
"__iadd__",
"__isub__",
"__imul__",
"__imatmul__",
"__itruediv__",
"__ifloordiv__",
"__imod__",
"__ipow__",
"__ilshift__",
"__irshift__",
"__iand__",
"__ior__",
"__ixor__",
"T",
"astype",
"reshape",
"_broadcast",
"transpose",
"flatten",
"sum",
"prod",
"min",
"max",
"mean",
]
def active_module_tracer():
return _active_module_tracer
......@@ -108,9 +174,8 @@ class Patcher:
self.wrap_fn = wrap_fn
for module in self._builtin_modules:
self.patch_module(module)
for cls in self._builtin_methods:
self.patch_cls(cls)
for meth in BUILTIN_ARRAY_METHOD:
self.patch_method(ArrayMethodMixin, meth, self.wrap_fn)
for i, j in self._builtin_functions:
if id(i) not in self.visited_frames_ids:
......
......@@ -13,6 +13,7 @@ import numpy
from ...core._imperative_rt.core2 import Tensor as RawTensor
from ...module import Module
from ...tensor import Tensor
from .pytree import TreeDef
class Node:
......@@ -58,6 +59,7 @@ class ModuleNode(Node):
module_type = Module # type: Type[Module]
graph = None
attr_type_map = None # type: Dict[str, Type[Any]]
arg_def = None # type: TreeDef
def __repr__(self):
if self._name is None:
......
from typing import Callable, NamedTuple
SUPPORTED_TYPE = {}
NodeType = NamedTuple("NodeType", [("flatten", Callable), ("unflatten", Callable)])
def register_supported_type(type, flatten, unflatten):
SUPPORTED_TYPE[type] = NodeType(flatten, unflatten)
register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x))
register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: list(x))
register_supported_type(
dict, lambda x: (list(x.values()), list(x.keys())), lambda x, y: dict(zip(y, x))
)
register_supported_type(
slice,
lambda x: ([x.start, x.stop, x.step], None),
lambda x, aux_data: slice(x[0], x[1], x[2]),
)
def tree_flatten(
values, leaf_type: Callable = lambda x: type(x), is_leaf: Callable = lambda x: True
):
if type(values) not in SUPPORTED_TYPE:
assert is_leaf(values)
return [values,], LeafDef(leaf_type(values))
rst = []
children_defs = []
children_values, aux_data = SUPPORTED_TYPE[type(values)].flatten(values)
for v in children_values:
v_list, treedef = tree_flatten(v, leaf_type)
rst.extend(v_list)
children_defs.append(treedef)
return rst, TreeDef(type(values), aux_data, children_defs)
class TreeDef:
def __init__(self, type, aux_data, children_defs):
self.type = type
self.aux_data = aux_data
self.children_defs = children_defs
self.num_leaves = sum(ch.num_leaves for ch in children_defs)
def unflatten(self, leaves):
assert len(leaves) == self.num_leaves
start = 0
children = []
for ch in self.children_defs:
children.append(ch.unflatten(leaves[start : start + ch.num_leaves]))
start += ch.num_leaves
return SUPPORTED_TYPE[self.type].unflatten(children, self.aux_data)
def __eq__(self, other):
return (
self.type == other.type
and self.aux_data == other.aux_data
and self.num_leaves == other.num_leaves
and self.children_defs == other.children_defs
)
def __repr__(self):
return "{}[{}]".format(self.type.__name__, self.children_defs)
class LeafDef(TreeDef):
def __init__(self, type):
super().__init__(type, None, [])
self.num_leaves = 1
def unflatten(self, leaves):
assert len(leaves) == 1
assert isinstance(leaves[0], self.type), self.type
return leaves[0]
def __repr__(self):
return "Leaf({})".format(self.type.__name__)
......@@ -9,9 +9,11 @@
import collections
import copy
import functools
from inspect import getmembers, isclass, ismethod
from typing import List, Type
from ... import module as M
from ...core._imperative_rt.core2 import Tensor as RawTensor
from ...core._imperative_rt.core2 import (
is_tracing_module,
set_module_tracing,
......@@ -28,6 +30,16 @@ from .module_tracer import (
set_active_module_tracer,
)
from .node import ModuleNode, Node, NodeMixin, TensorNode
from .pytree import tree_flatten
def _leaf_type(node):
if isinstance(node, RawTensor):
return (Tensor, TensorNode)
elif isinstance(node, (NodeMixin, Module)):
return (Module, ModuleNode, NodeMixin)
else:
return type(node)
class InternalGraph:
......@@ -65,9 +77,7 @@ class InternalGraph:
for n, v in zip(self._inputs, inputs):
node2value[n] = v
for expr in self._exprs:
values = expr.interpret(
*list(Expr.get_arg_value(i, node2value) for i in expr.inputs)
)
values = expr.interpret(*list(node2value[i] for i in expr.inputs))
for n, v in zip(expr.outputs, values):
node2value[n] = v
return list(node2value[i] for i in self._outputs)
......@@ -80,37 +90,39 @@ class InternalGraph:
)
def _get_meth_name(obj, func):
for cls in type(obj).mro():
for k, v in cls.__dict__.items():
if v == func:
return k
return None
def _wrapped_function(orig_func):
@functools.wraps(orig_func)
def wrapped_fn(*inputs, **kwargs):
def wrapped_fn(*args, **kwargs):
if is_tracing_module():
unset_module_tracing()
const_kwargs = {}
arg_names = orig_func.__code__.co_varnames
if orig_func.__qualname__.split(".").__len__() > 1:
# FIXME: a robust way to distinguish method and function. <XP>
inputs, tree_def = tree_flatten((args, kwargs), leaf_type=_leaf_type)
for i in inputs:
if not NodeMixin.get(i, None):
if isinstance(i, (RawTensor, NodeMixin)):
NodeMixin.wrap_safe(i, Constant.make(i))
meth_name = _get_meth_name(args[0], wrapped_fn)
if meth_name:
self = inputs[0]
call_node = CallMethod.make(NodeMixin.get(self), orig_func.__name__)
call_node = CallMethod.make(NodeMixin.get(self), meth_name)
else:
call_node = CallFunction.make(orig_func)
def add_input(inp, varname=None):
node = Expr.get_args_node(inp)
if node is not None:
call_node.add_input(node, varname)
else:
const_kwargs[varname] = inp
for ind, inp in enumerate(inputs):
add_input(inp, arg_names[ind])
for k, v in kwargs.items():
add_input(v, k)
call_node.kwargs = const_kwargs
outputs = orig_func(*inputs, **kwargs)
call_node.add_inputs(inputs)
call_node.arg_def = tree_def
outputs = orig_func(*args, **kwargs)
call_node.add_outputs(outputs)
set_module_tracing()
return outputs
return orig_func(*inputs, **kwargs)
return orig_func(*args, **kwargs)
return wrapped_fn
......@@ -120,14 +132,14 @@ class TracedModuleBuilder(NodeMixin):
_mod = None # type: Module
_body = None # type: InternalGraph
_is_builtin = None # type: bool
_arg_def = None # type: TreeDef
__builder_attributes__ = [
"_mod",
"_body",
"_NodeMixin__node",
"_is_builtin",
"_is_traced",
"build",
"_arg_def" "build",
]
def __init__(self, mod):
......@@ -146,6 +158,7 @@ class TracedModuleBuilder(NodeMixin):
node = NodeMixin.get(self)
node.graph = self._body
node.attr_type_map = {}
node.arg_def = self._arg_def
traced_module = TracedModule(node)
for k, v in self.__dict__.items():
if k not in TracedModuleBuilder.__builder_attributes__:
......@@ -155,32 +168,34 @@ class TracedModuleBuilder(NodeMixin):
traced_module.m_node.attr_type_map[k] = type(v)
return traced_module
def __call__(self, *inputs, **kwargs):
def __call__(self, *args, **kwargs):
assert isinstance(self._mod, Module)
for arg in args:
assert isinstance(arg, RawTensor)
for k, v in kwargs.items():
assert isinstance(v, RawTensor)
# prepare args and kwargs for inner graph
def mark_constant(x):
node = NodeMixin.get(x, None)
if node is None: # capture as constant
NodeMixin.wrap(x, lambda: Constant.make(x))
inputs, tree_def = tree_flatten(((self, *args), kwargs), leaf_type=_leaf_type)
if self._arg_def is None:
self._arg_def = tree_def
assert self._arg_def == tree_def
for i in inputs:
mark_constant(i)
for k, v in kwargs.items():
mark_constant(v)
callnode = CallMethod.make(NodeMixin.get(self))
def add_input(x):
callnode.add_input(NodeMixin.get(x))
callnode.add_inputs(inputs)
for i in inputs:
add_input(i)
for k, v in kwargs.items():
add_input(v)
callnode.arg_def = tree_def
if self._is_builtin or self._is_traced:
unset_module_tracing()
outputs = self._mod(*inputs, **kwargs)
outputs = self._mod(*args, **kwargs)
set_module_tracing()
if self._is_builtin:
self._body = None
......@@ -193,23 +208,21 @@ class TracedModuleBuilder(NodeMixin):
)
# prepare args and kwargs for inner graph
def wrap(x):
# wrapped = copy.copy(x) # FIXME
wrapped = x # FIXME: <XP>
wrapped = copy.copy(x) # FIXME
NodeMixin.wrap(
wrapped,
lambda: Input.make(type=NodeMixin.get_wrapped_type(wrapped)),
)
return wrapped
args = []
for i in inputs:
args = [self]
for i in inputs[1:]:
args.append(wrap(i))
for k, v in kwargs.items():
kwargs[k] = wrap(v)
args, kwargs = tree_def.unflatten(args)
active_module_tracer().patcher.auto_patch(
getattr(getattr(self._mod, "forward", self._mod), "__globals__", {})
)
outputs = type(self._mod).forward(self, *args, **kwargs)
outputs = type(self._mod).forward(*args, **kwargs)
for i in (
outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,)
......@@ -269,8 +282,10 @@ class TracedModule(Module):
super(TracedModule, self).__init__()
self.m_node = node
def forward(self, *inputs):
rst = self.m_node.graph.interpret(self, *inputs)
def forward(self, *args, **kwargs):
inputs, treedef = tree_flatten(((self, *args), kwargs), leaf_type=_leaf_type)
assert treedef == self.m_node.arg_def
rst = self.m_node.graph.interpret(*inputs)
if len(rst) == 1:
rst = rst[0]
return rst
......@@ -345,7 +360,6 @@ def register_as_builtin(mod_cls: Type[Module]) -> None:
def _register_all_builtin_module():
from inspect import getmembers, isclass
for sub_mod in [M, M.qat, M.quantized]:
for m in getmembers(sub_mod):
......@@ -357,7 +371,7 @@ def _register_all_builtin_module():
module_tracer.register_as_builtin(m[1])
def trace_module(mod: Module, *inputs: Tensor, **kwargs: Tensor) -> TracedModule:
def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule:
"""
Traces module ``mod`` and returns corresponding TracedModule.
......@@ -375,15 +389,13 @@ def trace_module(mod: Module, *inputs: Tensor, **kwargs: Tensor) -> TracedModule
builder = TracedModuleBuilder(mod)
NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode))
inputs, _ = tree_flatten((args, kwargs))
for _, i in enumerate(inputs):
NodeMixin.wrap_safe(i, Input.make("arg_{}".format(_)))
for k, v in kwargs.items():
NodeMixin.wrap_safe(v, Input.make("kwarg_{}".format(k)))
builder(*inputs, **kwargs)
NodeMixin.wrap_safe(
i, Input.make("arg_{}".format(_), NodeMixin.get_wrapped_type(i))
)
builder(*args, **kwargs)
active_module_tracer().pop_scope()
return builder.build()
finally:
set_active_module_tracer(None)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册