diff --git a/imperative/python/megengine/experimental/traced_module/expr.py b/imperative/python/megengine/experimental/traced_module/expr.py index ec07af739936b5c1e7721144d73c688fd42026a4..63318e2d5206d305f334d43a1997a7b05fc8081c 100644 --- a/imperative/python/megengine/experimental/traced_module/expr.py +++ b/imperative/python/megengine/experimental/traced_module/expr.py @@ -7,7 +7,7 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - +import builtins import collections from typing import Callable, List @@ -19,6 +19,7 @@ from ...module import Module from ...tensor import Tensor from .module_tracer import active_module_tracer from .node import ModuleNode, Node, NodeMixin, TensorNode +from .pytree import TreeDef class Expr: @@ -28,9 +29,22 @@ class Expr: inputs = None # type: List[Node] outputs = None # type: List[Node] - - def add_input(self, node): - self.inputs.append(node) + const_val = None # type: List[Any] + arg_def = None # type: TreeDef + + def add_inputs(self, vals): + if not isinstance(vals, collections.abc.Sequence): + vals = (vals,) + for val in vals: + node = NodeMixin.get(val, None) + if isinstance(node, (TensorNode, ModuleNode)): + if node not in self.inputs: + self.inputs.append(node) + else: + assert node is None + assert type(val) in builtins.__dict__.values() + idx = len(self.inputs) + len(self.const_val) + self.const_val.append((idx, val)) def add_outputs(self, outputs): self.outputs = [] @@ -38,50 +52,31 @@ class Expr: outputs = (outputs,) for i in outputs: + assert isinstance(i, RawTensor) self.outputs.append(NodeMixin.get_wrapped_type(i)(self)) for i, node in zip(outputs, self.outputs,): NodeMixin.wrap_safe(i, node) - @classmethod - def get_args_node(cls, arg): - """ - Create nodes by ``arg``, which may be a container. - Return the same structure with arg. - - If ``arg`` was not Tensor or Module, it will be stored as const. - - :param arg: tensor, module or const. - """ - if isinstance(arg, (RawTensor, Module)): - if not NodeMixin.get(arg, None): - NodeMixin.wrap_safe(arg, Constant.make(arg)) - return NodeMixin.get(arg) - elif isinstance(arg, collections.abc.Sequence): - seq_cls = type(arg) - return seq_cls([Expr.get_args_node(a) for a in arg]) + def unflatten_args(self, inputs): + if self.arg_def is not None: + inputs = list(inputs) + for idx, val in self.const_val: + inputs.insert(idx, val) + args, kwargs = self.arg_def.unflatten(inputs) + return args, kwargs else: - # TODO: assert arg type - return arg # as const + return inputs, {} - @classmethod - def get_arg_value(cls, inp_node, node2value): - """ - Get values from node2value by inp_node, which may be a container. - Return the same structure with inp_node. - - If ``inp_node`` was not in node2value, it is a const. - - :param inp_node: nodes. - :param node2value: dict from node to tensor and module. - """ - if inp_node in node2value: - return node2value[inp_node] - elif isinstance(inp_node, collections.abc.Sequence): - seq_cls = type(inp_node) - return seq_cls([Expr.get_arg_value(i, node2value) for i in inp_node]) - else: - return inp_node + @property + def kwargs(self): + _, kwargs = self.unflatten_args(self.inputs) + return kwargs + + @property + def args(self): + args, _ = self.unflatten_args(self.inputs) + return args # expr: None (i.e. fake expression which is used to mark input) @@ -144,16 +139,8 @@ class CallMethod(Expr): self.inputs = [ module, ] + self.const_val = [] self.method = method - self.arg_names = [] - self.kwargs = {} # const kwargs - - def add_input(self, node, arg_name=None): - if arg_name == "self": # FIXME: - return - self.inputs.append(node) - if arg_name is not None: - self.arg_names.append(arg_name) @classmethod def make(cls, *args, **kwargs): @@ -162,19 +149,22 @@ class CallMethod(Expr): return expr def interpret(self, *inputs): - mod = inputs[0] - args = inputs[1:] - outputs = getattr(mod, self.method)(*args, **self.kwargs) + args, kwargs = self.unflatten_args(inputs) + obj = args[0] + args = args[1:] + outputs = getattr(obj, self.method)(*args, **kwargs) if isinstance(outputs, RawTensor): outputs = (outputs,) return outputs def __repr__(self): - return "{} = CallMethod({}, {})({})".format( + args = ", ".join(str(i) for i in self.args[1:]) + kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items()) + return "{} = {}.{}({})".format( ", ".join(str(i) for i in self.outputs), self.inputs[0], self.method, - ", ".join(str(i) for i in self.inputs[1:]), + ", ".join([args, kwargs]), ) @@ -227,13 +217,8 @@ class CallFunction(Expr): def __init__(self, func): assert isinstance(func, Callable) self.func = func + self.const_val = [] self.inputs = [] - self.arg_names = [] - self.kwargs = {} # const kwargs - - def add_input(self, node, arg_name): - self.inputs.append(node) - self.arg_names.append(arg_name) @classmethod def make(cls, *args, **kwargs): @@ -242,18 +227,20 @@ class CallFunction(Expr): return expr def interpret(self, *inputs): - inp_dict = dict([(name, node) for node, name in zip(inputs, self.arg_names)]) - outputs = self.func(**inp_dict, **self.kwargs) + args, kwargs = self.unflatten_args(inputs) + outputs = self.func(*args, **kwargs) outputs = ( outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,) ) return outputs def __repr__(self): + args = ", ".join(str(i) for i in self.args) + kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items()) return "{} = {}({})".format( ", ".join(str(i) for i in self.outputs), self.func.__module__ + "." + self.func.__name__, - ", ".join(str(i) for i in self.inputs), + ", ".join([args, kwargs]), ) diff --git a/imperative/python/megengine/experimental/traced_module/module_tracer.py b/imperative/python/megengine/experimental/traced_module/module_tracer.py index f8eb67cfc48f8aefe08bcfbc28ceb5f01a9566c1..57d69dbdbbf70feb5b46ab4d6052499fcbfa7639 100644 --- a/imperative/python/megengine/experimental/traced_module/module_tracer.py +++ b/imperative/python/megengine/experimental/traced_module/module_tracer.py @@ -15,6 +15,72 @@ from ...module import Module _active_module_tracer = None +BUILTIN_ARRAY_METHOD = [ + "__lt__", + "__le__", + "__gt__", + "__ge__", + "__eq__", + "__ne__", + "__neg__", + "__pos__", + "__abs__", + "__invert__", + "__round__", + "__floor__", + "__ceil__", + "__add__", + "__sub__", + "__mul__", + "__matmul__", + "__truediv__", + "__floordiv__", + "__mod__", + "__pow__", + "__lshift__", + "__rshift__", + "__and__", + "__or__", + "__xor__", + "__radd__", + "__rsub__", + "__rmul__", + "__rmatmul__", + "__rtruediv__", + "__rfloordiv__", + "__rmod__", + "__rpow__", + "__rlshift__", + "__rrshift__", + "__rand__", + "__ror__", + "__rxor__", + "__iadd__", + "__isub__", + "__imul__", + "__imatmul__", + "__itruediv__", + "__ifloordiv__", + "__imod__", + "__ipow__", + "__ilshift__", + "__irshift__", + "__iand__", + "__ior__", + "__ixor__", + "T", + "astype", + "reshape", + "_broadcast", + "transpose", + "flatten", + "sum", + "prod", + "min", + "max", + "mean", +] + def active_module_tracer(): return _active_module_tracer @@ -108,9 +174,8 @@ class Patcher: self.wrap_fn = wrap_fn for module in self._builtin_modules: self.patch_module(module) - - for cls in self._builtin_methods: - self.patch_cls(cls) + for meth in BUILTIN_ARRAY_METHOD: + self.patch_method(ArrayMethodMixin, meth, self.wrap_fn) for i, j in self._builtin_functions: if id(i) not in self.visited_frames_ids: diff --git a/imperative/python/megengine/experimental/traced_module/node.py b/imperative/python/megengine/experimental/traced_module/node.py index e45aa281ee2923f46ec3a08c82fc987d2fcd9e0e..066fefe6f7bdcf7fb9f2d5f096ae063bc6d22179 100644 --- a/imperative/python/megengine/experimental/traced_module/node.py +++ b/imperative/python/megengine/experimental/traced_module/node.py @@ -13,6 +13,7 @@ import numpy from ...core._imperative_rt.core2 import Tensor as RawTensor from ...module import Module from ...tensor import Tensor +from .pytree import TreeDef class Node: @@ -58,6 +59,7 @@ class ModuleNode(Node): module_type = Module # type: Type[Module] graph = None attr_type_map = None # type: Dict[str, Type[Any]] + arg_def = None # type: TreeDef def __repr__(self): if self._name is None: diff --git a/imperative/python/megengine/experimental/traced_module/pytree.py b/imperative/python/megengine/experimental/traced_module/pytree.py new file mode 100644 index 0000000000000000000000000000000000000000..73b2c05d1ce6e11fe9d64f7c6a8db286d45fb33b --- /dev/null +++ b/imperative/python/megengine/experimental/traced_module/pytree.py @@ -0,0 +1,80 @@ +from typing import Callable, NamedTuple + +SUPPORTED_TYPE = {} + +NodeType = NamedTuple("NodeType", [("flatten", Callable), ("unflatten", Callable)]) + + +def register_supported_type(type, flatten, unflatten): + SUPPORTED_TYPE[type] = NodeType(flatten, unflatten) + + +register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x)) +register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: list(x)) +register_supported_type( + dict, lambda x: (list(x.values()), list(x.keys())), lambda x, y: dict(zip(y, x)) +) +register_supported_type( + slice, + lambda x: ([x.start, x.stop, x.step], None), + lambda x, aux_data: slice(x[0], x[1], x[2]), +) + + +def tree_flatten( + values, leaf_type: Callable = lambda x: type(x), is_leaf: Callable = lambda x: True +): + if type(values) not in SUPPORTED_TYPE: + assert is_leaf(values) + return [values,], LeafDef(leaf_type(values)) + rst = [] + children_defs = [] + children_values, aux_data = SUPPORTED_TYPE[type(values)].flatten(values) + for v in children_values: + v_list, treedef = tree_flatten(v, leaf_type) + rst.extend(v_list) + children_defs.append(treedef) + + return rst, TreeDef(type(values), aux_data, children_defs) + + +class TreeDef: + def __init__(self, type, aux_data, children_defs): + self.type = type + self.aux_data = aux_data + self.children_defs = children_defs + self.num_leaves = sum(ch.num_leaves for ch in children_defs) + + def unflatten(self, leaves): + assert len(leaves) == self.num_leaves + start = 0 + children = [] + for ch in self.children_defs: + children.append(ch.unflatten(leaves[start : start + ch.num_leaves])) + start += ch.num_leaves + return SUPPORTED_TYPE[self.type].unflatten(children, self.aux_data) + + def __eq__(self, other): + return ( + self.type == other.type + and self.aux_data == other.aux_data + and self.num_leaves == other.num_leaves + and self.children_defs == other.children_defs + ) + + def __repr__(self): + return "{}[{}]".format(self.type.__name__, self.children_defs) + + +class LeafDef(TreeDef): + def __init__(self, type): + super().__init__(type, None, []) + self.num_leaves = 1 + + def unflatten(self, leaves): + assert len(leaves) == 1 + assert isinstance(leaves[0], self.type), self.type + return leaves[0] + + def __repr__(self): + return "Leaf({})".format(self.type.__name__) diff --git a/imperative/python/megengine/experimental/traced_module/traced_module.py b/imperative/python/megengine/experimental/traced_module/traced_module.py index 9dd76035865ba825350e2ae65d4bd1eeb8897119..c6faa03059ef18bc7e7f31b53d632c16fac92e21 100644 --- a/imperative/python/megengine/experimental/traced_module/traced_module.py +++ b/imperative/python/megengine/experimental/traced_module/traced_module.py @@ -9,9 +9,11 @@ import collections import copy import functools +from inspect import getmembers, isclass, ismethod from typing import List, Type from ... import module as M +from ...core._imperative_rt.core2 import Tensor as RawTensor from ...core._imperative_rt.core2 import ( is_tracing_module, set_module_tracing, @@ -28,6 +30,16 @@ from .module_tracer import ( set_active_module_tracer, ) from .node import ModuleNode, Node, NodeMixin, TensorNode +from .pytree import tree_flatten + + +def _leaf_type(node): + if isinstance(node, RawTensor): + return (Tensor, TensorNode) + elif isinstance(node, (NodeMixin, Module)): + return (Module, ModuleNode, NodeMixin) + else: + return type(node) class InternalGraph: @@ -65,9 +77,7 @@ class InternalGraph: for n, v in zip(self._inputs, inputs): node2value[n] = v for expr in self._exprs: - values = expr.interpret( - *list(Expr.get_arg_value(i, node2value) for i in expr.inputs) - ) + values = expr.interpret(*list(node2value[i] for i in expr.inputs)) for n, v in zip(expr.outputs, values): node2value[n] = v return list(node2value[i] for i in self._outputs) @@ -80,37 +90,39 @@ class InternalGraph: ) +def _get_meth_name(obj, func): + for cls in type(obj).mro(): + for k, v in cls.__dict__.items(): + if v == func: + return k + return None + + def _wrapped_function(orig_func): @functools.wraps(orig_func) - def wrapped_fn(*inputs, **kwargs): + def wrapped_fn(*args, **kwargs): if is_tracing_module(): unset_module_tracing() - const_kwargs = {} - arg_names = orig_func.__code__.co_varnames - if orig_func.__qualname__.split(".").__len__() > 1: - # FIXME: a robust way to distinguish method and function. + inputs, tree_def = tree_flatten((args, kwargs), leaf_type=_leaf_type) + for i in inputs: + if not NodeMixin.get(i, None): + if isinstance(i, (RawTensor, NodeMixin)): + NodeMixin.wrap_safe(i, Constant.make(i)) + meth_name = _get_meth_name(args[0], wrapped_fn) + if meth_name: self = inputs[0] - call_node = CallMethod.make(NodeMixin.get(self), orig_func.__name__) + call_node = CallMethod.make(NodeMixin.get(self), meth_name) else: call_node = CallFunction.make(orig_func) - def add_input(inp, varname=None): - node = Expr.get_args_node(inp) - if node is not None: - call_node.add_input(node, varname) - else: - const_kwargs[varname] = inp - - for ind, inp in enumerate(inputs): - add_input(inp, arg_names[ind]) - for k, v in kwargs.items(): - add_input(v, k) - call_node.kwargs = const_kwargs - outputs = orig_func(*inputs, **kwargs) + call_node.add_inputs(inputs) + + call_node.arg_def = tree_def + outputs = orig_func(*args, **kwargs) call_node.add_outputs(outputs) set_module_tracing() return outputs - return orig_func(*inputs, **kwargs) + return orig_func(*args, **kwargs) return wrapped_fn @@ -120,14 +132,14 @@ class TracedModuleBuilder(NodeMixin): _mod = None # type: Module _body = None # type: InternalGraph _is_builtin = None # type: bool - + _arg_def = None # type: TreeDef __builder_attributes__ = [ "_mod", "_body", "_NodeMixin__node", "_is_builtin", "_is_traced", - "build", + "_arg_def" "build", ] def __init__(self, mod): @@ -146,6 +158,7 @@ class TracedModuleBuilder(NodeMixin): node = NodeMixin.get(self) node.graph = self._body node.attr_type_map = {} + node.arg_def = self._arg_def traced_module = TracedModule(node) for k, v in self.__dict__.items(): if k not in TracedModuleBuilder.__builder_attributes__: @@ -155,32 +168,34 @@ class TracedModuleBuilder(NodeMixin): traced_module.m_node.attr_type_map[k] = type(v) return traced_module - def __call__(self, *inputs, **kwargs): + def __call__(self, *args, **kwargs): assert isinstance(self._mod, Module) + for arg in args: + assert isinstance(arg, RawTensor) + for k, v in kwargs.items(): + assert isinstance(v, RawTensor) # prepare args and kwargs for inner graph def mark_constant(x): node = NodeMixin.get(x, None) if node is None: # capture as constant NodeMixin.wrap(x, lambda: Constant.make(x)) + inputs, tree_def = tree_flatten(((self, *args), kwargs), leaf_type=_leaf_type) + if self._arg_def is None: + self._arg_def = tree_def + assert self._arg_def == tree_def for i in inputs: mark_constant(i) - for k, v in kwargs.items(): - mark_constant(v) callnode = CallMethod.make(NodeMixin.get(self)) - def add_input(x): - callnode.add_input(NodeMixin.get(x)) + callnode.add_inputs(inputs) - for i in inputs: - add_input(i) - for k, v in kwargs.items(): - add_input(v) + callnode.arg_def = tree_def if self._is_builtin or self._is_traced: unset_module_tracing() - outputs = self._mod(*inputs, **kwargs) + outputs = self._mod(*args, **kwargs) set_module_tracing() if self._is_builtin: self._body = None @@ -193,23 +208,21 @@ class TracedModuleBuilder(NodeMixin): ) # prepare args and kwargs for inner graph def wrap(x): - # wrapped = copy.copy(x) # FIXME - wrapped = x # FIXME: + wrapped = copy.copy(x) # FIXME NodeMixin.wrap( wrapped, lambda: Input.make(type=NodeMixin.get_wrapped_type(wrapped)), ) return wrapped - args = [] - for i in inputs: + args = [self] + for i in inputs[1:]: args.append(wrap(i)) - for k, v in kwargs.items(): - kwargs[k] = wrap(v) + args, kwargs = tree_def.unflatten(args) active_module_tracer().patcher.auto_patch( getattr(getattr(self._mod, "forward", self._mod), "__globals__", {}) ) - outputs = type(self._mod).forward(self, *args, **kwargs) + outputs = type(self._mod).forward(*args, **kwargs) for i in ( outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,) @@ -269,8 +282,10 @@ class TracedModule(Module): super(TracedModule, self).__init__() self.m_node = node - def forward(self, *inputs): - rst = self.m_node.graph.interpret(self, *inputs) + def forward(self, *args, **kwargs): + inputs, treedef = tree_flatten(((self, *args), kwargs), leaf_type=_leaf_type) + assert treedef == self.m_node.arg_def + rst = self.m_node.graph.interpret(*inputs) if len(rst) == 1: rst = rst[0] return rst @@ -345,7 +360,6 @@ def register_as_builtin(mod_cls: Type[Module]) -> None: def _register_all_builtin_module(): - from inspect import getmembers, isclass for sub_mod in [M, M.qat, M.quantized]: for m in getmembers(sub_mod): @@ -357,7 +371,7 @@ def _register_all_builtin_module(): module_tracer.register_as_builtin(m[1]) -def trace_module(mod: Module, *inputs: Tensor, **kwargs: Tensor) -> TracedModule: +def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule: """ Traces module ``mod`` and returns corresponding TracedModule. @@ -375,15 +389,13 @@ def trace_module(mod: Module, *inputs: Tensor, **kwargs: Tensor) -> TracedModule builder = TracedModuleBuilder(mod) NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode)) - + inputs, _ = tree_flatten((args, kwargs)) for _, i in enumerate(inputs): - NodeMixin.wrap_safe(i, Input.make("arg_{}".format(_))) - for k, v in kwargs.items(): - NodeMixin.wrap_safe(v, Input.make("kwarg_{}".format(k))) - - builder(*inputs, **kwargs) + NodeMixin.wrap_safe( + i, Input.make("arg_{}".format(_), NodeMixin.get_wrapped_type(i)) + ) + builder(*args, **kwargs) active_module_tracer().pop_scope() - return builder.build() finally: set_active_module_tracer(None)