From 7b19bc76fb05ac7583e5bc70d9b895ac70093fb4 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 8 Sep 2021 18:20:32 +0800 Subject: [PATCH] feat(traced_module): support traced module backward compatible serialization GitOrigin-RevId: aaa9e51c74c11fa7955ae7bbfac476fa9bcf0d7d --- .../megengine/traced_module/__init__.py | 1 + .../python/megengine/traced_module/compat.py | 136 ++++++++ .../python/megengine/traced_module/expr.py | 303 +++++++++++++++-- .../megengine/traced_module/module_tracer.py | 1 - .../python/megengine/traced_module/node.py | 59 +++- .../python/megengine/traced_module/pytree.py | 82 +++-- .../megengine/traced_module/serialization.py | 164 ++++++++- .../megengine/traced_module/traced_module.py | 286 +++++++++++----- .../python/megengine/traced_module/utils.py | 102 +++++- .../test/unit/core/test_serialization.py | 23 -- .../unit/traced_module/test_modification.py | 6 +- .../unit/traced_module/test_qat_module.py | 26 +- .../unit/traced_module/test_serialization.py | 316 +++++++++++++++++- 13 files changed, 1314 insertions(+), 191 deletions(-) create mode 100644 imperative/python/megengine/traced_module/compat.py diff --git a/imperative/python/megengine/traced_module/__init__.py b/imperative/python/megengine/traced_module/__init__.py index 970225bf0..741ec59d2 100644 --- a/imperative/python/megengine/traced_module/__init__.py +++ b/imperative/python/megengine/traced_module/__init__.py @@ -7,6 +7,7 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from ..core._imperative_rt.core2 import set_cpp_apply_module_trace +from . import compat from .traced_module import ( TracedModule, _register_all_builtin_module, diff --git a/imperative/python/megengine/traced_module/compat.py b/imperative/python/megengine/traced_module/compat.py new file mode 100644 index 000000000..878d2fc6b --- /dev/null +++ b/imperative/python/megengine/traced_module/compat.py @@ -0,0 +1,136 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + +import numpy as np + +from .. import tensor +from ..core.ops.builtin import BatchNorm +from .expr import CallMethod, Constant +from .node import TensorNode +from .serialization import ( + register_functional_loader, + register_module_loader, + register_opdef_loader, + register_tensor_method_loader, +) + + +""" +# Expr loaders examples + +from ..core.ops.builtin import Elemwise + +@register_opdef_loader(Elemwise) +def add_opdef_loader(expr): + if expr.opdef_state["mode"] == "ADD": + expr.opdef_state["mode"] == "MUL" + node = expr.inputs[1] + astype_expr = CallMethod(node, "astype") + oup = TensorNode( + astype_expr, + shape=node.shape, + dtype=expr.inputs[0].dtype, + qparams=node.qparams, + ) + + astype_expr.set_args_kwargs(node, expr.inputs[0].dtype) + astype_expr.return_val = (oup,) + expr.inputs[1] = oup + + +@register_functional_loader(("megengine.functional.nn", "conv2d")) +def conv2df_loader(expr): + # expr.func = ("megengine.functional.nn","conv2d") + kwargs = expr.kwargs + orig_weight = expr.named_args["weight"] + + astype_expr = CallMethod(orig_weight, "astype") + oup = TensorNode( + astype_expr, + shape=orig_weight.shape, + dtype=orig_weight.dtype, + qparams=orig_weight.qparams, + ) + + astype_expr.set_args_kwargs(orig_weight, expr.named_args["inp"].dtype) + astype_expr.return_val = (oup,) + + expr.set_arg("weight", oup) + + +@register_module_loader(("megengine.module.conv", "Conv2d")) +def conv2dm_loader(expr): + module = expr.inputs[0].owner + args = list(expr.args) + orig_inp = args[1] + astype_expr = CallMethod(orig_inp, "astype") + oup = TensorNode( + astype_expr, + shape=orig_inp.shape, + dtype=orig_inp.dtype, + qparams=orig_inp.qparams, + ) + astype_expr.set_args_kwargs(orig_inp, module.weight.dtype) + astype_expr.return_val = (oup,) + args[1] = oup + expr.set_args_kwargs(*args) + + +@register_tensor_method_loader("__add__") +def add_loader(expr): + args = list(expr.args) + if not isinstance(args[1], TensorNode): + args[1] = tensor(args[1]) + node = Constant(args[1], "const").outputs[0] + + astype_expr = CallMethod(node, "astype") + oup = TensorNode( + astype_expr, shape=node.shape, dtype=node.dtype, qparams=node.qparams, + ) + + astype_expr.set_args_kwargs(node, expr.inputs[0].dtype) + astype_expr.return_val = (oup,) + args[1] = oup + expr.set_args_kwargs(*args) +""" + + +@register_module_loader( + ("megengine.module.batchnorm", "BatchNorm1d"), + ("megengine.module.batchnorm", "BatchNorm2d"), + ("megengine.module.batchnorm", "SyncBatchNorm"), +) +def bn2d_module_loader(expr): + # mge 1.6 + if not hasattr(expr, "version"): + module = expr.inputs[0].owner + if not hasattr(module, "param_dim"): + module.param_dim = "dim_1c11" + + +@register_module_loader( + ("megengine.module.conv_bn", "ConvBn2d"), + ("megengine.module.conv_bn", "ConvBnRelu2d"), + ("megengine.module.qat.conv_bn", "ConvBn2d"), + ("megengine.module.qat.conv_bn", "ConvBnRelu2d"), +) +def convbn2d_module_loader(expr): + # mge 1.6 + if not hasattr(expr, "version"): + module = expr.inputs[0].owner + if not hasattr(module.bn, "param_dim"): + module.bn.param_dim = "dim_1c11" + + +@register_opdef_loader(BatchNorm) +def bn_opdef_loader(expr): + # mge 1.6 + if not hasattr(expr, "version"): + output = expr.outputs[-1] + oup = TensorNode(expr, shape=(0,), dtype=None, qparams=output._qparams,) + expr.outputs.insert(4, oup) diff --git a/imperative/python/megengine/traced_module/expr.py b/imperative/python/megengine/traced_module/expr.py index 23fe73d72..fcd49dfe0 100644 --- a/imperative/python/megengine/traced_module/expr.py +++ b/imperative/python/megengine/traced_module/expr.py @@ -11,19 +11,28 @@ import collections import copy import inspect import re -from typing import Callable, Dict, List, Optional, Union +import weakref +from importlib import import_module +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union from ..core._imperative_rt import OpDef from ..core._imperative_rt.core2 import Tensor as RawTensor -from ..core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing +from ..core._imperative_rt.core2 import ( + apply, + is_tracing_module, + set_module_tracing, + unset_module_tracing, +) from ..core.ops.builtin import FakeQuant from ..core.ops.special import Const from ..module import Module from ..tensor import Parameter, Tensor +from ..version import __version__ from .module_tracer import active_module_tracer, module_tracer from .node import ModuleNode, Node, NodeMixin, TensorNode from .pytree import ArgsIndex, TreeDef, _is_const_leaf, _is_leaf, tree_flatten -from .serialization import get_opdef_state, load_opdef_from_state +from .serialization import _ModuleState +from .utils import _check_builtin_module_attr, _check_obj_attr, _convert_kwargs_to_args def rstrip(s: str, __chars: str): @@ -112,6 +121,7 @@ class Expr: node.users.append(self) else: assert node is None + assert not isinstance(val, (Module, RawTensor)) assert _is_leaf(val) and _is_const_leaf(val) idx = len(self.inputs) + len(self.const_val) self.const_val.append((idx, val)) @@ -132,14 +142,14 @@ class Expr: current_graph._namespace.auto_naming_for_outputs(self) def unflatten_args(self, inputs): - if self.arg_def is not None: - inputs = list(inputs) - for idx, val in self.const_val: - inputs.insert(idx, val) - args, kwargs = self.arg_def.unflatten(inputs) - return args, kwargs - else: - return inputs, {} + assert self.arg_def is not None, "{} expr doesn't have args/kwargs".format( + type(self).__name__ + ) + inputs = list(inputs) + for idx, val in self.const_val: + inputs.insert(idx, val) + args, kwargs = self.arg_def.unflatten(inputs) + return args, kwargs def replace_inputs(self, repl_dict: Dict[Node, Node]): r"""Replace the input Nodes of this Expr. @@ -165,6 +175,39 @@ class Expr: node.users.remove(self) repl_node.users.append(self) + @property + def _support_set_args_kwargs(self): + return False + + def set_args_kwargs(self, *args, **kwargs): + r""" Set args and kwargs for Expr. + """ + assert ( + self._support_set_args_kwargs + ), "Doesn't support set args/kwargs for {} expr".format(type(self).__name__) + args, kwargs = _convert_kwargs_to_args(self._get_func(), args, kwargs) + inputs, arg_def = tree_flatten((args, kwargs)) + orig_inputs = self.inputs + self.inputs = [] + self.const_val = [] + for val in inputs: + if isinstance(val, (TensorNode, ModuleNode)): + self.inputs.append(val) + else: + assert _is_leaf(val) and _is_const_leaf(val) + idx = len(self.inputs) + len(self.const_val) + self.const_val.append((idx, val)) + + for n in orig_inputs: + if n not in self.inputs: + n.users.remove(self) + + for n in self.inputs: + if n not in orig_inputs: + n.users.append(self) + + self.arg_def = arg_def + @property def kwargs(self): r"""Get the keyword arguments of the operation corresponding to this Expr.""" @@ -177,6 +220,61 @@ class Expr: args, _ = self.unflatten_args(self.inputs) return args + def _get_func(self): + # get called function when the expr is interpreted + raise NotImplementedError + + @property + def named_args(self): + func = self._get_func() + return inspect.getcallargs(func, *self.args, **self.kwargs) + + def set_arg(self, name, val): + func = self._get_func() + if name in self.kwargs: + new_kwargs = self.kwargs + new_kwargs[name] = val + self.set_args_kwargs(*self.args, **new_kwargs) + else: + arg_spec = inspect.getfullargspec(func) + if name in arg_spec.args: + ind = arg_spec.args.index(name) + new_args = list(self.args) + new_args[ind] = val + self.set_args_kwargs(*new_args) + elif name == arg_spec.varargs: + assert arg_spec.varargs is not None + assert len(self.args) >= len(arg_spec.args) + val = (val,) if not isinstance(val, Sequence) else val + self.set_args_kwargs(*self.args[0 : len(arg_spec.args)], *val) + else: + assert ( + arg_spec.varkw is not None + ), "func {} does't have argument named {}".format(func, name) + new_kwargs = self.kwargs + new_kwargs[name] = val + self.set_args_kwargs(*self.args, **new_kwargs) + + @property + def return_val(self): + return self.out_def.unflatten(self.outputs) + + @return_val.setter + def return_val(self, new_outputs): + outputs, out_def = tree_flatten( + new_outputs, is_leaf=lambda x: isinstance(x, Node) + ) + assert all( + isinstance(o, Node) for o in outputs + ), "Return values of expr must be ModuleNode or TensorNode or Container with them" + assert all( + o.expr in (None, self) for o in outputs + ), "Some nodes are produced by other expr, can not be output of expr {}".format( + self + ) + self.outputs = outputs + self.out_def = out_def + @property def top_graph(self): r"""Get the parent graph of this Expr.""" @@ -184,12 +282,6 @@ class Expr: return self._top_graph() return None - def __getstate__(self): - state = self.__dict__.copy() - if "_top_graph" in state: - state.pop("_top_graph") - return state - @classmethod def _get_next_id(cls): return cls.__total_id @@ -199,6 +291,23 @@ class Expr: assert isinstance(id, int) cls.__total_id = id + def __copy__(self): + cls = self.__class__ + result = cls.__new__(cls) + result.__dict__.update(self.__dict__) + return result + + def __deepcopy__(self, memo): + cls = self.__class__ + result = cls.__new__(cls) + state = {} + memo[id(self)] = result + for k, v in self.__dict__.items(): + if not isinstance(v, weakref.ReferenceType): + state[k] = copy.deepcopy(v, memo) + result.__dict__.update(state) + return result + # expr: None (i.e. fake expression which is used to mark input) class Input(Expr): @@ -229,6 +338,17 @@ class Input(Expr): def __repr__(self): return "%{}:\t{} = Input()".format(self._id, self.outputs[0]) + def __getstate__(self): + state = { + "_id": self._id, + "_disable_remove": self._disable_remove, + "inputs": self.inputs, + "outputs": self.outputs, + "name": self.name, + } + _check_obj_attr(state) + return state + # expr: outputs = getattr(inputs[0], self.name) class GetAttr(Expr): @@ -276,11 +396,23 @@ class GetAttr(Expr): def __repr__(self): out_type = "Tensor" if isinstance(self.outputs[0], ModuleNode): - out_type = self.outputs[0].module_type.__name__ + m_type = self.outputs[0].module_type + out_type = m_type.__name__ if isinstance(m_type, type) else m_type[1] return '%{}:\t{} = getattr({}, "{}") -> ({})'.format( self._id, self.outputs[0], self.inputs[0], self.name, out_type ) + def __getstate__(self): + state = { + "_id": self._id, + "_disable_remove": self._disable_remove, + "inputs": self.inputs, + "outputs": self.outputs, + "name": self.name, + } + _check_obj_attr(state) + return state + # expr: outputs = inputs[0].__call__(*inputs[1:]) class CallMethod(Expr): @@ -307,6 +439,7 @@ class CallMethod(Expr): node, ] self.const_val = [] + self.arg_def = tree_flatten(((node,), {}))[1] self.method = method @classmethod @@ -342,6 +475,27 @@ class CallMethod(Expr): outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor)) return outputs + def _get_func(self): + if isinstance(self.args[0], type): + obj_type = self.args[0] + elif isinstance(self.args[0], ModuleNode): + obj_type = self.args[0].module_type + else: + assert isinstance(self.args[0], TensorNode) + obj_type = Tensor + meth = getattr( + obj_type, "forward" if issubclass(obj_type, Module) else self.method + ) + return meth + + @property + def _support_set_args_kwargs(self): + # only expr call tensor method or builtin module support modify args/kwargs + return ( + isinstance(self.args[0], (TensorNode, type)) + or self.args[0].module_type is not Module + ) + def __repr__(self): args = ", ".join(str(i) for i in self.args[1:]) kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items()) @@ -359,6 +513,21 @@ class CallMethod(Expr): ", ".join([args, kwargs]), ) + def __getstate__(self): + state = { + "_id": self._id, + "_disable_remove": self._disable_remove, + "inputs": self.inputs, + "const_val": self.const_val, + "method": self.method, + "arg_def": self.arg_def, + "out_def": self.out_def, + "outputs": self.outputs, + "version": __version__, + } + _check_obj_attr(state) + return state + # expr: outputs = apply(self.opdef, *inputs) class Apply(Expr): @@ -394,14 +563,32 @@ class Apply(Expr): ) def __getstate__(self): - state = super().__getstate__() - state["opdef"] = get_opdef_state(state["opdef"]) + opdef_state = self.opdef.__getstate__() + opdef_state["opdef_type"] = type(self.opdef) + state = { + "_id": self._id, + "_disable_remove": self._disable_remove, + "opdef_state": opdef_state, + "inputs": self.inputs, + "outputs": self.outputs, + "version": __version__, + } + _check_obj_attr(state) return state def __setstate__(self, state): - state["opdef"] = load_opdef_from_state(state["opdef"]) - for k, v in state.items(): - setattr(self, k, v) + # compat with mge 1.6 + if "opdef" in state and "opdef_state" not in state: + opdef_state = state.pop("opdef") + opdef_state["opdef_type"] = opdef_state.pop("type") + state["opdef_state"] = opdef_state + self.__dict__.update(state) + assert isinstance(state["opdef_state"], dict) + opdef_state = state["opdef_state"].copy() + opdef_type = opdef_state.pop("opdef_type") + opdef_obj = opdef_type() + opdef_obj.__setstate__(opdef_state) + setattr(self, "opdef", opdef_obj) @classmethod def apply_module_trace_hook(cls, opdef, *inputs): @@ -458,12 +645,24 @@ class CallFunction(Expr): def interpret(self, *inputs): args, kwargs = self.unflatten_args(inputs) - outputs = self.func(*args, **kwargs) + func = ( + self.func + if not is_tracing_module() + else active_module_tracer().patcher.wrap_fn(self.func) + ) + outputs = func(*args, **kwargs) if outputs is None: return outputs outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor)) return outputs + def _get_func(self): + return self.func + + @property + def _support_set_args_kwargs(self): + return True + def __repr__(self): args = ", ".join(str(i) for i in self.args) kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items()) @@ -477,6 +676,33 @@ class CallFunction(Expr): ", ".join([args, kwargs]), ) + def __getstate__(self): + state = { + "_id": self._id, + "_disable_remove": self._disable_remove, + "func": (self.func.__module__, self.func.__qualname__), + "const_val": self.const_val, + "inputs": self.inputs, + "arg_def": self.arg_def, + "out_def": self.out_def, + "outputs": self.outputs, + "version": __version__, + } + _check_obj_attr(state) + return state + + def __setstate__(self, state): + self.__dict__.update(state) + try: + if isinstance(self.func, tuple): + mname, fname = self.func + f = import_module(mname) + for i in fname.split("."): + f = getattr(f, i) + self.func = f + except Exception: + pass + # expr outputs = self.value class Constant(Expr): @@ -496,6 +722,13 @@ class Constant(Expr): assert isinstance(c, (RawTensor, Module)) if isinstance(c, Module): assert module_tracer.is_builtin(c) or c.is_qat + if isinstance(c, RawTensor): + if is_tracing_module(): + unset_module_tracing() + c = Tensor(c) + set_module_tracing() + else: + c = Tensor(c) self.value = c self.name = name self.inputs = [] @@ -530,9 +763,25 @@ class Constant(Expr): ) def __getstate__(self): - state = self.__dict__.copy() - if "_top_graph" in state: - state.pop("_top_graph") + state = { + "_id": self._id, + "_disable_remove": self._disable_remove, + "value": self.value, + "name": self.name, + "inputs": self.inputs, + "outputs": self.outputs, + } + _check_obj_attr(state) if isinstance(self.value, RawTensor): state["value"] = Tensor(self.value) + if isinstance(self.value, Module) and module_tracer.is_builtin(self.value): + _check_builtin_module_attr(self.value) + state["value"] = _ModuleState.get_module_state(self.value) + return state + + def __setstate__(self, state): + for k, v in state.items(): + if isinstance(v, _ModuleState): + state[k] = v.to_module() + self.__dict__.update(state) diff --git a/imperative/python/megengine/traced_module/module_tracer.py b/imperative/python/megengine/traced_module/module_tracer.py index a0d5dceb9..db2bf0552 100644 --- a/imperative/python/megengine/traced_module/module_tracer.py +++ b/imperative/python/megengine/traced_module/module_tracer.py @@ -72,7 +72,6 @@ BUILTIN_ARRAY_METHOD = [ "astype", "reshape", "_broadcast", - "transpose", "flatten", "sum", "prod", diff --git a/imperative/python/megengine/traced_module/node.py b/imperative/python/megengine/traced_module/node.py index 5aa571ad5..d3c9fccad 100644 --- a/imperative/python/megengine/traced_module/node.py +++ b/imperative/python/megengine/traced_module/node.py @@ -6,7 +6,9 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import abc +import copy import weakref +from importlib import import_module from typing import Any, Dict, List, Tuple, Type import numpy @@ -14,7 +16,9 @@ import numpy from .. import get_logger from ..core._imperative_rt.core2 import Tensor as RawTensor from ..module import Module +from ..quantization.utils import QParams from ..tensor import Tensor +from .utils import _check_obj_attr logger = get_logger(__name__) @@ -145,6 +149,23 @@ class Node: assert isinstance(id, int) cls.__total_id = id + def __copy__(self): + cls = self.__class__ + result = cls.__new__(cls) + result.__dict__.update(self.__dict__) + return result + + def __deepcopy__(self, memo): + cls = self.__class__ + result = cls.__new__(cls) + state = {} + memo[id(self)] = result + for k, v in self.__dict__.items(): + if not isinstance(v, weakref.ReferenceType) and k != "actual_node": + state[k] = copy.deepcopy(v, memo) + result.__dict__.update(state) + return result + class ModuleNode(Node): r"""``ModuleNode`` represents the Module objects.""" @@ -157,19 +178,28 @@ class ModuleNode(Node): super().__init__(expr, name, qualname) def __getstate__(self): - return { + state = { "expr": self.expr, "users": self.users, "_id": self._id, "_name": self._name, "_qualname": self._qualname, - "module_type": self.module_type, + "module_type": (self.module_type.__module__, self.module_type.__qualname__), } + _check_obj_attr(state) + return state def __setstate__(self, state): if "_orig_name" in state: state["_qualname"] = state.pop("_orig_name") self.__dict__.update(state) + try: + if isinstance(self.module_type, tuple): + mname, classname = self.module_type + mtype = getattr(import_module(mname), classname) + self.module_type = mtype + except Exception: + pass @property def owner(self): @@ -185,12 +215,26 @@ class TensorNode(Node): _shape = None # type: Tuple[int] _dtype = None # type: numpy.dtype - _qparams = None + _qparams = None # type: QParams _device = None _value = None # type: Tensor + def __init__( + self, + expr: "Expr", + name: str = None, + qualname: str = None, + shape: Tuple[int] = None, + dtype: numpy.dtype = None, + qparams: QParams = None, + ): + super().__init__(expr, name, qualname) + self._shape = shape + self._dtype = shape + self._qparams = qparams + def __getstate__(self): - return { + state = { "expr": self.expr, "users": self.users, "_id": self._id, @@ -201,6 +245,8 @@ class TensorNode(Node): "_name": self._name, "_qualname": self._qualname, } + _check_obj_attr(state) + return state def __setstate__(self, state): if "_orig_name" in state: @@ -276,7 +322,10 @@ class NodeMixin(abc.ABC): assert isinstance(node, TensorNode) assert isinstance(value, RawTensor) if isinstance(value, RawTensor): - node._dtype = value.dtype + try: + node._dtype = value.dtype + except RuntimeError: + node._dtype = None node._shape = ( value._tuple_shape if isinstance(value, Tensor) else value.shape ) diff --git a/imperative/python/megengine/traced_module/pytree.py b/imperative/python/megengine/traced_module/pytree.py index ad7e644b4..0c62dc283 100644 --- a/imperative/python/megengine/traced_module/pytree.py +++ b/imperative/python/megengine/traced_module/pytree.py @@ -7,15 +7,18 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import collections -from collections import OrderedDict +from collections import OrderedDict, defaultdict +from functools import partial from typing import Callable, NamedTuple import numpy as np +from ..core._imperative_rt import OpDef from ..core._imperative_rt.common import CompNode from ..core._imperative_rt.core2 import Tensor as RawTensor from ..core._wrap import Device from ..core.tensor.dtype import QuantDtypeMeta +from ..distributed import Group from ..module import Module from ..quantization.utils import LSQParams, QParams, QuantMode from ..tensor import Parameter, Tensor @@ -49,45 +52,54 @@ SUPPORTED_LEAF_TYPE = { type(Ellipsis), QuantMode, ArgsIndex, + Group, } +USER_REGISTERED_LEAF_TYPE = [] +USER_REGISTERED_CONTAINER_TYPE = [] # if isinstance(object, SUPPORTED_LEAF_CLS) or issubclass(obj, SUPPORTED_LEAF_CLS) is True, the object could be threated as leaf node of pytree -SUPPORTED_LEAF_CLS = [Module, Node, NodeMixin, np.dtype, np.ndarray, np.number] +SUPPORTED_LEAF_CLS = [ + Module, + Node, + NodeMixin, + np.dtype, + np.ndarray, + np.number, + np.bool_, + OpDef, +] NodeType = NamedTuple("NodeType", [("flatten", Callable), ("unflatten", Callable)]) def register_supported_type(type, flatten=None, unflatten=None): + tp_info = (type.__module__, type.__qualname__) if flatten and unflatten: - SUPPORTED_TYPE[type] = NodeType(flatten, unflatten) + USER_REGISTERED_CONTAINER_TYPE.append(tp_info) else: - SUPPORTED_LEAF_CLS.append(type) - - -def _dict_flatten(inp): - aux_data = [] - results = [] - for key, value in sorted(inp.items()): - results.append(value) - aux_data.append(key) - return results, tuple(aux_data) + USER_REGISTERED_LEAF_TYPE.append(tp_info) + _register_supported_type(type, flatten, unflatten) -def _dict_unflatten(inps, aux_data): - return dict(zip(aux_data, inps)) +def _register_supported_type(type, flatten=None, unflatten=None): + if flatten and unflatten: + SUPPORTED_TYPE[type] = NodeType(flatten, unflatten) + else: + SUPPORTED_LEAF_CLS.append(type) -def _ordereddict_flatten(inp): +def _dict_flatten(ordered, inp): aux_data = [] results = [] - for key, value in inp.items(): + dict_items = inp.items() if ordered else sorted(inp.items()) + for key, value in dict_items: results.append(value) aux_data.append(key) return results, tuple(aux_data) -def _ordereddict_unflatten(inps, aux_data): - return OrderedDict(zip(aux_data, inps)) +def _dict_unflatten(dict_type, inps, aux_data): + return dict_type(zip(aux_data, inps)) def qparams_flatten(inp): @@ -99,33 +111,41 @@ def qparams_flatten(inp): return results, tuple(aux_data) -def qparams_unflatten(inp, aux_data): - obj = QParams.__new__(QParams) +def qparams_unflatten(qparam_type, inp, aux_data): + obj = qparam_type.__new__(qparam_type) for k, v in zip(aux_data, inp): setattr(obj, k, v) return obj -register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x)) -register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: tuple(x)) -register_supported_type(dict, _dict_flatten, _dict_unflatten) -register_supported_type( - collections.OrderedDict, _ordereddict_flatten, _ordereddict_unflatten +_register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x)) +_register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: tuple(x)) +_register_supported_type( + dict, partial(_dict_flatten, False), partial(_dict_unflatten, dict) +) +_register_supported_type( + defaultdict, partial(_dict_flatten, False), partial(_dict_unflatten, defaultdict) ) -register_supported_type( +_register_supported_type( + OrderedDict, partial(_dict_flatten, True), partial(_dict_unflatten, OrderedDict) +) +_register_supported_type( slice, lambda x: ([x.start, x.stop, x.step], None), lambda x, aux_data: slice(x[0], x[1], x[2]), ) -register_supported_type(QParams, qparams_flatten, qparams_unflatten) +_register_supported_type(QParams, qparams_flatten, partial(qparams_unflatten, QParams)) +_register_supported_type( + LSQParams, qparams_flatten, partial(qparams_unflatten, LSQParams) +) def _is_leaf(obj): - if isinstance(obj, type): - return issubclass(obj, tuple(SUPPORTED_LEAF_CLS)) or obj in SUPPORTED_LEAF_TYPE + obj_type = obj if isinstance(obj, type) else type(obj) return ( - isinstance(obj, tuple(SUPPORTED_LEAF_CLS)) or type(obj) in SUPPORTED_LEAF_TYPE + issubclass(obj_type, tuple(SUPPORTED_LEAF_CLS)) + or obj_type in SUPPORTED_LEAF_TYPE ) diff --git a/imperative/python/megengine/traced_module/serialization.py b/imperative/python/megengine/traced_module/serialization.py index 8ce3ed67a..7762a40e9 100644 --- a/imperative/python/megengine/traced_module/serialization.py +++ b/imperative/python/megengine/traced_module/serialization.py @@ -5,30 +5,158 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -from typing import Dict +from importlib import import_module +from typing import Dict, Tuple from ..core._imperative_rt import OpDef from ..core.ops import builtin +from ..tensor import Tensor from ..version import __version__ +from .utils import _convert_kwargs_to_args -OPDEF_PARAM_LOADER = {} +OPDEF_LOADER = {} +FUNCTIONAL_LOADER = {} +TENSORMETHOD_LOADER = {} +MODULE_LOADER = {} -def get_opdef_state(obj: OpDef) -> Dict: - state = obj.__getstate__() - state["type"] = type(obj) - state["version"] = __version__ - return state +class _ModuleState: + obj = None + def __init__(self, module: Tuple, state: Dict, version: str): + self.module = module + self.state = state + self.version = version -def load_opdef_from_state(state: Dict) -> OpDef: - assert "type" in state and issubclass(state["type"], OpDef) - assert "version" in state - opdef_type = state.pop("type") - if opdef_type in OPDEF_PARAM_LOADER: - loader = OPDEF_PARAM_LOADER[opdef_type] - state = loader(state) - state.pop("version") - opdef_obj = opdef_type() - opdef_obj.__setstate__(state) - return opdef_obj + @classmethod + def get_module_state(cls, module): + typem = (type(module).__module__, type(module).__qualname__) + state = module.__dict__.copy() + state.pop("_m_dump_modulestate", None) + if hasattr(module, "_m_dump_modulestate"): + assert isinstance(module._m_dump_modulestate, cls) + module._m_dump_modulestate.__init__(typem, state, __version__) + else: + module.__dict__["_m_dump_modulestate"] = _ModuleState( + typem, state, __version__ + ) + + return module._m_dump_modulestate + + def __getstate__(self): + return {"module": self.module, "state": self.state, "version": self.version} + + def to_module(self): + if self.obj is None: + typem = getattr(import_module(self.module[0]), self.module[1]) + m_obj = typem.__new__(typem) + m_obj.__dict__.update(self.state) + self.obj = m_obj + return self.obj + + +def register_opdef_loader(*opdefs): + def callback(loader): + for opdef in opdefs: + assert opdef not in OPDEF_LOADER + OPDEF_LOADER[opdef] = loader + return loader + + return callback + + +def register_functional_loader(*funcs): + def callback(loader): + for func in funcs: + assert func not in FUNCTIONAL_LOADER + FUNCTIONAL_LOADER[func] = loader + return loader + + return callback + + +def register_module_loader(*module_types): + def callback(loader): + for module_type in module_types: + assert module_type not in MODULE_LOADER + MODULE_LOADER[module_type] = loader + return loader + + return callback + + +def register_tensor_method_loader(*methods): + def callback(loader): + for method in methods: + assert method not in TENSORMETHOD_LOADER + TENSORMETHOD_LOADER[method] = loader + return loader + + return callback + + +def _replace_args_kwargs(expr, new_args, new_kwargs): + if len(new_args) != len(expr.args) or set(new_kwargs.keys()) != set( + expr.kwargs.keys() + ): + expr.set_args_kwargs(*new_args, **new_kwargs) + + +def load_functional(expr): + func = ( + (expr.func.__module__, expr.func.__qualname__) + if callable(expr.func) + else expr.func + ) + assert isinstance(func, tuple) + if func in FUNCTIONAL_LOADER: + loader = FUNCTIONAL_LOADER[func] + loader(expr) + mname, fname = func + f = import_module(mname) + for i in fname.split("."): + f = getattr(f, i) + expr.func = f + assert callable(expr.func) + if not hasattr(expr, "version") or expr.version != __version__: + args, kwargs = _convert_kwargs_to_args(expr.func, expr.args, expr.kwargs) + _replace_args_kwargs(expr, args, kwargs) + + +def load_call_module_expr(expr): + m_type = expr.inputs[0].module_type + if isinstance(m_type, type): + m_type = (m_type.__module__, m_type.__qualname__) + if m_type in MODULE_LOADER: + MODULE_LOADER[m_type](expr) + if isinstance(expr.inputs[0].module_type, tuple): + mname, classname = expr.inputs[0].module_type + expr.inputs[0].module_type = getattr(import_module(mname), classname) + if not hasattr(expr, "version") or expr.version != __version__: + fwd_func = getattr(expr.inputs[0].module_type, "forward") + args, kwargs = _convert_kwargs_to_args(fwd_func, expr.args, expr.kwargs) + _replace_args_kwargs(expr, args, kwargs) + + +def load_call_tensor_method_expr(expr): + if expr.method in TENSORMETHOD_LOADER: + loader = TENSORMETHOD_LOADER[expr.method] + loader(expr) + if not hasattr(expr, "version") or expr.version != __version__: + tmethod = ( + getattr(expr.args[0], expr.method) + if isinstance(expr.args[0], type) + else getattr(Tensor, expr.method) + ) + args, kwargs = _convert_kwargs_to_args(tmethod, expr.args, expr.kwargs) + _replace_args_kwargs(expr, args, kwargs) + + +def load_apply_expr(expr): + opdef_type = type(expr.opdef) + if opdef_type in OPDEF_LOADER: + OPDEF_LOADER[opdef_type](expr) + opdef_state = expr.opdef_state + opdef_obj = opdef_state.pop("opdef_type")() + opdef_obj.__setstate__(opdef_state) + expr.opdef = opdef_obj diff --git a/imperative/python/megengine/traced_module/traced_module.py b/imperative/python/megengine/traced_module/traced_module.py index 1000be7be..824150ed4 100644 --- a/imperative/python/megengine/traced_module/traced_module.py +++ b/imperative/python/megengine/traced_module/traced_module.py @@ -14,6 +14,7 @@ import inspect import keyword import re import weakref +from importlib import import_module from inspect import getcallargs, getmembers, isclass, ismethod from itertools import chain from types import FunctionType @@ -53,6 +54,7 @@ from ..quantization.observer import ( SyncMinMaxObserver, ) from ..tensor import Tensor +from ..version import __version__ from .expr import ( Apply, CallFunction, @@ -80,8 +82,27 @@ from .module_tracer import ( set_active_module_tracer, ) from .node import ModuleNode, Node, NodeMixin, TensorNode -from .pytree import ArgsIndex, tree_flatten -from .utils import replace_container_with_module_container +from .pytree import ( + USER_REGISTERED_CONTAINER_TYPE, + USER_REGISTERED_LEAF_TYPE, + ArgsIndex, + TreeDef, + _register_supported_type, + tree_flatten, +) +from .serialization import ( + _ModuleState, + load_apply_expr, + load_call_module_expr, + load_call_tensor_method_expr, + load_functional, +) +from .utils import ( + _check_builtin_module_attr, + _check_obj_attr, + _convert_kwargs_to_args, + replace_container_with_module_container, +) logger = get_logger(__name__) @@ -341,7 +362,7 @@ class NameSpace: def create_unique_name(self, name: str, node: Any = None) -> str: assert isinstance(name, str), "The name must be a string" - if name in self._used_names and self._used_names[name] is node: + if name in self._used_names and (self._used_names[name] is node): return name name = re.sub("[^0-9a-zA-Z_]+", "_", name) @@ -1067,6 +1088,7 @@ class InternalGraph: if node2value[n][1] == 0: node2value.pop(n) if values is not None: + assert len(values) == len(expr.outputs) for n, v in zip(expr.outputs, values): if ref_count(n) > 0: node2value[n] = [v, ref_count(n)] @@ -1105,13 +1127,27 @@ class InternalGraph: return res def __getstate__(self): - state = self.__dict__.copy() - if "_top_graph" in state: - state.pop("_top_graph") + state = { + "_exprs": self._exprs, + "_inputs": self._inputs, + "_outputs": self._outputs, + "_watch_point": [], + "_end_point": [], + "_namespace": self._namespace, + "_rst": collections.defaultdict(list), + "_name": self._name, + "_qualname": self._qualname, + } + if self._total_ids: + state["_total_ids"] = self._total_ids + + _check_obj_attr(state) + return state def __setstate__(self, state): old_version = False + if "_module_name" in state: old_version = True state["_qualname"] = state.pop("_module_name") @@ -1144,6 +1180,25 @@ class InternalGraph: self._namespace = NameSpace(self._name, self._qualname) self._re_associate_name() + def __copy__(self): + cls = self.__class__ + result = cls.__new__(cls) + result.__dict__.update(self.__dict__) + return result + + def __deepcopy__(self, memo): + if id(self) in memo: + return memo[id(self)] + cls = self.__class__ + result = cls.__new__(cls) + state = {} + memo[id(self)] = result + for k, v in self.__dict__.items(): + if not isinstance(v, weakref.ReferenceType): + state[k] = copy.deepcopy(v, memo) + result.__dict__.update(state) + return result + def _get_meth_name(obj, func): tp = obj if isinstance(obj, type) else type(obj) @@ -1157,9 +1212,7 @@ def _get_meth_name(obj, func): def _wrapped_function(orig_func): @functools.wraps(orig_func) def wrapped_fn(*args, **kwargs): - method_func = wrapped_fn - if "method_func" in kwargs: - method_func = kwargs.pop("method_func") + method_func = kwargs.pop("method_func", wrapped_fn) if is_tracing_module(): unset_module_tracing() inputs, tree_def = tree_flatten((args, kwargs)) @@ -1167,11 +1220,11 @@ def _wrapped_function(orig_func): if not NodeMixin.get(i, None): if isinstance(i, (RawTensor, NodeMixin)): NodeMixin.wrap_safe(i, Constant.make(i)) - meth_name, arg_type = None, None - if args: - meth_name = _get_meth_name(args[0], method_func) - arg_type = args[0] if isinstance(args[0], type) else type(args[0]) + args, kwargs = _convert_kwargs_to_args(orig_func, args, kwargs) + meth_name = _get_meth_name(args[0], method_func) + arg_type = args[0] if isinstance(args[0], type) else type(args[0]) if meth_name and arg_type and issubclass(arg_type, RawTensor): + inputs, tree_def = tree_flatten((args, kwargs)) self = inputs[0] if meth_name == "__new__": if all([not isinstance(i, RawTensor) for i in inputs]): @@ -1190,6 +1243,7 @@ def _wrapped_function(orig_func): call_node = CallMethod.make(NodeMixin.get(self), meth_name) call_node.add_inputs(inputs[1:]) else: + inputs, tree_def = tree_flatten((args, kwargs)) call_node = CallFunction.make(orig_func) call_node.add_inputs(inputs) @@ -1228,9 +1282,11 @@ class TracedModuleBuilder(NodeMixin): "_record_wrapped_nodes", "_argdef_graph_map", "_argdef_outdef_map", + "_check_qat_module", "nodes", "__class__", "__dict__", + "_is_top", ] def __init__(self, mod, is_top_module=False): @@ -1301,22 +1357,18 @@ class TracedModuleBuilder(NodeMixin): qat_module.weight_fake_quant.set_qparams(qparams) def build(self): - if self._is_builtin or isinstance(self._mod, TracedModule): - if module_tracer.is_builtin(self._mod) or isinstance( - self._mod, TracedModule - ): - mod_type = type(self._mod) - else: - assert isinstance(self._mod, (Observer, _FakeQuantize)) - mod_type = ( - Observer if isinstance(self._mod, Observer) else _FakeQuantize - ) + if self._is_builtin: + assert module_tracer.is_builtin(self._mod) + mod_type = type(self._mod) + for node in self.nodes: node.module_type = mod_type return self._mod else: - is_qat = isinstance(self._mod, QATModule) + is_qat = isinstance(self._mod, QATModule) or ( + isinstance(self._mod, TracedModule) and self._mod.is_qat + ) traced_module = TracedModule( self._is_top, self._argdef_graph_map, self._argdef_outdef_map, is_qat ) @@ -1338,15 +1390,18 @@ class TracedModuleBuilder(NodeMixin): traced_module.with_act = self._mod.with_act traced_module.with_weight = self._mod.with_weight if not hasattr(traced_module, "act_fake_quant"): - traced_module.act_fakequant = None + traced_module.act_fake_quant = None if not hasattr(traced_module, "act_observer"): traced_module.act_observer = None if not hasattr(traced_module, "weight_fake_quant"): - traced_module.weight_fakequant = None + traced_module.weight_fake_quant = None if not hasattr(traced_module, "weight_observer"): traced_module.weight_observer = None set_module_tracing() + if self._is_top: + traced_module._update_ref() + return traced_module def _record_wrapped_nodes(self, node): @@ -1357,6 +1412,7 @@ class TracedModuleBuilder(NodeMixin): # prepare args and kwargs for inner graph if "method_func" in kwargs: kwargs.pop("method_func") + args, kwargs = _convert_kwargs_to_args(self._mod.forward, args, kwargs, True) def mark_constant(x): node = NodeMixin.get(x, None) @@ -1372,11 +1428,7 @@ class TracedModuleBuilder(NodeMixin): callnode.arg_def = tree_def - if ( - self._is_builtin - or tree_def in self._argdef_graph_map - or isinstance(self._mod, TracedModule) - ): + if self._is_builtin or tree_def in self._argdef_graph_map: unset_module_tracing() rst = self._mod(*args, **kwargs) outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf) @@ -1385,33 +1437,7 @@ class TracedModuleBuilder(NodeMixin): self._body = None elif tree_def in self._argdef_graph_map: self._body = self._argdef_graph_map[tree_def] - else: - self._mod._is_top = False - self._body = self._mod.argdef_graph_map[tree_def] - module_qualname = NodeMixin.get(self).qualname - if module_qualname != self._body.qualname: - src_name, dst_name = self._body.qualname, module_qualname - - def replace_qualname(g): - attr_name = get_suffix_name(src_name, g.qualname) - if attr_name is not None: - g._qualname = ( - ("%s.%s" % (dst_name, attr_name)) - if attr_name - else dst_name - ) - assert get_suffix_name(dst_name, g.qualname) is not None - - for mod in self._mod.modules(): - if not hasattr(mod, "argdef_graph_map"): - continue - for g in mod.argdef_graph_map.values(): - replace_qualname(g) - g._namespace.qualname = g.qualname - for n in g.nodes(False): - replace_qualname(n) else: - self_node = None orig_self = NodeMixin.get(self) parent_graph = active_module_tracer().current_scope() module_qualname = orig_self._qualname @@ -1423,20 +1449,14 @@ class TracedModuleBuilder(NodeMixin): active_module_tracer().push_scope(self._body) # rebind self to new input node - if self_node: - NodeMixin.wrap_safe(self, self_node) - active_module_tracer().current_scope()._add_input(self_node) - else: - NodeMixin.wrap_safe( - self, - self_node - if self_node - else Input.make( - name="self", - qualname=module_qualname, - type=NodeMixin.get_wrapped_type(self), - ), - ) + NodeMixin.wrap_safe( + self, + Input.make( + name="self", + qualname=module_qualname, + type=NodeMixin.get_wrapped_type(self), + ), + ) origin_inp_node = [NodeMixin.get(i, None) for i in inputs[1:]] # prepare args and kwargs for inner graph @@ -1470,8 +1490,23 @@ class TracedModuleBuilder(NodeMixin): return x args = [self] - for i, v in enumerate(inputs[1:]): - args.append(wrap(v, idx2key[i + 1])) + orig_traced_inputs = ( + None + if not isinstance(self._mod, TracedModule) + else self._mod.argdef_graph_map[tree_def].inputs + ) + ind = 1 + for v in inputs[1:]: + if isinstance(v, (RawTensor, NodeMixin)): + args_name = ( + orig_traced_inputs[ind]._name + if orig_traced_inputs + else idx2key[ind] + ) + ind += 1 + args.append(wrap(v, args_name)) + else: + args.append(v) args, kwargs = tree_def.unflatten(args) active_module_tracer().patcher.auto_patch( @@ -1514,7 +1549,6 @@ class TracedModuleBuilder(NodeMixin): attr = getattr(type(self._mod), name).__get__(self, type(self)) else: attr = getattr(self._mod, name) - if ( isinstance(attr, FunctionType) and id(attr) in active_module_tracer().patcher.patched_fn_ids @@ -1568,7 +1602,7 @@ class TracedModuleBuilder(NodeMixin): wrapped = self.__getattr__(name) if isinstance(wrapped, TracedModuleBuilder): - if not isinstance(mod_attr, (List, Dict)): + if not isinstance(mod_attr, (List, Dict, QATModule)): assert mod_attr is wrapped._mod else: assert mod_attr is wrapped @@ -1977,8 +2011,6 @@ class TracedModule(Module): def graph(self) -> InternalGraph: """Return the ``InternalGraph`` of this ``TracedModule``. """ - if self._is_top: - self._update_ref() assert len(self.argdef_graph_map) == 1 return list(self.argdef_graph_map.values())[0] @@ -2112,7 +2144,7 @@ class TracedModule(Module): if hasattr(obj, "argdef_graph_map") else None ) - if expr_graph is not None: + if expr_graph is not None and not obj.is_qat: exprs = _flatten_subgraph(graph, expr_graph, expr, obj) if parent_graph is not None: @@ -2137,26 +2169,119 @@ class TracedModule(Module): ) new_module.graph._re_associate_name() new_module.graph.compile() + new_module._update_ref() new_module.graph._reset_ids() return new_module def __getstate__(self): - d = self.__dict__ + d = self.__dict__.copy() for k in Module.__dict__: d.pop(k, None) + _check_obj_attr(d) + for k in d: + if module_tracer.is_builtin(d[k]): + assert _check_builtin_module_attr( + d[k] + ), "Module {} can not be serialized. ".format(type(d[k])) + d[k] = _ModuleState.get_module_state(d[k]) + dump_info = { + "version": __version__, + "register_type": USER_REGISTERED_LEAF_TYPE, + "register_container_type": USER_REGISTERED_CONTAINER_TYPE, + "register_mdule": USER_REGISTERED_MODULE, + "register_function": USER_REGISTERED_FUNCTION, + } + d["dump_info"] = dump_info return d + def __setstate__(self, state): + + for k, v in state.items(): + if isinstance(v, _ModuleState): + state[k] = v.to_module() + self.__dict__.update(state) + self._update_ref() + + for _, graph in self.argdef_graph_map.items(): + for expr in graph._exprs: + if isinstance(expr, CallFunction): + load_functional(expr) + if isinstance(expr, CallMethod): + if expr.method == "__call__": + load_call_module_expr(expr) + else: + load_call_tensor_method_expr(expr) + if isinstance(expr, Apply): + load_apply_expr(expr) + + for _, graph in self.argdef_graph_map.items(): + ind = 0 + while ind < len(graph._exprs): + cur_expr = graph._exprs[ind] + has_new_expr = False + for i in cur_expr.inputs: + if i.expr not in graph._exprs and not isinstance(i.expr, Input): + graph._exprs.insert(ind, i.expr) + has_new_expr = True + if not has_new_expr: + ind += 1 + for expr in graph._exprs: + for i in expr.inputs: + if expr.inputs.count(i) != i.users.count(expr): + add_or_del_count = expr.inputs.count(i) - i.users.count(expr) + if add_or_del_count > 0: + i.users.extend([expr] * add_or_del_count) + else: + [i.users.remove(expr) for i in range(-add_or_del_count)] + + for o in expr.outputs: + if o.expr is not expr: + assert o not in o.expr.outputs + o.expr = expr + for node in graph.nodes(False): + # remove users of node which doesn't use node as input + node.users = [e for e in node.users if node in e.inputs] + + for expr in graph._exprs: + graph._namespace.auto_naming_for_outputs(expr) + self._update_ref() + for _, graph in self.argdef_graph_map.items(): + graph._reset_ids() + + def __copy__(self): + cls = self.__class__ + result = cls.__new__(cls) + result.__dict__.update(self.__dict__) + return result + + def __deepcopy__(self, memo): + cls = self.__class__ + result = cls.__new__(cls) + state = {} + memo[id(self)] = result + for k, v in self.__dict__.items(): + if not isinstance(v, weakref.ReferenceType): + state[k] = copy.deepcopy(v, memo) + result.__dict__.update(state) + result._update_ref() + return result + def cpp_apply_module_trace(opdef, *args): return Apply.apply_module_trace_hook(opdef, *args) +USER_REGISTERED_MODULE = [] +USER_REGISTERED_FUNCTION = [] + + def register_as_builtin(mod_cls: Type[Module]) -> None: r"""Registers class ``mod_cls`` (subclass of :class:`~.Module`) as builtin module. Args: mod_cls: the module class which will be treated as builtin module in tracing. """ + USER_REGISTERED_MODULE.append((mod_cls.__module__, mod_cls.__qualname__)) module_tracer.register_as_builtin(mod_cls) @@ -2181,6 +2306,7 @@ def wrap(func: Callable): Args: func: the function of the global function to insert into the graph when it's called. """ + USER_REGISTERED_FUNCTION.append((func.__module__, func.__qualname__)) assert callable(func), "func must be a callable" assert hasattr(func, "__code__") fn_name = func.__code__.co_name @@ -2247,6 +2373,8 @@ def trace_module( NodeMixin.wrap_safe( builder, Input.make(name="top", type=ModuleNode, qualname=net_name) ) + args, kwargs = _convert_kwargs_to_args(mod.forward, args, kwargs, True) + inputs, _ = tree_flatten((args, kwargs)) for _, i in enumerate(inputs): # assert isinstance(i, Tensor), "not support " diff --git a/imperative/python/megengine/traced_module/utils.py b/imperative/python/megengine/traced_module/utils.py index 67bb06b65..9038a5118 100644 --- a/imperative/python/megengine/traced_module/utils.py +++ b/imperative/python/megengine/traced_module/utils.py @@ -5,12 +5,17 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import collections import copy +import inspect from collections.abc import MutableMapping, MutableSequence -from typing import Dict, Iterable, List, Optional, Sequence +from typing import Dict, Iterable, List, Optional, Sequence, Type +from .. import get_logger from ..module import Module +logger = get_logger(__name__) + def replace_container_with_module_container(container): has_module = False @@ -52,6 +57,101 @@ def replace_container_with_module_container(container): return has_module, module_container +def _convert_kwargs_to_args(func, args, kwargs, is_bounded=False): + # is_bounded = True when func is a method and provided args don't include 'self' + arg_specs = inspect.getfullargspec(func) + arg_specs_args = arg_specs.args + if is_bounded: + arg_specs_args = arg_specs.args[1:] + new_args = [] + new_kwargs = {} + new_args.extend(args) + if set(arg_specs_args[0 : len(new_args)]) & set(kwargs.keys()): + repeated_arg_name = set(arg_specs_args[0 : len(new_args)]) & set(kwargs.keys()) + raise TypeError( + "{} got multiple values for argument {}".format( + func.__qualname__, ", ".join(repeated_arg_name) + ) + ) + if len(new_args) < len(arg_specs.args): + for ind in range(len(new_args), len(arg_specs_args)): + arg_name = arg_specs_args[ind] + if arg_name in kwargs: + new_args.append(kwargs[arg_name]) + else: + index = ind - len(arg_specs_args) + len(arg_specs.defaults) + assert index < len(arg_specs.defaults) and index >= 0 + new_args.append(arg_specs.defaults[index]) + + for kwarg_name in arg_specs.kwonlyargs: + if kwarg_name in kwargs: + new_kwargs[kwarg_name] = kwargs[kwarg_name] + else: + assert kwarg_name in arg_specs.kwonlydefaults + new_kwargs[kwarg_name] = arg_specs.kwonlydefaults[kwarg_name] + for k, v in kwargs.items(): + if k not in arg_specs.args and k not in arg_specs.kwonlyargs: + if arg_specs.varkw is None: + raise TypeError( + "{} got an unexpected keyword argument {}".format( + func.__qualname__, k + ) + ) + new_kwargs[k] = v + return tuple(new_args), new_kwargs + + +def _check_obj_attr(obj): + # check if all the attributes of a obj is serializable + from .pytree import tree_flatten + from .pytree import SUPPORTED_LEAF_CLS, SUPPORTED_LEAF_TYPE, TreeDef + from .expr import Expr + from .traced_module import TracedModule, InternalGraph, NameSpace + + def _check_leaf_type(leaf): + leaf_type = leaf if isinstance(leaf, type) else type(leaf) + traced_module_types = [Expr, TreeDef, TracedModule, InternalGraph, NameSpace] + return ( + issubclass(leaf_type, tuple(SUPPORTED_LEAF_CLS + traced_module_types)) + or leaf_type in SUPPORTED_LEAF_TYPE + ) + + for _, v in obj.items(): + leafs, _ = tree_flatten(v, is_leaf=lambda _: True) + for leaf in leafs: + assert _check_leaf_type( + leaf + ), "Type {} is not supported by traced module".format( + leaf if isinstance(leaf, type) else type(leaf) + ) + + +def _check_builtin_module_attr(mod): + from .pytree import _is_leaf as _check_leaf_type + from .pytree import tree_flatten + + # check if all the attributes of a builtin module is serializable + is_non_serializable_module = lambda m: isinstance( + m, Module + ) and not _check_builtin_module_attr(m) + for k, v in mod.__dict__.items(): + if k == "_m_dump_modulestate": + continue + if is_non_serializable_module(v): + return False + elif not isinstance(v, Module): + leafs, _ = tree_flatten(v, is_leaf=lambda _: True) + for leaf in leafs: + if not _check_leaf_type(leaf) or is_non_serializable_module(leaf): + logger.warn( + "Type {} is not supported by traced module".format( + leaf if isinstance(leaf, type) else type(leaf) + ) + ) + return False + return True + + class _ModuleList(Module, MutableSequence): r"""A List-like container. diff --git a/imperative/python/test/unit/core/test_serialization.py b/imperative/python/test/unit/core/test_serialization.py index 382efe8d4..15f47eb83 100644 --- a/imperative/python/test/unit/core/test_serialization.py +++ b/imperative/python/test/unit/core/test_serialization.py @@ -15,7 +15,6 @@ import numpy as np import megengine as mge from megengine import Parameter, Tensor from megengine.core.ops import builtin -from megengine.traced_module.serialization import get_opdef_state, load_opdef_from_state def test_tensor_serialization(): @@ -88,25 +87,3 @@ def test_compatibility(): test_old_tensor("tensor_v1_1.mge") test_old_tensor("tensor_v1_2.mge") - - -def test_opdef_serialization(): - with TemporaryFile() as f: - x = builtin.Elemwise(mode="Add") - pickle.dump(get_opdef_state(x), f) - f.seek(0) - load_x = load_opdef_from_state(pickle.load(f)) - assert x == load_x - - with TemporaryFile() as f: - x = builtin.Convolution(stride_h=9, compute_mode="float32") - x.strategy = ( - builtin.Convolution.Strategy.PROFILE - | builtin.Convolution.Strategy.HEURISTIC - | builtin.Convolution.Strategy.REPRODUCIBLE - ) - pickle.dump(get_opdef_state(x), f) - f.seek(0) - load_x = load_opdef_from_state(pickle.load(f)) - assert x.strategy == load_x.strategy - assert x == load_x diff --git a/imperative/python/test/unit/traced_module/test_modification.py b/imperative/python/test/unit/traced_module/test_modification.py index 8a7c80a3d..4797a8daf 100644 --- a/imperative/python/test/unit/traced_module/test_modification.py +++ b/imperative/python/test/unit/traced_module/test_modification.py @@ -85,12 +85,12 @@ class NewModule(M.Module): return x -def _check_expr_users(traced_module): +def _check_expr_users(flattened_module): node_user = defaultdict(list) - for expr in traced_module.graph._exprs: + for expr in flattened_module.graph._exprs: for node in expr.inputs: node_user[node].append(expr) - for node in traced_module.graph.nodes(): + for node in flattened_module.graph.nodes(): node.users.sort(key=lambda m: m._id) node_user[node].sort(key=lambda m: m._id) assert node.users == node_user[node] diff --git a/imperative/python/test/unit/traced_module/test_qat_module.py b/imperative/python/test/unit/traced_module/test_qat_module.py index 1bcb74d64..d40111634 100644 --- a/imperative/python/test/unit/traced_module/test_qat_module.py +++ b/imperative/python/test/unit/traced_module/test_qat_module.py @@ -8,6 +8,7 @@ import numpy as np import megengine as mge import megengine.functional as F import megengine.module as M +import megengine.module.qat as QM import megengine.quantization as Q from megengine import Tensor from megengine.module.qat.module import QATModule @@ -28,10 +29,18 @@ def get_subattr(self: M.Module, name: str): return getattr(self, name) +class MyConvBnRelu2d(M.ConvBnRelu2d): + pass + + +class MyQATConvBnRelu2d(QM.ConvBnRelu2d): + pass + + class Myblcok(M.Module): def __init__(self,): super().__init__() - self.conv0 = M.ConvBnRelu2d(3, 3, 3, 1, 1) + self.conv0 = MyConvBnRelu2d(3, 3, 3, 1, 1) self.conv1 = M.ConvBn2d(3, 3, 1, 1, 0) self.conv2 = M.ConvBn2d(3, 3, 1, 1, 0) self.add = M.Elemwise("FUSE_ADD_RELU") @@ -106,7 +115,11 @@ def check_qparams(qparmsa: Q.QParams, qparmsb: Q.QParams): def build_observered_net(net: M.Module, observer_cls): - qat_net = Q.quantize_qat(net, qconfig=get_observer_config(observer_cls)) + qat_net = Q.quantize_qat( + net, + qconfig=get_observer_config(observer_cls), + mapping={MyConvBnRelu2d: MyQATConvBnRelu2d}, + ) Q.enable_observer(qat_net) inp = Tensor(np.random.random(size=(5, 3, 32, 32))) qat_net(inp) @@ -134,6 +147,15 @@ def test_trace_qat(): check_qparams(weight_qparams, traced_weight_qparams) if act_qparams: check_qparams(act_qparams, traced_act_qparams) + flatten_traced_net = traced_net.flatten() + conv0_node = flatten_traced_net.graph.get_node_by_name( + "MyModule_block0_conv0" + ).as_unique() + conv0_out_node = flatten_traced_net.graph.get_node_by_name( + "MyModule_block0_conv0_out" + ).as_unique() + assert isinstance(conv0_node.owner, TracedModule) + assert conv0_out_node.expr.inputs[0] is conv0_node _check_qat_module(build_observered_net(MyModule(), Q.MinMaxObserver)) _check_qat_module(build_observered_net(MyModule(), MyMinMaxObserver)) diff --git a/imperative/python/test/unit/traced_module/test_serialization.py b/imperative/python/test/unit/traced_module/test_serialization.py index 167dd46de..3e952ee9b 100644 --- a/imperative/python/test/unit/traced_module/test_serialization.py +++ b/imperative/python/test/unit/traced_module/test_serialization.py @@ -6,14 +6,59 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import pickle +from collections import defaultdict +from tempfile import TemporaryFile import numpy as np import megengine.functional as F import megengine.module as M +import megengine.traced_module.serialization as S from megengine import Tensor +from megengine.core._imperative_rt.core2 import apply +from megengine.core.ops import builtin +from megengine.core.ops.builtin import Elemwise from megengine.module import Module from megengine.traced_module import trace_module +from megengine.traced_module.expr import CallMethod, Constant +from megengine.traced_module.node import TensorNode +from megengine.traced_module.serialization import ( + register_functional_loader, + register_module_loader, + register_opdef_loader, + register_tensor_method_loader, +) +from megengine.traced_module.utils import _convert_kwargs_to_args + + +def _check_id(traced_module): + _total_ids = traced_module.graph._total_ids + node_ids = [n._id for n in traced_module.graph.nodes().as_list()] + assert len(set(node_ids)) == len(node_ids) + assert max(node_ids) + 1 == _total_ids[0] + + expr_ids = [n._id for n in traced_module.graph.exprs().as_list()] + assert len(set(expr_ids)) == len(expr_ids) + assert max(expr_ids) + 1 == _total_ids[1] + + +def _check_name(flatened_module): + node_names = [n._name for n in flatened_module.graph.nodes().as_list()] + assert len(set(node_names)) == len(node_names) + + +def _check_expr_users(traced_module): + node_user = defaultdict(list) + for expr in traced_module.graph._exprs: + for node in expr.inputs: + node_user[node].append(expr) + if isinstance(expr, CallMethod) and expr.graph: + _check_expr_users(expr.inputs[0].owner) + + for node in traced_module.graph.nodes(False): + node.users.sort(key=lambda m: m._id) + node_user[node].sort(key=lambda m: m._id) + assert node.users == node_user[node] class MyBlock(Module): @@ -48,5 +93,274 @@ def test_dump_and_load(): traced_module = trace_module(module, x) np.testing.assert_array_equal(expect, traced_module(x)) obj = pickle.dumps(traced_module) - pickle.loads(obj) + new_tm = pickle.loads(obj) + _check_id(new_tm) + _check_expr_users(new_tm) + traced_module.graph._reset_ids() + old_nodes = traced_module.graph.nodes().as_list() + new_nodes = new_tm.graph.nodes().as_list() + old_exprs = traced_module.graph.exprs().as_list() + new_exprs = new_tm.graph.exprs().as_list() + assert len(old_nodes) == len(new_nodes) + for i, j in zip(old_nodes, new_nodes): + assert i._name == j._name + assert i._qualname == j._qualname + assert i._id == j._id + assert len(old_exprs) == len(new_exprs) + for i, j in zip(old_exprs, new_exprs): + assert i._id == j._id + np.testing.assert_array_equal(expect, traced_module(x)) + + +def test_opdef_loader(): + class MyModule1(Module): + def forward(self, x, y): + op = Elemwise("ADD") + return apply(op, x, y)[0] + + m = MyModule1() + x = Tensor(np.ones((20))) + y = Tensor(np.ones((20))) + traced_module = trace_module(m, x, y) + orig_loader_dict = S.OPDEF_LOADER + S.OPDEF_LOADER = {} + + @register_opdef_loader(Elemwise) + def add_opdef_loader(expr): + if expr.opdef_state["mode"] == "ADD": + expr.opdef_state["mode"] = "MUL" + node = expr.inputs[1] + astype_expr = CallMethod(node, "astype") + oup = TensorNode( + astype_expr, + shape=node.shape, + dtype=expr.inputs[0].dtype, + qparams=node.qparams, + ) + astype_expr.set_args_kwargs(node, expr.inputs[0].dtype) + astype_expr.return_val = (oup,) + expr.inputs[1] = oup + + obj = pickle.dumps(traced_module) + new_module = pickle.loads(obj) + _check_id(new_module) + _check_expr_users(new_module) + _check_name(new_module.flatten()) + assert ( + isinstance(new_module.graph._exprs[0], CallMethod) + and new_module.graph._exprs[1].opdef.mode == "MUL" + and len(new_module.graph._exprs) == 2 + ) + result = new_module(x, y) + np.testing.assert_equal(result.numpy(), x.numpy()) + S.OPDEF_LOADER = orig_loader_dict + + +def test_functional_loader(): + class MyModule2(Module): + def forward(self, x, y): + return F.conv2d(x, y) + + m = MyModule2() + x = Tensor(np.random.random((1, 3, 32, 32))) + y = Tensor(np.random.random((3, 3, 3, 3))) + traced_module = trace_module(m, x, y) + orig_loader_dict = S.FUNCTIONAL_LOADER + S.FUNCTIONAL_LOADER = {} + + @register_functional_loader(("megengine.functional.nn", "conv2d")) + def conv2df_loader(expr): + # expr.func = ("megengine.functional.nn","conv2d") + kwargs = expr.kwargs + orig_weight = expr.named_args["weight"] + + astype_expr = CallMethod(orig_weight, "astype") + oup = TensorNode( + astype_expr, + shape=orig_weight.shape, + dtype=orig_weight.dtype, + qparams=orig_weight.qparams, + ) + + astype_expr.set_args_kwargs(orig_weight, expr.named_args["inp"].dtype) + astype_expr.return_val = (oup,) + + expr.set_arg("weight", oup) + + obj = pickle.dumps(traced_module) + new_module = pickle.loads(obj) + _check_expr_users(new_module) + _check_id(new_module) + result = new_module(x, y) + gt = m(x, y) + assert ( + isinstance(new_module.graph._exprs[0], CallMethod) + and len(new_module.graph._exprs) == 2 + ) + np.testing.assert_equal(result.numpy(), gt.numpy()) + S.FUNCTIONAL_LOADER = orig_loader_dict + + +def test_tensor_method_loader(): + class MyModule3(Module): + def forward(self, x): + return x + 1 + + m = MyModule3() + x = Tensor(np.ones((20))) + traced_module = trace_module(m, x) + orig_loader_dict = S.TENSORMETHOD_LOADER + S.TENSORMETHOD_LOADER = {} + + @register_tensor_method_loader("__add__") + def add_loader(expr): + args = list(expr.args) + if not isinstance(args[1], TensorNode): + args[1] = Tensor(args[1]) + node = Constant(args[1], "const").outputs[0] + + astype_expr = CallMethod(node, "astype") + oup = TensorNode( + astype_expr, shape=node.shape, dtype=node.dtype, qparams=node.qparams, + ) + astype_expr.set_args_kwargs(node, expr.inputs[0].dtype) + astype_expr.return_val = (oup,) + + add_expr = CallMethod(oup, "__add__") + add_expr.set_args_kwargs(oup, oup) + oup1 = TensorNode( + add_expr, shape=oup.shape, dtype=oup.dtype, qparams=node.qparams, + ) + add_expr.return_val = oup1 + args[1] = oup1 + expr.set_args_kwargs(*args) + + obj = pickle.dumps(traced_module) + new_module = pickle.loads(obj) + _check_expr_users(new_module) + _check_id(new_module) + result = new_module(x) + gt = m(x) + assert ( + isinstance(new_module.graph._exprs[0], Constant) + and len(new_module.graph._exprs) == 4 + ) + np.testing.assert_equal(result.numpy(), (x + 2).numpy()) + S.TENSORMETHOD_LOADER = orig_loader_dict + + +def test_module_loader(): + class MyModule4(Module): + def __init__(self): + super().__init__() + self.conv = M.Conv2d(3, 3, 3) + + def forward(self, x): + return self.conv(x) + + m = MyModule4() + x = Tensor(np.random.random((1, 3, 32, 32))) + traced_module = trace_module(m, x) + orig_loader_dict = S.MODULE_LOADER + S.MODULE_LOADER = {} + + @register_module_loader(("megengine.module.conv", "Conv2d")) + def conv2dm_loader(expr): + module = expr.inputs[0].owner + args = list(expr.args) + orig_inp = args[1] + astype_expr = CallMethod(orig_inp, "astype") + oup = TensorNode( + astype_expr, + shape=orig_inp.shape, + dtype=orig_inp.dtype, + qparams=orig_inp.qparams, + ) + astype_expr.set_args_kwargs(orig_inp, module.weight.dtype) + astype_expr.return_val = (oup,) + args[1] = oup + expr.set_args_kwargs(*args) + + obj = pickle.dumps(traced_module) + new_module = pickle.loads(obj) + result = new_module(x) + gt = m(x) + assert ( + isinstance(new_module.graph._exprs[1], CallMethod) + and len(new_module.graph._exprs) == 3 + ) + np.testing.assert_equal(result.numpy(), gt.numpy()) + S.MODULE_LOADER = orig_loader_dict + + +def test_shared_module(): + class MyModule(M.Module): + def __init__(self): + super().__init__() + self.a = M.Elemwise("ADD") + self.b = self.a + + def forward(self, x, y): + z = self.a(x, y) + z = self.b(z, y) + return z + + x = Tensor(1) + y = Tensor(2) + m = MyModule() + tm = trace_module(m, x, y) + obj = pickle.dumps(tm) + load_tm = pickle.loads(obj) + _check_expr_users(load_tm) + _check_name(load_tm.flatten()) + _check_id(load_tm) + assert load_tm.a is load_tm.b + + +def test_convert_kwargs_to_args(): + def func(a, b, c=4, *, d, e=3, f=4): + pass + + args = (1,) + kwargs = {"b": 1, "d": 6} + new_args, new_kwargs = _convert_kwargs_to_args(func, args, kwargs) + assert new_args == (1, 1, 4) + assert new_kwargs == {"d": 6, "e": 3, "f": 4} + + args = (1,) + kwargs = {"d": 6} + new_args, new_kwargs = _convert_kwargs_to_args(func, args, kwargs, is_bounded=True) + assert new_args == (1, 4) + assert new_kwargs == {"d": 6, "e": 3, "f": 4} + + def func1(a, b, c, d, e, *, f): + pass + + args = () + kwargs = {"a": 1, "b": 2, "c": 3, "d": 4, "e": 5, "f": 6} + new_args, new_kwargs = _convert_kwargs_to_args(func1, args, kwargs) + assert new_args == (1, 2, 3, 4, 5) + assert new_kwargs == {"f": 6} + + +def test_opdef_serialization(): + with TemporaryFile() as f: + x = builtin.Elemwise(mode="Add") + pickle.dump(x, f) + f.seek(0) + load_x = pickle.load(f) + assert x == load_x + + with TemporaryFile() as f: + x = builtin.Convolution(stride_h=9, compute_mode="float32") + x.strategy = ( + builtin.Convolution.Strategy.PROFILE + | builtin.Convolution.Strategy.HEURISTIC + | builtin.Convolution.Strategy.REPRODUCIBLE + ) + pickle.dump(x, f) + f.seek(0) + load_x = pickle.load(f) + assert x.strategy == load_x.strategy + assert x == load_x -- GitLab