From 763c56f3b9a225c2b0d173fdcf8131298de438fb Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 4 Mar 2021 14:53:22 +0800 Subject: [PATCH] feat(imperative): add traced module GitOrigin-RevId: 28c3503f2eaca979242c19c7d7495358daa0a8c4 --- imperative/python/megengine/__init__.py | 1 + .../python/megengine/experimental/__init__.py | 1 + .../experimental/traced_module/__init__.py | 12 + .../experimental/traced_module/expr.py | 215 +++++++++++++ .../traced_module/module_tracer.py | 52 +++ .../experimental/traced_module/node.py | 123 ++++++++ .../traced_module/traced_module.py | 295 ++++++++++++++++++ 7 files changed, 699 insertions(+) create mode 100644 imperative/python/megengine/experimental/traced_module/expr.py create mode 100644 imperative/python/megengine/experimental/traced_module/module_tracer.py create mode 100644 imperative/python/megengine/experimental/traced_module/node.py create mode 100644 imperative/python/megengine/experimental/traced_module/traced_module.py diff --git a/imperative/python/megengine/__init__.py b/imperative/python/megengine/__init__.py index ef453fc7..f149bcfd 100644 --- a/imperative/python/megengine/__init__.py +++ b/imperative/python/megengine/__init__.py @@ -130,3 +130,4 @@ import megengine.optimizer import megengine.quantization import megengine.random import megengine.utils +import megengine.experimental diff --git a/imperative/python/megengine/experimental/__init__.py b/imperative/python/megengine/experimental/__init__.py index 19b1fc6b..d263b902 100644 --- a/imperative/python/megengine/experimental/__init__.py +++ b/imperative/python/megengine/experimental/__init__.py @@ -6,4 +6,5 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from . import traced_module from .weight_scaler import get_scaled_model diff --git a/imperative/python/megengine/experimental/traced_module/__init__.py b/imperative/python/megengine/experimental/traced_module/__init__.py index f92f1aa0..cad44a0c 100644 --- a/imperative/python/megengine/experimental/traced_module/__init__.py +++ b/imperative/python/megengine/experimental/traced_module/__init__.py @@ -5,3 +5,15 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + +from ...core._imperative_rt.core2 import set_cpp_apply_module_trace +from .traced_module import ( + TracedModule, + _register_all_builtin_module, + cpp_apply_module_trace, + register_as_builtin, + trace_module, +) + +_register_all_builtin_module() +set_cpp_apply_module_trace(cpp_apply_module_trace) diff --git a/imperative/python/megengine/experimental/traced_module/expr.py b/imperative/python/megengine/experimental/traced_module/expr.py new file mode 100644 index 00000000..ad9ed301 --- /dev/null +++ b/imperative/python/megengine/experimental/traced_module/expr.py @@ -0,0 +1,215 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + + +import collections +from typing import List + +from ...core._imperative_rt import OpDef +from ...core._imperative_rt.core2 import Tensor as RawTensor +from ...core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing +from ...core.ops.special import Const +from ...tensor import Tensor +from .module_tracer import active_module_tracer +from .node import ModuleNode, Node, NodeMixin, TensorNode + + +class Expr: + """ + ``Expr`` represents the operations(i.e. Call, Apply, GetAttr, Input, Constant) on ``Node``. + """ + + inputs = None # type: List[Node] + outputs = None # type: List[Node] + + +# expr: None (i.e. fake expression which is used to mark input) +class Input(Expr): + name = None + + def __init__(self, name=None, type=None): + self.inputs = [] + node_cls = type if type else Node + self.outputs = [ + node_cls(self, name=name), + ] + self.name = name + + @classmethod + def make(cls, *args, **kwargs): + expr = cls(*args, **kwargs) + active_module_tracer().current_scope().add_input(expr.outputs[0]) + return expr.outputs[0] + + def __repr__(self): + return "{} = Input({})".format(self.outputs[0], self.name) + + +# expr: outputs = getattr(inputs[0], self.name) +class GetAttr(Expr): + name = None + + def __init__(self, module, name, type=None): + assert isinstance(module, ModuleNode) + self.inputs = [ + module, + ] + self.name = name + node_cls = type if type else Node + self.outputs = [ + node_cls(self), + ] + + @classmethod + def make(cls, *args, **kwargs): + expr = cls(*args, **kwargs) + active_module_tracer().current_scope().insert(expr) + expr.outputs[0]._name = expr.name + return expr.outputs[0] + + def interpret(self, *inputs): + return (getattr(inputs[0], self.name),) + + def __repr__(self): + return '{} = GetAttr({}, "{}")'.format( + self.outputs[0], self.inputs[0], self.name + ) + + +# expr: outputs = inputs[0].__call__(*inputs[1:]) +class Call(Expr): + def __init__(self, module): + assert isinstance(module, ModuleNode) + self.inputs = [ + module, + ] + + def add_input(self, node): + self.inputs.append(node) + + def add_outputs(self, references): + self.outputs = [] + if not isinstance(references, collections.Sequence): + references = (references,) + + for i in references: + self.outputs.append(NodeMixin.get_wrapped_type(i)(self)) + + @classmethod + def make(cls, *args, **kwargs): + expr = cls(*args, **kwargs) + active_module_tracer().current_scope().insert(expr) + return expr + + def interpret(self, *inputs): + mod = inputs[0] + args = inputs[1:] + outputs = mod(*args) + if isinstance(outputs, RawTensor): + outputs = (outputs,) + return outputs + + def __repr__(self): + return "{} = Call({})({})".format( + ", ".join(str(i) for i in self.outputs), + self.inputs[0], + ", ".join(str(i) for i in self.inputs[1:]), + ) + + +# expr: outputs = apply(self.opdef, *inputs) +class Apply(Expr): + opdef = None + + def __init__(self, opdef): + assert isinstance(opdef, OpDef) + self.opdef = opdef + self.inputs = [] + + def add_input(self, node): + self.inputs.append(node) + + def add_outputs(self, references): + self.outputs = [] + if not isinstance(references, collections.Sequence): + references = (references,) + + for i in references: + self.outputs.append(NodeMixin.get_wrapped_type(i)(self)) + + @classmethod + def make(cls, *args, **kwargs): + expr = cls(*args, **kwargs) + active_module_tracer().current_scope().insert(expr) + return expr + + def interpret(self, *inputs): + return apply(self.opdef, *inputs) + + def __repr__(self): + return "{} = {}({})".format( + ", ".join(str(i) for i in self.outputs), + self.opdef, + ", ".join(str(i) for i in self.inputs), + ) + + @classmethod + def apply_module_trace_hook(cls, opdef, *inputs): + for i in inputs: + node = NodeMixin.get(i, None) + if node is None: # capture as constant + NodeMixin.wrap_safe(i, Constant.make(i)) + apply_node = cls.make(opdef) + for i in inputs: + apply_node.add_input(NodeMixin.get(i)) + + unset_module_tracing() + outputs = apply(opdef, *inputs) + set_module_tracing() + + apply_node.add_outputs(outputs) + for n, v in zip(apply_node.outputs, outputs): + NodeMixin.wrap_safe(v, n) + return list(outputs) + + +# expr outputs = self.value +class Constant(Expr): + value = None + # TODO: constant cache to reduce the size of dumped model + _constant_cache = {} + + def __init__(self, c): + # TODO: type check, since not all types should be captured as constant + self.value = c + self.inputs = [] + node_cls = NodeMixin.get_wrapped_type(c) + self.outputs = [ + node_cls(self), + ] + + @classmethod + def make(cls, *args, **kwargs): + expr = cls(*args, **kwargs) + active_module_tracer().current_scope().insert(expr) + return expr.outputs[0] + + def interpret(self, *inputs): + if isinstance(self.value, RawTensor): + return Const(self.value.numpy())() + return (self.value,) + + def __repr__(self): + return "{} = Constant({})".format(self.outputs[0], self.value) + + def __getstate__(self): + state = self.__dict__.copy() + if isinstance(self.value, RawTensor): + state["value"] = Tensor(self.value) + return state diff --git a/imperative/python/megengine/experimental/traced_module/module_tracer.py b/imperative/python/megengine/experimental/traced_module/module_tracer.py new file mode 100644 index 00000000..0a0b2807 --- /dev/null +++ b/imperative/python/megengine/experimental/traced_module/module_tracer.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + +from ...module import Module + +_active_module_tracer = None + + +def active_module_tracer(): + return _active_module_tracer + + +def set_active_module_tracer(tracer): + global _active_module_tracer + _active_module_tracer = tracer + + +class module_tracer: + + _opaque_types = set() + + _active_scopes = None + + def __init__(self): + self._active_scopes = [] + + @classmethod + def register_as_builtin(cls, mod): + assert issubclass(mod, Module) + cls._opaque_types.add(mod) + return mod + + @classmethod + def is_builtin(cls, mod): + return type(mod) in cls._opaque_types + + def push_scope(self, scope): + self._active_scopes.append(scope) + + def pop_scope(self): + self._active_scopes.pop() + + def current_scope(self): + if self._active_scopes: + return self._active_scopes[-1] + return None diff --git a/imperative/python/megengine/experimental/traced_module/node.py b/imperative/python/megengine/experimental/traced_module/node.py new file mode 100644 index 00000000..6e6b5a9a --- /dev/null +++ b/imperative/python/megengine/experimental/traced_module/node.py @@ -0,0 +1,123 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from typing import Any, Dict, Tuple, Type + +import numpy + +from ...core._imperative_rt.core2 import Tensor as RawTensor +from ...module import Module +from ...tensor import Tensor + + +class Node: + """ + ``Node`` represents the variables (Tensor/Module/other python object) used in Module's forward method. They are inputs/outputs of Expr(the operations on variables). + + param expr: the Expr which produces the node + param name: the name of the node + """ + + expr = None + __total_id = 0 + _id = None + _name = None + + def __init__(self, expr: "Expr", name: str = None): + self.expr = expr + self._id = Node.__total_id + Node.__total_id += 1 + self._name = name + + def __repr__(self): + if self._name is None: + return "%{}".format(self._id) + else: + return "%{}".format(self._name) + + +class ModuleNode(Node): + """ + ``ModuleNode`` represents the Module objects. + + Attributes: + module_type: type of the Module correspending to the ModuleNode + graph: the InternalGraph which will be interpreted when call Module's forward method + attr_type_map: record the type of Module's attributes + """ + + module_type = Module # type: Type[Module] + graph = None + attr_type_map = None # type: Dict[str, Type[Any]] + + def __repr__(self): + if self._name is None: + return "%{}({})".format(self._id, self.module_type.__name__) + else: + return "%{}({})".format(self._name, self.module_type.__name__) + + +class TensorNode(Node): + """ + ``TensorNode`` represents the Tensor objects. + """ + + shape = None # type: Tuple[int] + dtype = None # type: numpy.dtype + + def __repr__(self): + if self._name is None: + return "%{}(Tensor)".format(self._id) + else: + return "%{}(Tensor)".format(self._name) + + +class NodeMixin: + __node = None + + @classmethod + def wrap(cls, value, node): + if isinstance(value, (NodeMixin, RawTensor)): + if isinstance(node, Node): + if isinstance(value, RawTensor): + node.dtype = value.dtype + node.shape = ( + value._tuple_shape if isinstance(value, Tensor) else value.shape + ) + setattr(value, "_NodeMixin__node", node) + else: + assert callable(node) + n = node() + if isinstance(value, RawTensor): + n.dtype = value.dtype + n.shape = ( + value._tuple_shape if isinstance(value, Tensor) else value.shape + ) + setattr(value, "_NodeMixin__node", n) + + @classmethod + def wrap_safe(cls, value, node): + assert isinstance(value, (NodeMixin, RawTensor)) + if isinstance(value, RawTensor): + node.dtype = value.dtype + node.shape = ( + value._tuple_shape if isinstance(value, Tensor) else value.shape + ) + setattr(value, "_NodeMixin__node", node) + + @classmethod + def get(cls, value, *default): + return getattr(value, "_NodeMixin__node", *default) + + @classmethod + def get_wrapped_type(cls, value): + if isinstance(value, RawTensor): + return TensorNode + if isinstance(value, (Module, NodeMixin)): + return ModuleNode + return Node diff --git a/imperative/python/megengine/experimental/traced_module/traced_module.py b/imperative/python/megengine/experimental/traced_module/traced_module.py new file mode 100644 index 00000000..880c6404 --- /dev/null +++ b/imperative/python/megengine/experimental/traced_module/traced_module.py @@ -0,0 +1,295 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import collections +import copy +from typing import List, Type + +from ... import module as M +from ...core._imperative_rt.core2 import set_module_tracing, unset_module_tracing +from ...module import Module +from ...tensor import Tensor +from .expr import Apply, Call, Constant, Expr, GetAttr, Input +from .module_tracer import active_module_tracer, module_tracer, set_active_module_tracer +from .node import ModuleNode, Node, NodeMixin, TensorNode + + +class InternalGraph: + """ + ``InternalGraph`` is a graph consist of ``Node`` and ``Expr``, it is used to represent the execution procedure of Module's forward method. + + Attributes: + _exprs: List of Exprs in order of execution + _inputs: Input Nodes of InternalGraph + _outputs: Output Nodes of InternalGraph + """ + + _exprs = None # type: List[Expr] + _inputs = None # type: List[Node] + _outputs = None # type: List[Node] + + def __init__(self): + self._exprs = [] + self._inputs = [] + self._outputs = [] + + def insert(self, expr): + self._exprs.append(expr) + + def add_input(self, i): + self._inputs.append(i) + + def add_output(self, o): + self._outputs.append(o) + + def interpret(self, *inputs): + # TODO: support kwargs ? + # TODO: skip expressions which are independent and have no side effect + node2value = {} + for n, v in zip(self._inputs, inputs): + node2value[n] = v + for expr in self._exprs: + values = expr.interpret(*list(node2value[i] for i in expr.inputs)) + for n, v in zip(expr.outputs, values): + node2value[n] = v + return list(node2value[i] for i in self._outputs) + + def __repr__(self): + return "InternalGraph ({}) {{\n\t{}\n\treturn {}\n}}".format( + ", ".join(str(i) for i in self._inputs), + "\n\t".join(str(i) for i in self._exprs), + ", ".join(str(i) for i in self._outputs), + ) + + +class TracedModuleBuilder(NodeMixin): + + _mod = None # type: Module + _body = None # type: InternalGraph + _is_builtin = None # type: bool + + __builder_attributes__ = [ + "_mod", + "_body", + "_NodeMixin__node", + "_is_builtin", + "_is_traced", + "build", + ] + + def __init__(self, mod): + super(TracedModuleBuilder, self).__init__() + self._mod = mod + self._body = InternalGraph() + self._is_traced = False + self._is_builtin = module_tracer.is_builtin(mod) + + def build(self): + if self._is_builtin: + node = NodeMixin.get(self) + node.module_type = type(self._mod) + return self._mod + else: + node = NodeMixin.get(self) + node.graph = self._body + node.attr_type_map = {} + traced_module = TracedModule(node) + for k, v in self.__dict__.items(): + if k not in TracedModuleBuilder.__builder_attributes__: + if isinstance(v, TracedModuleBuilder): + v = v.build() + setattr(traced_module, k, v) + traced_module.m_node.attr_type_map[k] = type(v) + return traced_module + + def __call__(self, *inputs, **kwargs): + assert isinstance(self._mod, Module) + + # prepare args and kwargs for inner graph + def mark_constant(x): + node = NodeMixin.get(x, None) + if node is None: # capture as constant + NodeMixin.wrap(x, lambda: Constant.make(x)) + + for i in inputs: + mark_constant(i) + for k, v in kwargs.items(): + mark_constant(v) + callnode = Call.make(NodeMixin.get(self)) + + def add_input(x): + callnode.add_input(NodeMixin.get(x)) + + for i in inputs: + add_input(i) + for k, v in kwargs.items(): + add_input(v) + + if self._is_builtin or self._is_traced: + unset_module_tracing() + outputs = self._mod(*inputs, **kwargs) + set_module_tracing() + if self._is_builtin: + self._body = None + else: + active_module_tracer().push_scope(self._body) + # rebind self to new input node + orig_self = NodeMixin.get(self) + NodeMixin.wrap_safe( + self, Input.make("self", NodeMixin.get_wrapped_type(self)) + ) + # prepare args and kwargs for inner graph + def wrap(x): + wrapped = copy.copy(x) # FIXME + NodeMixin.wrap( + wrapped, + lambda: Input.make(type=NodeMixin.get_wrapped_type(wrapped)), + ) + return wrapped + + args = [] + for i in inputs: + args.append(wrap(i)) + for k, v in kwargs.items(): + kwargs[k] = wrap(v) + + outputs = type(self._mod).forward(self, *args, **kwargs) + + for i in ( + outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,) + ): + active_module_tracer().current_scope().add_output(NodeMixin.get(i)) + + NodeMixin.wrap_safe(self, orig_self) + self._is_traced = True + active_module_tracer().pop_scope() + + # rebind output to outer graph + callnode.add_outputs(outputs) + for i, node in zip( + outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,), + callnode.outputs, + ): + NodeMixin.wrap_safe(i, node) + return outputs + + def __getattr__(self, name): + if name not in self._mod.__dict__: + attr = getattr(type(self._mod), name).__get__(self, type(self)) + else: + attr = getattr(self._mod, name) + if isinstance(attr, Module): + attr = TracedModuleBuilder(attr) + setattr(self, name, attr) + NodeMixin.wrap( + attr, + lambda: GetAttr.make( + NodeMixin.get(self), name, type=NodeMixin.get_wrapped_type(attr) + ), + ) + return attr + + def __getattribute__(self, name): + if name in TracedModuleBuilder.__builder_attributes__: + return super().__getattribute__(name) + else: + wrapped = super().__getattribute__(name) + if name in self._mod.__dict__ and not NodeMixin.get(wrapped, None): + assert not self._is_builtin + NodeMixin.wrap( + wrapped, + lambda: GetAttr.make( + NodeMixin.get(self), + name, + type=NodeMixin.get_wrapped_type(wrapped), + ), + ) + return wrapped + + +class TracedModule(Module): + """ + `TracedModule` is the Module created by tracing normal module. It owns a ModuleNode(m_node), and will interpret the m_node.graph when it is called. + """ + + m_node = None # type: ModuleNode + + def __init__(self, node): + super(TracedModule, self).__init__() + self.m_node = node + + def forward(self, *inputs): + rst = self.m_node.graph.interpret(self, *inputs) + if len(rst) == 1: + rst = rst[0] + return rst + + def __getstate__(self): + d = self.__dict__ + for k in Module.__dict__: + d.pop(k, None) + return d + + +def cpp_apply_module_trace(opdef, *args): + return Apply.apply_module_trace_hook(opdef, *args) + + +def register_as_builtin(mod_cls: Type[Module]) -> None: + """ + Registers class ``mod_cls`` (subclass of megengine.module.Module) as builtin module. + + param mod_cls: the Module class which will be threated as builtin module in tracing + """ + module_tracer.register_as_builtin(mod_cls) + + +def _register_all_builtin_module(): + from inspect import getmembers, isclass + + for sub_mod in [M, M.qat, M.quantized]: + for m in getmembers(sub_mod): + if ( + isclass(m[1]) + and issubclass(m[1], M.Module) + and m[1] is not M.Sequential + ): + module_tracer.register_as_builtin(m[1]) + + +def trace_module(mod: Module, *inputs: Tensor, **kwargs: Tensor) -> TracedModule: + """ + Traces module ``mod`` and returns corresponding TracedModule. + + param mod: the module will be converted to TracedModule + param input: the positional arguments passed to forward method of ``mod`` + param kwargs: the keyword arguments passed to forward method of ``mod`` + """ + assert active_module_tracer() is None + try: + set_module_tracing() + set_active_module_tracer(module_tracer()) + global_scope = InternalGraph() + + active_module_tracer().push_scope(global_scope) + + builder = TracedModuleBuilder(mod) + NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode)) + + for _, i in enumerate(inputs): + NodeMixin.wrap_safe(i, Input.make("arg_{}".format(_))) + for k, v in kwargs.items(): + NodeMixin.wrap_safe(v, Input.make("kwarg_{}".format(k))) + + builder(*inputs, **kwargs) + active_module_tracer().pop_scope() + + return builder.build() + finally: + set_active_module_tracer(None) + unset_module_tracing() -- GitLab