提交 7b19bc76 编写于 作者: M Megvii Engine Team

feat(traced_module): support traced module backward compatible serialization

GitOrigin-RevId: aaa9e51c74c11fa7955ae7bbfac476fa9bcf0d7d
上级 ffbfe59c
......@@ -7,6 +7,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ..core._imperative_rt.core2 import set_cpp_apply_module_trace
from . import compat
from .traced_module import (
TracedModule,
_register_all_builtin_module,
......
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
from .. import tensor
from ..core.ops.builtin import BatchNorm
from .expr import CallMethod, Constant
from .node import TensorNode
from .serialization import (
register_functional_loader,
register_module_loader,
register_opdef_loader,
register_tensor_method_loader,
)
"""
# Expr loaders examples
from ..core.ops.builtin import Elemwise
@register_opdef_loader(Elemwise)
def add_opdef_loader(expr):
if expr.opdef_state["mode"] == "ADD":
expr.opdef_state["mode"] == "MUL"
node = expr.inputs[1]
astype_expr = CallMethod(node, "astype")
oup = TensorNode(
astype_expr,
shape=node.shape,
dtype=expr.inputs[0].dtype,
qparams=node.qparams,
)
astype_expr.set_args_kwargs(node, expr.inputs[0].dtype)
astype_expr.return_val = (oup,)
expr.inputs[1] = oup
@register_functional_loader(("megengine.functional.nn", "conv2d"))
def conv2df_loader(expr):
# expr.func = ("megengine.functional.nn","conv2d")
kwargs = expr.kwargs
orig_weight = expr.named_args["weight"]
astype_expr = CallMethod(orig_weight, "astype")
oup = TensorNode(
astype_expr,
shape=orig_weight.shape,
dtype=orig_weight.dtype,
qparams=orig_weight.qparams,
)
astype_expr.set_args_kwargs(orig_weight, expr.named_args["inp"].dtype)
astype_expr.return_val = (oup,)
expr.set_arg("weight", oup)
@register_module_loader(("megengine.module.conv", "Conv2d"))
def conv2dm_loader(expr):
module = expr.inputs[0].owner
args = list(expr.args)
orig_inp = args[1]
astype_expr = CallMethod(orig_inp, "astype")
oup = TensorNode(
astype_expr,
shape=orig_inp.shape,
dtype=orig_inp.dtype,
qparams=orig_inp.qparams,
)
astype_expr.set_args_kwargs(orig_inp, module.weight.dtype)
astype_expr.return_val = (oup,)
args[1] = oup
expr.set_args_kwargs(*args)
@register_tensor_method_loader("__add__")
def add_loader(expr):
args = list(expr.args)
if not isinstance(args[1], TensorNode):
args[1] = tensor(args[1])
node = Constant(args[1], "const").outputs[0]
astype_expr = CallMethod(node, "astype")
oup = TensorNode(
astype_expr, shape=node.shape, dtype=node.dtype, qparams=node.qparams,
)
astype_expr.set_args_kwargs(node, expr.inputs[0].dtype)
astype_expr.return_val = (oup,)
args[1] = oup
expr.set_args_kwargs(*args)
"""
@register_module_loader(
("megengine.module.batchnorm", "BatchNorm1d"),
("megengine.module.batchnorm", "BatchNorm2d"),
("megengine.module.batchnorm", "SyncBatchNorm"),
)
def bn2d_module_loader(expr):
# mge 1.6
if not hasattr(expr, "version"):
module = expr.inputs[0].owner
if not hasattr(module, "param_dim"):
module.param_dim = "dim_1c11"
@register_module_loader(
("megengine.module.conv_bn", "ConvBn2d"),
("megengine.module.conv_bn", "ConvBnRelu2d"),
("megengine.module.qat.conv_bn", "ConvBn2d"),
("megengine.module.qat.conv_bn", "ConvBnRelu2d"),
)
def convbn2d_module_loader(expr):
# mge 1.6
if not hasattr(expr, "version"):
module = expr.inputs[0].owner
if not hasattr(module.bn, "param_dim"):
module.bn.param_dim = "dim_1c11"
@register_opdef_loader(BatchNorm)
def bn_opdef_loader(expr):
# mge 1.6
if not hasattr(expr, "version"):
output = expr.outputs[-1]
oup = TensorNode(expr, shape=(0,), dtype=None, qparams=output._qparams,)
expr.outputs.insert(4, oup)
......@@ -11,19 +11,28 @@ import collections
import copy
import inspect
import re
from typing import Callable, Dict, List, Optional, Union
import weakref
from importlib import import_module
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
from ..core._imperative_rt import OpDef
from ..core._imperative_rt.core2 import Tensor as RawTensor
from ..core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing
from ..core._imperative_rt.core2 import (
apply,
is_tracing_module,
set_module_tracing,
unset_module_tracing,
)
from ..core.ops.builtin import FakeQuant
from ..core.ops.special import Const
from ..module import Module
from ..tensor import Parameter, Tensor
from ..version import __version__
from .module_tracer import active_module_tracer, module_tracer
from .node import ModuleNode, Node, NodeMixin, TensorNode
from .pytree import ArgsIndex, TreeDef, _is_const_leaf, _is_leaf, tree_flatten
from .serialization import get_opdef_state, load_opdef_from_state
from .serialization import _ModuleState
from .utils import _check_builtin_module_attr, _check_obj_attr, _convert_kwargs_to_args
def rstrip(s: str, __chars: str):
......@@ -112,6 +121,7 @@ class Expr:
node.users.append(self)
else:
assert node is None
assert not isinstance(val, (Module, RawTensor))
assert _is_leaf(val) and _is_const_leaf(val)
idx = len(self.inputs) + len(self.const_val)
self.const_val.append((idx, val))
......@@ -132,14 +142,14 @@ class Expr:
current_graph._namespace.auto_naming_for_outputs(self)
def unflatten_args(self, inputs):
if self.arg_def is not None:
inputs = list(inputs)
for idx, val in self.const_val:
inputs.insert(idx, val)
args, kwargs = self.arg_def.unflatten(inputs)
return args, kwargs
else:
return inputs, {}
assert self.arg_def is not None, "{} expr doesn't have args/kwargs".format(
type(self).__name__
)
inputs = list(inputs)
for idx, val in self.const_val:
inputs.insert(idx, val)
args, kwargs = self.arg_def.unflatten(inputs)
return args, kwargs
def replace_inputs(self, repl_dict: Dict[Node, Node]):
r"""Replace the input Nodes of this Expr.
......@@ -165,6 +175,39 @@ class Expr:
node.users.remove(self)
repl_node.users.append(self)
@property
def _support_set_args_kwargs(self):
return False
def set_args_kwargs(self, *args, **kwargs):
r""" Set args and kwargs for Expr.
"""
assert (
self._support_set_args_kwargs
), "Doesn't support set args/kwargs for {} expr".format(type(self).__name__)
args, kwargs = _convert_kwargs_to_args(self._get_func(), args, kwargs)
inputs, arg_def = tree_flatten((args, kwargs))
orig_inputs = self.inputs
self.inputs = []
self.const_val = []
for val in inputs:
if isinstance(val, (TensorNode, ModuleNode)):
self.inputs.append(val)
else:
assert _is_leaf(val) and _is_const_leaf(val)
idx = len(self.inputs) + len(self.const_val)
self.const_val.append((idx, val))
for n in orig_inputs:
if n not in self.inputs:
n.users.remove(self)
for n in self.inputs:
if n not in orig_inputs:
n.users.append(self)
self.arg_def = arg_def
@property
def kwargs(self):
r"""Get the keyword arguments of the operation corresponding to this Expr."""
......@@ -177,6 +220,61 @@ class Expr:
args, _ = self.unflatten_args(self.inputs)
return args
def _get_func(self):
# get called function when the expr is interpreted
raise NotImplementedError
@property
def named_args(self):
func = self._get_func()
return inspect.getcallargs(func, *self.args, **self.kwargs)
def set_arg(self, name, val):
func = self._get_func()
if name in self.kwargs:
new_kwargs = self.kwargs
new_kwargs[name] = val
self.set_args_kwargs(*self.args, **new_kwargs)
else:
arg_spec = inspect.getfullargspec(func)
if name in arg_spec.args:
ind = arg_spec.args.index(name)
new_args = list(self.args)
new_args[ind] = val
self.set_args_kwargs(*new_args)
elif name == arg_spec.varargs:
assert arg_spec.varargs is not None
assert len(self.args) >= len(arg_spec.args)
val = (val,) if not isinstance(val, Sequence) else val
self.set_args_kwargs(*self.args[0 : len(arg_spec.args)], *val)
else:
assert (
arg_spec.varkw is not None
), "func {} does't have argument named {}".format(func, name)
new_kwargs = self.kwargs
new_kwargs[name] = val
self.set_args_kwargs(*self.args, **new_kwargs)
@property
def return_val(self):
return self.out_def.unflatten(self.outputs)
@return_val.setter
def return_val(self, new_outputs):
outputs, out_def = tree_flatten(
new_outputs, is_leaf=lambda x: isinstance(x, Node)
)
assert all(
isinstance(o, Node) for o in outputs
), "Return values of expr must be ModuleNode or TensorNode or Container with them"
assert all(
o.expr in (None, self) for o in outputs
), "Some nodes are produced by other expr, can not be output of expr {}".format(
self
)
self.outputs = outputs
self.out_def = out_def
@property
def top_graph(self):
r"""Get the parent graph of this Expr."""
......@@ -184,12 +282,6 @@ class Expr:
return self._top_graph()
return None
def __getstate__(self):
state = self.__dict__.copy()
if "_top_graph" in state:
state.pop("_top_graph")
return state
@classmethod
def _get_next_id(cls):
return cls.__total_id
......@@ -199,6 +291,23 @@ class Expr:
assert isinstance(id, int)
cls.__total_id = id
def __copy__(self):
cls = self.__class__
result = cls.__new__(cls)
result.__dict__.update(self.__dict__)
return result
def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
state = {}
memo[id(self)] = result
for k, v in self.__dict__.items():
if not isinstance(v, weakref.ReferenceType):
state[k] = copy.deepcopy(v, memo)
result.__dict__.update(state)
return result
# expr: None (i.e. fake expression which is used to mark input)
class Input(Expr):
......@@ -229,6 +338,17 @@ class Input(Expr):
def __repr__(self):
return "%{}:\t{} = Input()".format(self._id, self.outputs[0])
def __getstate__(self):
state = {
"_id": self._id,
"_disable_remove": self._disable_remove,
"inputs": self.inputs,
"outputs": self.outputs,
"name": self.name,
}
_check_obj_attr(state)
return state
# expr: outputs = getattr(inputs[0], self.name)
class GetAttr(Expr):
......@@ -276,11 +396,23 @@ class GetAttr(Expr):
def __repr__(self):
out_type = "Tensor"
if isinstance(self.outputs[0], ModuleNode):
out_type = self.outputs[0].module_type.__name__
m_type = self.outputs[0].module_type
out_type = m_type.__name__ if isinstance(m_type, type) else m_type[1]
return '%{}:\t{} = getattr({}, "{}") -> ({})'.format(
self._id, self.outputs[0], self.inputs[0], self.name, out_type
)
def __getstate__(self):
state = {
"_id": self._id,
"_disable_remove": self._disable_remove,
"inputs": self.inputs,
"outputs": self.outputs,
"name": self.name,
}
_check_obj_attr(state)
return state
# expr: outputs = inputs[0].__call__(*inputs[1:])
class CallMethod(Expr):
......@@ -307,6 +439,7 @@ class CallMethod(Expr):
node,
]
self.const_val = []
self.arg_def = tree_flatten(((node,), {}))[1]
self.method = method
@classmethod
......@@ -342,6 +475,27 @@ class CallMethod(Expr):
outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor))
return outputs
def _get_func(self):
if isinstance(self.args[0], type):
obj_type = self.args[0]
elif isinstance(self.args[0], ModuleNode):
obj_type = self.args[0].module_type
else:
assert isinstance(self.args[0], TensorNode)
obj_type = Tensor
meth = getattr(
obj_type, "forward" if issubclass(obj_type, Module) else self.method
)
return meth
@property
def _support_set_args_kwargs(self):
# only expr call tensor method or builtin module support modify args/kwargs
return (
isinstance(self.args[0], (TensorNode, type))
or self.args[0].module_type is not Module
)
def __repr__(self):
args = ", ".join(str(i) for i in self.args[1:])
kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
......@@ -359,6 +513,21 @@ class CallMethod(Expr):
", ".join([args, kwargs]),
)
def __getstate__(self):
state = {
"_id": self._id,
"_disable_remove": self._disable_remove,
"inputs": self.inputs,
"const_val": self.const_val,
"method": self.method,
"arg_def": self.arg_def,
"out_def": self.out_def,
"outputs": self.outputs,
"version": __version__,
}
_check_obj_attr(state)
return state
# expr: outputs = apply(self.opdef, *inputs)
class Apply(Expr):
......@@ -394,14 +563,32 @@ class Apply(Expr):
)
def __getstate__(self):
state = super().__getstate__()
state["opdef"] = get_opdef_state(state["opdef"])
opdef_state = self.opdef.__getstate__()
opdef_state["opdef_type"] = type(self.opdef)
state = {
"_id": self._id,
"_disable_remove": self._disable_remove,
"opdef_state": opdef_state,
"inputs": self.inputs,
"outputs": self.outputs,
"version": __version__,
}
_check_obj_attr(state)
return state
def __setstate__(self, state):
state["opdef"] = load_opdef_from_state(state["opdef"])
for k, v in state.items():
setattr(self, k, v)
# compat with mge 1.6
if "opdef" in state and "opdef_state" not in state:
opdef_state = state.pop("opdef")
opdef_state["opdef_type"] = opdef_state.pop("type")
state["opdef_state"] = opdef_state
self.__dict__.update(state)
assert isinstance(state["opdef_state"], dict)
opdef_state = state["opdef_state"].copy()
opdef_type = opdef_state.pop("opdef_type")
opdef_obj = opdef_type()
opdef_obj.__setstate__(opdef_state)
setattr(self, "opdef", opdef_obj)
@classmethod
def apply_module_trace_hook(cls, opdef, *inputs):
......@@ -458,12 +645,24 @@ class CallFunction(Expr):
def interpret(self, *inputs):
args, kwargs = self.unflatten_args(inputs)
outputs = self.func(*args, **kwargs)
func = (
self.func
if not is_tracing_module()
else active_module_tracer().patcher.wrap_fn(self.func)
)
outputs = func(*args, **kwargs)
if outputs is None:
return outputs
outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor))
return outputs
def _get_func(self):
return self.func
@property
def _support_set_args_kwargs(self):
return True
def __repr__(self):
args = ", ".join(str(i) for i in self.args)
kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
......@@ -477,6 +676,33 @@ class CallFunction(Expr):
", ".join([args, kwargs]),
)
def __getstate__(self):
state = {
"_id": self._id,
"_disable_remove": self._disable_remove,
"func": (self.func.__module__, self.func.__qualname__),
"const_val": self.const_val,
"inputs": self.inputs,
"arg_def": self.arg_def,
"out_def": self.out_def,
"outputs": self.outputs,
"version": __version__,
}
_check_obj_attr(state)
return state
def __setstate__(self, state):
self.__dict__.update(state)
try:
if isinstance(self.func, tuple):
mname, fname = self.func
f = import_module(mname)
for i in fname.split("."):
f = getattr(f, i)
self.func = f
except Exception:
pass
# expr outputs = self.value
class Constant(Expr):
......@@ -496,6 +722,13 @@ class Constant(Expr):
assert isinstance(c, (RawTensor, Module))
if isinstance(c, Module):
assert module_tracer.is_builtin(c) or c.is_qat
if isinstance(c, RawTensor):
if is_tracing_module():
unset_module_tracing()
c = Tensor(c)
set_module_tracing()
else:
c = Tensor(c)
self.value = c
self.name = name
self.inputs = []
......@@ -530,9 +763,25 @@ class Constant(Expr):
)
def __getstate__(self):
state = self.__dict__.copy()
if "_top_graph" in state:
state.pop("_top_graph")
state = {
"_id": self._id,
"_disable_remove": self._disable_remove,
"value": self.value,
"name": self.name,
"inputs": self.inputs,
"outputs": self.outputs,
}
_check_obj_attr(state)
if isinstance(self.value, RawTensor):
state["value"] = Tensor(self.value)
if isinstance(self.value, Module) and module_tracer.is_builtin(self.value):
_check_builtin_module_attr(self.value)
state["value"] = _ModuleState.get_module_state(self.value)
return state
def __setstate__(self, state):
for k, v in state.items():
if isinstance(v, _ModuleState):
state[k] = v.to_module()
self.__dict__.update(state)
......@@ -72,7 +72,6 @@ BUILTIN_ARRAY_METHOD = [
"astype",
"reshape",
"_broadcast",
"transpose",
"flatten",
"sum",
"prod",
......
......@@ -6,7 +6,9 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import abc
import copy
import weakref
from importlib import import_module
from typing import Any, Dict, List, Tuple, Type
import numpy
......@@ -14,7 +16,9 @@ import numpy
from .. import get_logger
from ..core._imperative_rt.core2 import Tensor as RawTensor
from ..module import Module
from ..quantization.utils import QParams
from ..tensor import Tensor
from .utils import _check_obj_attr
logger = get_logger(__name__)
......@@ -145,6 +149,23 @@ class Node:
assert isinstance(id, int)
cls.__total_id = id
def __copy__(self):
cls = self.__class__
result = cls.__new__(cls)
result.__dict__.update(self.__dict__)
return result
def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
state = {}
memo[id(self)] = result
for k, v in self.__dict__.items():
if not isinstance(v, weakref.ReferenceType) and k != "actual_node":
state[k] = copy.deepcopy(v, memo)
result.__dict__.update(state)
return result
class ModuleNode(Node):
r"""``ModuleNode`` represents the Module objects."""
......@@ -157,19 +178,28 @@ class ModuleNode(Node):
super().__init__(expr, name, qualname)
def __getstate__(self):
return {
state = {
"expr": self.expr,
"users": self.users,
"_id": self._id,
"_name": self._name,
"_qualname": self._qualname,
"module_type": self.module_type,
"module_type": (self.module_type.__module__, self.module_type.__qualname__),
}
_check_obj_attr(state)
return state
def __setstate__(self, state):
if "_orig_name" in state:
state["_qualname"] = state.pop("_orig_name")
self.__dict__.update(state)
try:
if isinstance(self.module_type, tuple):
mname, classname = self.module_type
mtype = getattr(import_module(mname), classname)
self.module_type = mtype
except Exception:
pass
@property
def owner(self):
......@@ -185,12 +215,26 @@ class TensorNode(Node):
_shape = None # type: Tuple[int]
_dtype = None # type: numpy.dtype
_qparams = None
_qparams = None # type: QParams
_device = None
_value = None # type: Tensor
def __init__(
self,
expr: "Expr",
name: str = None,
qualname: str = None,
shape: Tuple[int] = None,
dtype: numpy.dtype = None,
qparams: QParams = None,
):
super().__init__(expr, name, qualname)
self._shape = shape
self._dtype = shape
self._qparams = qparams
def __getstate__(self):
return {
state = {
"expr": self.expr,
"users": self.users,
"_id": self._id,
......@@ -201,6 +245,8 @@ class TensorNode(Node):
"_name": self._name,
"_qualname": self._qualname,
}
_check_obj_attr(state)
return state
def __setstate__(self, state):
if "_orig_name" in state:
......@@ -276,7 +322,10 @@ class NodeMixin(abc.ABC):
assert isinstance(node, TensorNode)
assert isinstance(value, RawTensor)
if isinstance(value, RawTensor):
node._dtype = value.dtype
try:
node._dtype = value.dtype
except RuntimeError:
node._dtype = None
node._shape = (
value._tuple_shape if isinstance(value, Tensor) else value.shape
)
......
......@@ -7,15 +7,18 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections
from collections import OrderedDict
from collections import OrderedDict, defaultdict
from functools import partial
from typing import Callable, NamedTuple
import numpy as np
from ..core._imperative_rt import OpDef
from ..core._imperative_rt.common import CompNode
from ..core._imperative_rt.core2 import Tensor as RawTensor
from ..core._wrap import Device
from ..core.tensor.dtype import QuantDtypeMeta
from ..distributed import Group
from ..module import Module
from ..quantization.utils import LSQParams, QParams, QuantMode
from ..tensor import Parameter, Tensor
......@@ -49,45 +52,54 @@ SUPPORTED_LEAF_TYPE = {
type(Ellipsis),
QuantMode,
ArgsIndex,
Group,
}
USER_REGISTERED_LEAF_TYPE = []
USER_REGISTERED_CONTAINER_TYPE = []
# if isinstance(object, SUPPORTED_LEAF_CLS) or issubclass(obj, SUPPORTED_LEAF_CLS) is True, the object could be threated as leaf node of pytree
SUPPORTED_LEAF_CLS = [Module, Node, NodeMixin, np.dtype, np.ndarray, np.number]
SUPPORTED_LEAF_CLS = [
Module,
Node,
NodeMixin,
np.dtype,
np.ndarray,
np.number,
np.bool_,
OpDef,
]
NodeType = NamedTuple("NodeType", [("flatten", Callable), ("unflatten", Callable)])
def register_supported_type(type, flatten=None, unflatten=None):
tp_info = (type.__module__, type.__qualname__)
if flatten and unflatten:
SUPPORTED_TYPE[type] = NodeType(flatten, unflatten)
USER_REGISTERED_CONTAINER_TYPE.append(tp_info)
else:
SUPPORTED_LEAF_CLS.append(type)
def _dict_flatten(inp):
aux_data = []
results = []
for key, value in sorted(inp.items()):
results.append(value)
aux_data.append(key)
return results, tuple(aux_data)
USER_REGISTERED_LEAF_TYPE.append(tp_info)
_register_supported_type(type, flatten, unflatten)
def _dict_unflatten(inps, aux_data):
return dict(zip(aux_data, inps))
def _register_supported_type(type, flatten=None, unflatten=None):
if flatten and unflatten:
SUPPORTED_TYPE[type] = NodeType(flatten, unflatten)
else:
SUPPORTED_LEAF_CLS.append(type)
def _ordereddict_flatten(inp):
def _dict_flatten(ordered, inp):
aux_data = []
results = []
for key, value in inp.items():
dict_items = inp.items() if ordered else sorted(inp.items())
for key, value in dict_items:
results.append(value)
aux_data.append(key)
return results, tuple(aux_data)
def _ordereddict_unflatten(inps, aux_data):
return OrderedDict(zip(aux_data, inps))
def _dict_unflatten(dict_type, inps, aux_data):
return dict_type(zip(aux_data, inps))
def qparams_flatten(inp):
......@@ -99,33 +111,41 @@ def qparams_flatten(inp):
return results, tuple(aux_data)
def qparams_unflatten(inp, aux_data):
obj = QParams.__new__(QParams)
def qparams_unflatten(qparam_type, inp, aux_data):
obj = qparam_type.__new__(qparam_type)
for k, v in zip(aux_data, inp):
setattr(obj, k, v)
return obj
register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x))
register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: tuple(x))
register_supported_type(dict, _dict_flatten, _dict_unflatten)
register_supported_type(
collections.OrderedDict, _ordereddict_flatten, _ordereddict_unflatten
_register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x))
_register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: tuple(x))
_register_supported_type(
dict, partial(_dict_flatten, False), partial(_dict_unflatten, dict)
)
_register_supported_type(
defaultdict, partial(_dict_flatten, False), partial(_dict_unflatten, defaultdict)
)
register_supported_type(
_register_supported_type(
OrderedDict, partial(_dict_flatten, True), partial(_dict_unflatten, OrderedDict)
)
_register_supported_type(
slice,
lambda x: ([x.start, x.stop, x.step], None),
lambda x, aux_data: slice(x[0], x[1], x[2]),
)
register_supported_type(QParams, qparams_flatten, qparams_unflatten)
_register_supported_type(QParams, qparams_flatten, partial(qparams_unflatten, QParams))
_register_supported_type(
LSQParams, qparams_flatten, partial(qparams_unflatten, LSQParams)
)
def _is_leaf(obj):
if isinstance(obj, type):
return issubclass(obj, tuple(SUPPORTED_LEAF_CLS)) or obj in SUPPORTED_LEAF_TYPE
obj_type = obj if isinstance(obj, type) else type(obj)
return (
isinstance(obj, tuple(SUPPORTED_LEAF_CLS)) or type(obj) in SUPPORTED_LEAF_TYPE
issubclass(obj_type, tuple(SUPPORTED_LEAF_CLS))
or obj_type in SUPPORTED_LEAF_TYPE
)
......
......@@ -5,30 +5,158 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Dict
from importlib import import_module
from typing import Dict, Tuple
from ..core._imperative_rt import OpDef
from ..core.ops import builtin
from ..tensor import Tensor
from ..version import __version__
from .utils import _convert_kwargs_to_args
OPDEF_PARAM_LOADER = {}
OPDEF_LOADER = {}
FUNCTIONAL_LOADER = {}
TENSORMETHOD_LOADER = {}
MODULE_LOADER = {}
def get_opdef_state(obj: OpDef) -> Dict:
state = obj.__getstate__()
state["type"] = type(obj)
state["version"] = __version__
return state
class _ModuleState:
obj = None
def __init__(self, module: Tuple, state: Dict, version: str):
self.module = module
self.state = state
self.version = version
def load_opdef_from_state(state: Dict) -> OpDef:
assert "type" in state and issubclass(state["type"], OpDef)
assert "version" in state
opdef_type = state.pop("type")
if opdef_type in OPDEF_PARAM_LOADER:
loader = OPDEF_PARAM_LOADER[opdef_type]
state = loader(state)
state.pop("version")
opdef_obj = opdef_type()
opdef_obj.__setstate__(state)
return opdef_obj
@classmethod
def get_module_state(cls, module):
typem = (type(module).__module__, type(module).__qualname__)
state = module.__dict__.copy()
state.pop("_m_dump_modulestate", None)
if hasattr(module, "_m_dump_modulestate"):
assert isinstance(module._m_dump_modulestate, cls)
module._m_dump_modulestate.__init__(typem, state, __version__)
else:
module.__dict__["_m_dump_modulestate"] = _ModuleState(
typem, state, __version__
)
return module._m_dump_modulestate
def __getstate__(self):
return {"module": self.module, "state": self.state, "version": self.version}
def to_module(self):
if self.obj is None:
typem = getattr(import_module(self.module[0]), self.module[1])
m_obj = typem.__new__(typem)
m_obj.__dict__.update(self.state)
self.obj = m_obj
return self.obj
def register_opdef_loader(*opdefs):
def callback(loader):
for opdef in opdefs:
assert opdef not in OPDEF_LOADER
OPDEF_LOADER[opdef] = loader
return loader
return callback
def register_functional_loader(*funcs):
def callback(loader):
for func in funcs:
assert func not in FUNCTIONAL_LOADER
FUNCTIONAL_LOADER[func] = loader
return loader
return callback
def register_module_loader(*module_types):
def callback(loader):
for module_type in module_types:
assert module_type not in MODULE_LOADER
MODULE_LOADER[module_type] = loader
return loader
return callback
def register_tensor_method_loader(*methods):
def callback(loader):
for method in methods:
assert method not in TENSORMETHOD_LOADER
TENSORMETHOD_LOADER[method] = loader
return loader
return callback
def _replace_args_kwargs(expr, new_args, new_kwargs):
if len(new_args) != len(expr.args) or set(new_kwargs.keys()) != set(
expr.kwargs.keys()
):
expr.set_args_kwargs(*new_args, **new_kwargs)
def load_functional(expr):
func = (
(expr.func.__module__, expr.func.__qualname__)
if callable(expr.func)
else expr.func
)
assert isinstance(func, tuple)
if func in FUNCTIONAL_LOADER:
loader = FUNCTIONAL_LOADER[func]
loader(expr)
mname, fname = func
f = import_module(mname)
for i in fname.split("."):
f = getattr(f, i)
expr.func = f
assert callable(expr.func)
if not hasattr(expr, "version") or expr.version != __version__:
args, kwargs = _convert_kwargs_to_args(expr.func, expr.args, expr.kwargs)
_replace_args_kwargs(expr, args, kwargs)
def load_call_module_expr(expr):
m_type = expr.inputs[0].module_type
if isinstance(m_type, type):
m_type = (m_type.__module__, m_type.__qualname__)
if m_type in MODULE_LOADER:
MODULE_LOADER[m_type](expr)
if isinstance(expr.inputs[0].module_type, tuple):
mname, classname = expr.inputs[0].module_type
expr.inputs[0].module_type = getattr(import_module(mname), classname)
if not hasattr(expr, "version") or expr.version != __version__:
fwd_func = getattr(expr.inputs[0].module_type, "forward")
args, kwargs = _convert_kwargs_to_args(fwd_func, expr.args, expr.kwargs)
_replace_args_kwargs(expr, args, kwargs)
def load_call_tensor_method_expr(expr):
if expr.method in TENSORMETHOD_LOADER:
loader = TENSORMETHOD_LOADER[expr.method]
loader(expr)
if not hasattr(expr, "version") or expr.version != __version__:
tmethod = (
getattr(expr.args[0], expr.method)
if isinstance(expr.args[0], type)
else getattr(Tensor, expr.method)
)
args, kwargs = _convert_kwargs_to_args(tmethod, expr.args, expr.kwargs)
_replace_args_kwargs(expr, args, kwargs)
def load_apply_expr(expr):
opdef_type = type(expr.opdef)
if opdef_type in OPDEF_LOADER:
OPDEF_LOADER[opdef_type](expr)
opdef_state = expr.opdef_state
opdef_obj = opdef_state.pop("opdef_type")()
opdef_obj.__setstate__(opdef_state)
expr.opdef = opdef_obj
......@@ -14,6 +14,7 @@ import inspect
import keyword
import re
import weakref
from importlib import import_module
from inspect import getcallargs, getmembers, isclass, ismethod
from itertools import chain
from types import FunctionType
......@@ -53,6 +54,7 @@ from ..quantization.observer import (
SyncMinMaxObserver,
)
from ..tensor import Tensor
from ..version import __version__
from .expr import (
Apply,
CallFunction,
......@@ -80,8 +82,27 @@ from .module_tracer import (
set_active_module_tracer,
)
from .node import ModuleNode, Node, NodeMixin, TensorNode
from .pytree import ArgsIndex, tree_flatten
from .utils import replace_container_with_module_container
from .pytree import (
USER_REGISTERED_CONTAINER_TYPE,
USER_REGISTERED_LEAF_TYPE,
ArgsIndex,
TreeDef,
_register_supported_type,
tree_flatten,
)
from .serialization import (
_ModuleState,
load_apply_expr,
load_call_module_expr,
load_call_tensor_method_expr,
load_functional,
)
from .utils import (
_check_builtin_module_attr,
_check_obj_attr,
_convert_kwargs_to_args,
replace_container_with_module_container,
)
logger = get_logger(__name__)
......@@ -341,7 +362,7 @@ class NameSpace:
def create_unique_name(self, name: str, node: Any = None) -> str:
assert isinstance(name, str), "The name must be a string"
if name in self._used_names and self._used_names[name] is node:
if name in self._used_names and (self._used_names[name] is node):
return name
name = re.sub("[^0-9a-zA-Z_]+", "_", name)
......@@ -1067,6 +1088,7 @@ class InternalGraph:
if node2value[n][1] == 0:
node2value.pop(n)
if values is not None:
assert len(values) == len(expr.outputs)
for n, v in zip(expr.outputs, values):
if ref_count(n) > 0:
node2value[n] = [v, ref_count(n)]
......@@ -1105,13 +1127,27 @@ class InternalGraph:
return res
def __getstate__(self):
state = self.__dict__.copy()
if "_top_graph" in state:
state.pop("_top_graph")
state = {
"_exprs": self._exprs,
"_inputs": self._inputs,
"_outputs": self._outputs,
"_watch_point": [],
"_end_point": [],
"_namespace": self._namespace,
"_rst": collections.defaultdict(list),
"_name": self._name,
"_qualname": self._qualname,
}
if self._total_ids:
state["_total_ids"] = self._total_ids
_check_obj_attr(state)
return state
def __setstate__(self, state):
old_version = False
if "_module_name" in state:
old_version = True
state["_qualname"] = state.pop("_module_name")
......@@ -1144,6 +1180,25 @@ class InternalGraph:
self._namespace = NameSpace(self._name, self._qualname)
self._re_associate_name()
def __copy__(self):
cls = self.__class__
result = cls.__new__(cls)
result.__dict__.update(self.__dict__)
return result
def __deepcopy__(self, memo):
if id(self) in memo:
return memo[id(self)]
cls = self.__class__
result = cls.__new__(cls)
state = {}
memo[id(self)] = result
for k, v in self.__dict__.items():
if not isinstance(v, weakref.ReferenceType):
state[k] = copy.deepcopy(v, memo)
result.__dict__.update(state)
return result
def _get_meth_name(obj, func):
tp = obj if isinstance(obj, type) else type(obj)
......@@ -1157,9 +1212,7 @@ def _get_meth_name(obj, func):
def _wrapped_function(orig_func):
@functools.wraps(orig_func)
def wrapped_fn(*args, **kwargs):
method_func = wrapped_fn
if "method_func" in kwargs:
method_func = kwargs.pop("method_func")
method_func = kwargs.pop("method_func", wrapped_fn)
if is_tracing_module():
unset_module_tracing()
inputs, tree_def = tree_flatten((args, kwargs))
......@@ -1167,11 +1220,11 @@ def _wrapped_function(orig_func):
if not NodeMixin.get(i, None):
if isinstance(i, (RawTensor, NodeMixin)):
NodeMixin.wrap_safe(i, Constant.make(i))
meth_name, arg_type = None, None
if args:
meth_name = _get_meth_name(args[0], method_func)
arg_type = args[0] if isinstance(args[0], type) else type(args[0])
args, kwargs = _convert_kwargs_to_args(orig_func, args, kwargs)
meth_name = _get_meth_name(args[0], method_func)
arg_type = args[0] if isinstance(args[0], type) else type(args[0])
if meth_name and arg_type and issubclass(arg_type, RawTensor):
inputs, tree_def = tree_flatten((args, kwargs))
self = inputs[0]
if meth_name == "__new__":
if all([not isinstance(i, RawTensor) for i in inputs]):
......@@ -1190,6 +1243,7 @@ def _wrapped_function(orig_func):
call_node = CallMethod.make(NodeMixin.get(self), meth_name)
call_node.add_inputs(inputs[1:])
else:
inputs, tree_def = tree_flatten((args, kwargs))
call_node = CallFunction.make(orig_func)
call_node.add_inputs(inputs)
......@@ -1228,9 +1282,11 @@ class TracedModuleBuilder(NodeMixin):
"_record_wrapped_nodes",
"_argdef_graph_map",
"_argdef_outdef_map",
"_check_qat_module",
"nodes",
"__class__",
"__dict__",
"_is_top",
]
def __init__(self, mod, is_top_module=False):
......@@ -1301,22 +1357,18 @@ class TracedModuleBuilder(NodeMixin):
qat_module.weight_fake_quant.set_qparams(qparams)
def build(self):
if self._is_builtin or isinstance(self._mod, TracedModule):
if module_tracer.is_builtin(self._mod) or isinstance(
self._mod, TracedModule
):
mod_type = type(self._mod)
else:
assert isinstance(self._mod, (Observer, _FakeQuantize))
mod_type = (
Observer if isinstance(self._mod, Observer) else _FakeQuantize
)
if self._is_builtin:
assert module_tracer.is_builtin(self._mod)
mod_type = type(self._mod)
for node in self.nodes:
node.module_type = mod_type
return self._mod
else:
is_qat = isinstance(self._mod, QATModule)
is_qat = isinstance(self._mod, QATModule) or (
isinstance(self._mod, TracedModule) and self._mod.is_qat
)
traced_module = TracedModule(
self._is_top, self._argdef_graph_map, self._argdef_outdef_map, is_qat
)
......@@ -1338,15 +1390,18 @@ class TracedModuleBuilder(NodeMixin):
traced_module.with_act = self._mod.with_act
traced_module.with_weight = self._mod.with_weight
if not hasattr(traced_module, "act_fake_quant"):
traced_module.act_fakequant = None
traced_module.act_fake_quant = None
if not hasattr(traced_module, "act_observer"):
traced_module.act_observer = None
if not hasattr(traced_module, "weight_fake_quant"):
traced_module.weight_fakequant = None
traced_module.weight_fake_quant = None
if not hasattr(traced_module, "weight_observer"):
traced_module.weight_observer = None
set_module_tracing()
if self._is_top:
traced_module._update_ref()
return traced_module
def _record_wrapped_nodes(self, node):
......@@ -1357,6 +1412,7 @@ class TracedModuleBuilder(NodeMixin):
# prepare args and kwargs for inner graph
if "method_func" in kwargs:
kwargs.pop("method_func")
args, kwargs = _convert_kwargs_to_args(self._mod.forward, args, kwargs, True)
def mark_constant(x):
node = NodeMixin.get(x, None)
......@@ -1372,11 +1428,7 @@ class TracedModuleBuilder(NodeMixin):
callnode.arg_def = tree_def
if (
self._is_builtin
or tree_def in self._argdef_graph_map
or isinstance(self._mod, TracedModule)
):
if self._is_builtin or tree_def in self._argdef_graph_map:
unset_module_tracing()
rst = self._mod(*args, **kwargs)
outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf)
......@@ -1385,33 +1437,7 @@ class TracedModuleBuilder(NodeMixin):
self._body = None
elif tree_def in self._argdef_graph_map:
self._body = self._argdef_graph_map[tree_def]
else:
self._mod._is_top = False
self._body = self._mod.argdef_graph_map[tree_def]
module_qualname = NodeMixin.get(self).qualname
if module_qualname != self._body.qualname:
src_name, dst_name = self._body.qualname, module_qualname
def replace_qualname(g):
attr_name = get_suffix_name(src_name, g.qualname)
if attr_name is not None:
g._qualname = (
("%s.%s" % (dst_name, attr_name))
if attr_name
else dst_name
)
assert get_suffix_name(dst_name, g.qualname) is not None
for mod in self._mod.modules():
if not hasattr(mod, "argdef_graph_map"):
continue
for g in mod.argdef_graph_map.values():
replace_qualname(g)
g._namespace.qualname = g.qualname
for n in g.nodes(False):
replace_qualname(n)
else:
self_node = None
orig_self = NodeMixin.get(self)
parent_graph = active_module_tracer().current_scope()
module_qualname = orig_self._qualname
......@@ -1423,20 +1449,14 @@ class TracedModuleBuilder(NodeMixin):
active_module_tracer().push_scope(self._body)
# rebind self to new input node
if self_node:
NodeMixin.wrap_safe(self, self_node)
active_module_tracer().current_scope()._add_input(self_node)
else:
NodeMixin.wrap_safe(
self,
self_node
if self_node
else Input.make(
name="self",
qualname=module_qualname,
type=NodeMixin.get_wrapped_type(self),
),
)
NodeMixin.wrap_safe(
self,
Input.make(
name="self",
qualname=module_qualname,
type=NodeMixin.get_wrapped_type(self),
),
)
origin_inp_node = [NodeMixin.get(i, None) for i in inputs[1:]]
# prepare args and kwargs for inner graph
......@@ -1470,8 +1490,23 @@ class TracedModuleBuilder(NodeMixin):
return x
args = [self]
for i, v in enumerate(inputs[1:]):
args.append(wrap(v, idx2key[i + 1]))
orig_traced_inputs = (
None
if not isinstance(self._mod, TracedModule)
else self._mod.argdef_graph_map[tree_def].inputs
)
ind = 1
for v in inputs[1:]:
if isinstance(v, (RawTensor, NodeMixin)):
args_name = (
orig_traced_inputs[ind]._name
if orig_traced_inputs
else idx2key[ind]
)
ind += 1
args.append(wrap(v, args_name))
else:
args.append(v)
args, kwargs = tree_def.unflatten(args)
active_module_tracer().patcher.auto_patch(
......@@ -1514,7 +1549,6 @@ class TracedModuleBuilder(NodeMixin):
attr = getattr(type(self._mod), name).__get__(self, type(self))
else:
attr = getattr(self._mod, name)
if (
isinstance(attr, FunctionType)
and id(attr) in active_module_tracer().patcher.patched_fn_ids
......@@ -1568,7 +1602,7 @@ class TracedModuleBuilder(NodeMixin):
wrapped = self.__getattr__(name)
if isinstance(wrapped, TracedModuleBuilder):
if not isinstance(mod_attr, (List, Dict)):
if not isinstance(mod_attr, (List, Dict, QATModule)):
assert mod_attr is wrapped._mod
else:
assert mod_attr is wrapped
......@@ -1977,8 +2011,6 @@ class TracedModule(Module):
def graph(self) -> InternalGraph:
"""Return the ``InternalGraph`` of this ``TracedModule``.
"""
if self._is_top:
self._update_ref()
assert len(self.argdef_graph_map) == 1
return list(self.argdef_graph_map.values())[0]
......@@ -2112,7 +2144,7 @@ class TracedModule(Module):
if hasattr(obj, "argdef_graph_map")
else None
)
if expr_graph is not None:
if expr_graph is not None and not obj.is_qat:
exprs = _flatten_subgraph(graph, expr_graph, expr, obj)
if parent_graph is not None:
......@@ -2137,26 +2169,119 @@ class TracedModule(Module):
)
new_module.graph._re_associate_name()
new_module.graph.compile()
new_module._update_ref()
new_module.graph._reset_ids()
return new_module
def __getstate__(self):
d = self.__dict__
d = self.__dict__.copy()
for k in Module.__dict__:
d.pop(k, None)
_check_obj_attr(d)
for k in d:
if module_tracer.is_builtin(d[k]):
assert _check_builtin_module_attr(
d[k]
), "Module {} can not be serialized. ".format(type(d[k]))
d[k] = _ModuleState.get_module_state(d[k])
dump_info = {
"version": __version__,
"register_type": USER_REGISTERED_LEAF_TYPE,
"register_container_type": USER_REGISTERED_CONTAINER_TYPE,
"register_mdule": USER_REGISTERED_MODULE,
"register_function": USER_REGISTERED_FUNCTION,
}
d["dump_info"] = dump_info
return d
def __setstate__(self, state):
for k, v in state.items():
if isinstance(v, _ModuleState):
state[k] = v.to_module()
self.__dict__.update(state)
self._update_ref()
for _, graph in self.argdef_graph_map.items():
for expr in graph._exprs:
if isinstance(expr, CallFunction):
load_functional(expr)
if isinstance(expr, CallMethod):
if expr.method == "__call__":
load_call_module_expr(expr)
else:
load_call_tensor_method_expr(expr)
if isinstance(expr, Apply):
load_apply_expr(expr)
for _, graph in self.argdef_graph_map.items():
ind = 0
while ind < len(graph._exprs):
cur_expr = graph._exprs[ind]
has_new_expr = False
for i in cur_expr.inputs:
if i.expr not in graph._exprs and not isinstance(i.expr, Input):
graph._exprs.insert(ind, i.expr)
has_new_expr = True
if not has_new_expr:
ind += 1
for expr in graph._exprs:
for i in expr.inputs:
if expr.inputs.count(i) != i.users.count(expr):
add_or_del_count = expr.inputs.count(i) - i.users.count(expr)
if add_or_del_count > 0:
i.users.extend([expr] * add_or_del_count)
else:
[i.users.remove(expr) for i in range(-add_or_del_count)]
for o in expr.outputs:
if o.expr is not expr:
assert o not in o.expr.outputs
o.expr = expr
for node in graph.nodes(False):
# remove users of node which doesn't use node as input
node.users = [e for e in node.users if node in e.inputs]
for expr in graph._exprs:
graph._namespace.auto_naming_for_outputs(expr)
self._update_ref()
for _, graph in self.argdef_graph_map.items():
graph._reset_ids()
def __copy__(self):
cls = self.__class__
result = cls.__new__(cls)
result.__dict__.update(self.__dict__)
return result
def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
state = {}
memo[id(self)] = result
for k, v in self.__dict__.items():
if not isinstance(v, weakref.ReferenceType):
state[k] = copy.deepcopy(v, memo)
result.__dict__.update(state)
result._update_ref()
return result
def cpp_apply_module_trace(opdef, *args):
return Apply.apply_module_trace_hook(opdef, *args)
USER_REGISTERED_MODULE = []
USER_REGISTERED_FUNCTION = []
def register_as_builtin(mod_cls: Type[Module]) -> None:
r"""Registers class ``mod_cls`` (subclass of :class:`~.Module`) as builtin module.
Args:
mod_cls: the module class which will be treated as builtin module in tracing.
"""
USER_REGISTERED_MODULE.append((mod_cls.__module__, mod_cls.__qualname__))
module_tracer.register_as_builtin(mod_cls)
......@@ -2181,6 +2306,7 @@ def wrap(func: Callable):
Args:
func: the function of the global function to insert into the graph when it's called.
"""
USER_REGISTERED_FUNCTION.append((func.__module__, func.__qualname__))
assert callable(func), "func must be a callable"
assert hasattr(func, "__code__")
fn_name = func.__code__.co_name
......@@ -2247,6 +2373,8 @@ def trace_module(
NodeMixin.wrap_safe(
builder, Input.make(name="top", type=ModuleNode, qualname=net_name)
)
args, kwargs = _convert_kwargs_to_args(mod.forward, args, kwargs, True)
inputs, _ = tree_flatten((args, kwargs))
for _, i in enumerate(inputs):
# assert isinstance(i, Tensor), "not support "
......
......@@ -5,12 +5,17 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections
import copy
import inspect
from collections.abc import MutableMapping, MutableSequence
from typing import Dict, Iterable, List, Optional, Sequence
from typing import Dict, Iterable, List, Optional, Sequence, Type
from .. import get_logger
from ..module import Module
logger = get_logger(__name__)
def replace_container_with_module_container(container):
has_module = False
......@@ -52,6 +57,101 @@ def replace_container_with_module_container(container):
return has_module, module_container
def _convert_kwargs_to_args(func, args, kwargs, is_bounded=False):
# is_bounded = True when func is a method and provided args don't include 'self'
arg_specs = inspect.getfullargspec(func)
arg_specs_args = arg_specs.args
if is_bounded:
arg_specs_args = arg_specs.args[1:]
new_args = []
new_kwargs = {}
new_args.extend(args)
if set(arg_specs_args[0 : len(new_args)]) & set(kwargs.keys()):
repeated_arg_name = set(arg_specs_args[0 : len(new_args)]) & set(kwargs.keys())
raise TypeError(
"{} got multiple values for argument {}".format(
func.__qualname__, ", ".join(repeated_arg_name)
)
)
if len(new_args) < len(arg_specs.args):
for ind in range(len(new_args), len(arg_specs_args)):
arg_name = arg_specs_args[ind]
if arg_name in kwargs:
new_args.append(kwargs[arg_name])
else:
index = ind - len(arg_specs_args) + len(arg_specs.defaults)
assert index < len(arg_specs.defaults) and index >= 0
new_args.append(arg_specs.defaults[index])
for kwarg_name in arg_specs.kwonlyargs:
if kwarg_name in kwargs:
new_kwargs[kwarg_name] = kwargs[kwarg_name]
else:
assert kwarg_name in arg_specs.kwonlydefaults
new_kwargs[kwarg_name] = arg_specs.kwonlydefaults[kwarg_name]
for k, v in kwargs.items():
if k not in arg_specs.args and k not in arg_specs.kwonlyargs:
if arg_specs.varkw is None:
raise TypeError(
"{} got an unexpected keyword argument {}".format(
func.__qualname__, k
)
)
new_kwargs[k] = v
return tuple(new_args), new_kwargs
def _check_obj_attr(obj):
# check if all the attributes of a obj is serializable
from .pytree import tree_flatten
from .pytree import SUPPORTED_LEAF_CLS, SUPPORTED_LEAF_TYPE, TreeDef
from .expr import Expr
from .traced_module import TracedModule, InternalGraph, NameSpace
def _check_leaf_type(leaf):
leaf_type = leaf if isinstance(leaf, type) else type(leaf)
traced_module_types = [Expr, TreeDef, TracedModule, InternalGraph, NameSpace]
return (
issubclass(leaf_type, tuple(SUPPORTED_LEAF_CLS + traced_module_types))
or leaf_type in SUPPORTED_LEAF_TYPE
)
for _, v in obj.items():
leafs, _ = tree_flatten(v, is_leaf=lambda _: True)
for leaf in leafs:
assert _check_leaf_type(
leaf
), "Type {} is not supported by traced module".format(
leaf if isinstance(leaf, type) else type(leaf)
)
def _check_builtin_module_attr(mod):
from .pytree import _is_leaf as _check_leaf_type
from .pytree import tree_flatten
# check if all the attributes of a builtin module is serializable
is_non_serializable_module = lambda m: isinstance(
m, Module
) and not _check_builtin_module_attr(m)
for k, v in mod.__dict__.items():
if k == "_m_dump_modulestate":
continue
if is_non_serializable_module(v):
return False
elif not isinstance(v, Module):
leafs, _ = tree_flatten(v, is_leaf=lambda _: True)
for leaf in leafs:
if not _check_leaf_type(leaf) or is_non_serializable_module(leaf):
logger.warn(
"Type {} is not supported by traced module".format(
leaf if isinstance(leaf, type) else type(leaf)
)
)
return False
return True
class _ModuleList(Module, MutableSequence):
r"""A List-like container.
......
......@@ -15,7 +15,6 @@ import numpy as np
import megengine as mge
from megengine import Parameter, Tensor
from megengine.core.ops import builtin
from megengine.traced_module.serialization import get_opdef_state, load_opdef_from_state
def test_tensor_serialization():
......@@ -88,25 +87,3 @@ def test_compatibility():
test_old_tensor("tensor_v1_1.mge")
test_old_tensor("tensor_v1_2.mge")
def test_opdef_serialization():
with TemporaryFile() as f:
x = builtin.Elemwise(mode="Add")
pickle.dump(get_opdef_state(x), f)
f.seek(0)
load_x = load_opdef_from_state(pickle.load(f))
assert x == load_x
with TemporaryFile() as f:
x = builtin.Convolution(stride_h=9, compute_mode="float32")
x.strategy = (
builtin.Convolution.Strategy.PROFILE
| builtin.Convolution.Strategy.HEURISTIC
| builtin.Convolution.Strategy.REPRODUCIBLE
)
pickle.dump(get_opdef_state(x), f)
f.seek(0)
load_x = load_opdef_from_state(pickle.load(f))
assert x.strategy == load_x.strategy
assert x == load_x
......@@ -85,12 +85,12 @@ class NewModule(M.Module):
return x
def _check_expr_users(traced_module):
def _check_expr_users(flattened_module):
node_user = defaultdict(list)
for expr in traced_module.graph._exprs:
for expr in flattened_module.graph._exprs:
for node in expr.inputs:
node_user[node].append(expr)
for node in traced_module.graph.nodes():
for node in flattened_module.graph.nodes():
node.users.sort(key=lambda m: m._id)
node_user[node].sort(key=lambda m: m._id)
assert node.users == node_user[node]
......
......@@ -8,6 +8,7 @@ import numpy as np
import megengine as mge
import megengine.functional as F
import megengine.module as M
import megengine.module.qat as QM
import megengine.quantization as Q
from megengine import Tensor
from megengine.module.qat.module import QATModule
......@@ -28,10 +29,18 @@ def get_subattr(self: M.Module, name: str):
return getattr(self, name)
class MyConvBnRelu2d(M.ConvBnRelu2d):
pass
class MyQATConvBnRelu2d(QM.ConvBnRelu2d):
pass
class Myblcok(M.Module):
def __init__(self,):
super().__init__()
self.conv0 = M.ConvBnRelu2d(3, 3, 3, 1, 1)
self.conv0 = MyConvBnRelu2d(3, 3, 3, 1, 1)
self.conv1 = M.ConvBn2d(3, 3, 1, 1, 0)
self.conv2 = M.ConvBn2d(3, 3, 1, 1, 0)
self.add = M.Elemwise("FUSE_ADD_RELU")
......@@ -106,7 +115,11 @@ def check_qparams(qparmsa: Q.QParams, qparmsb: Q.QParams):
def build_observered_net(net: M.Module, observer_cls):
qat_net = Q.quantize_qat(net, qconfig=get_observer_config(observer_cls))
qat_net = Q.quantize_qat(
net,
qconfig=get_observer_config(observer_cls),
mapping={MyConvBnRelu2d: MyQATConvBnRelu2d},
)
Q.enable_observer(qat_net)
inp = Tensor(np.random.random(size=(5, 3, 32, 32)))
qat_net(inp)
......@@ -134,6 +147,15 @@ def test_trace_qat():
check_qparams(weight_qparams, traced_weight_qparams)
if act_qparams:
check_qparams(act_qparams, traced_act_qparams)
flatten_traced_net = traced_net.flatten()
conv0_node = flatten_traced_net.graph.get_node_by_name(
"MyModule_block0_conv0"
).as_unique()
conv0_out_node = flatten_traced_net.graph.get_node_by_name(
"MyModule_block0_conv0_out"
).as_unique()
assert isinstance(conv0_node.owner, TracedModule)
assert conv0_out_node.expr.inputs[0] is conv0_node
_check_qat_module(build_observered_net(MyModule(), Q.MinMaxObserver))
_check_qat_module(build_observered_net(MyModule(), MyMinMaxObserver))
......
......@@ -6,14 +6,59 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import pickle
from collections import defaultdict
from tempfile import TemporaryFile
import numpy as np
import megengine.functional as F
import megengine.module as M
import megengine.traced_module.serialization as S
from megengine import Tensor
from megengine.core._imperative_rt.core2 import apply
from megengine.core.ops import builtin
from megengine.core.ops.builtin import Elemwise
from megengine.module import Module
from megengine.traced_module import trace_module
from megengine.traced_module.expr import CallMethod, Constant
from megengine.traced_module.node import TensorNode
from megengine.traced_module.serialization import (
register_functional_loader,
register_module_loader,
register_opdef_loader,
register_tensor_method_loader,
)
from megengine.traced_module.utils import _convert_kwargs_to_args
def _check_id(traced_module):
_total_ids = traced_module.graph._total_ids
node_ids = [n._id for n in traced_module.graph.nodes().as_list()]
assert len(set(node_ids)) == len(node_ids)
assert max(node_ids) + 1 == _total_ids[0]
expr_ids = [n._id for n in traced_module.graph.exprs().as_list()]
assert len(set(expr_ids)) == len(expr_ids)
assert max(expr_ids) + 1 == _total_ids[1]
def _check_name(flatened_module):
node_names = [n._name for n in flatened_module.graph.nodes().as_list()]
assert len(set(node_names)) == len(node_names)
def _check_expr_users(traced_module):
node_user = defaultdict(list)
for expr in traced_module.graph._exprs:
for node in expr.inputs:
node_user[node].append(expr)
if isinstance(expr, CallMethod) and expr.graph:
_check_expr_users(expr.inputs[0].owner)
for node in traced_module.graph.nodes(False):
node.users.sort(key=lambda m: m._id)
node_user[node].sort(key=lambda m: m._id)
assert node.users == node_user[node]
class MyBlock(Module):
......@@ -48,5 +93,274 @@ def test_dump_and_load():
traced_module = trace_module(module, x)
np.testing.assert_array_equal(expect, traced_module(x))
obj = pickle.dumps(traced_module)
pickle.loads(obj)
new_tm = pickle.loads(obj)
_check_id(new_tm)
_check_expr_users(new_tm)
traced_module.graph._reset_ids()
old_nodes = traced_module.graph.nodes().as_list()
new_nodes = new_tm.graph.nodes().as_list()
old_exprs = traced_module.graph.exprs().as_list()
new_exprs = new_tm.graph.exprs().as_list()
assert len(old_nodes) == len(new_nodes)
for i, j in zip(old_nodes, new_nodes):
assert i._name == j._name
assert i._qualname == j._qualname
assert i._id == j._id
assert len(old_exprs) == len(new_exprs)
for i, j in zip(old_exprs, new_exprs):
assert i._id == j._id
np.testing.assert_array_equal(expect, traced_module(x))
def test_opdef_loader():
class MyModule1(Module):
def forward(self, x, y):
op = Elemwise("ADD")
return apply(op, x, y)[0]
m = MyModule1()
x = Tensor(np.ones((20)))
y = Tensor(np.ones((20)))
traced_module = trace_module(m, x, y)
orig_loader_dict = S.OPDEF_LOADER
S.OPDEF_LOADER = {}
@register_opdef_loader(Elemwise)
def add_opdef_loader(expr):
if expr.opdef_state["mode"] == "ADD":
expr.opdef_state["mode"] = "MUL"
node = expr.inputs[1]
astype_expr = CallMethod(node, "astype")
oup = TensorNode(
astype_expr,
shape=node.shape,
dtype=expr.inputs[0].dtype,
qparams=node.qparams,
)
astype_expr.set_args_kwargs(node, expr.inputs[0].dtype)
astype_expr.return_val = (oup,)
expr.inputs[1] = oup
obj = pickle.dumps(traced_module)
new_module = pickle.loads(obj)
_check_id(new_module)
_check_expr_users(new_module)
_check_name(new_module.flatten())
assert (
isinstance(new_module.graph._exprs[0], CallMethod)
and new_module.graph._exprs[1].opdef.mode == "MUL"
and len(new_module.graph._exprs) == 2
)
result = new_module(x, y)
np.testing.assert_equal(result.numpy(), x.numpy())
S.OPDEF_LOADER = orig_loader_dict
def test_functional_loader():
class MyModule2(Module):
def forward(self, x, y):
return F.conv2d(x, y)
m = MyModule2()
x = Tensor(np.random.random((1, 3, 32, 32)))
y = Tensor(np.random.random((3, 3, 3, 3)))
traced_module = trace_module(m, x, y)
orig_loader_dict = S.FUNCTIONAL_LOADER
S.FUNCTIONAL_LOADER = {}
@register_functional_loader(("megengine.functional.nn", "conv2d"))
def conv2df_loader(expr):
# expr.func = ("megengine.functional.nn","conv2d")
kwargs = expr.kwargs
orig_weight = expr.named_args["weight"]
astype_expr = CallMethod(orig_weight, "astype")
oup = TensorNode(
astype_expr,
shape=orig_weight.shape,
dtype=orig_weight.dtype,
qparams=orig_weight.qparams,
)
astype_expr.set_args_kwargs(orig_weight, expr.named_args["inp"].dtype)
astype_expr.return_val = (oup,)
expr.set_arg("weight", oup)
obj = pickle.dumps(traced_module)
new_module = pickle.loads(obj)
_check_expr_users(new_module)
_check_id(new_module)
result = new_module(x, y)
gt = m(x, y)
assert (
isinstance(new_module.graph._exprs[0], CallMethod)
and len(new_module.graph._exprs) == 2
)
np.testing.assert_equal(result.numpy(), gt.numpy())
S.FUNCTIONAL_LOADER = orig_loader_dict
def test_tensor_method_loader():
class MyModule3(Module):
def forward(self, x):
return x + 1
m = MyModule3()
x = Tensor(np.ones((20)))
traced_module = trace_module(m, x)
orig_loader_dict = S.TENSORMETHOD_LOADER
S.TENSORMETHOD_LOADER = {}
@register_tensor_method_loader("__add__")
def add_loader(expr):
args = list(expr.args)
if not isinstance(args[1], TensorNode):
args[1] = Tensor(args[1])
node = Constant(args[1], "const").outputs[0]
astype_expr = CallMethod(node, "astype")
oup = TensorNode(
astype_expr, shape=node.shape, dtype=node.dtype, qparams=node.qparams,
)
astype_expr.set_args_kwargs(node, expr.inputs[0].dtype)
astype_expr.return_val = (oup,)
add_expr = CallMethod(oup, "__add__")
add_expr.set_args_kwargs(oup, oup)
oup1 = TensorNode(
add_expr, shape=oup.shape, dtype=oup.dtype, qparams=node.qparams,
)
add_expr.return_val = oup1
args[1] = oup1
expr.set_args_kwargs(*args)
obj = pickle.dumps(traced_module)
new_module = pickle.loads(obj)
_check_expr_users(new_module)
_check_id(new_module)
result = new_module(x)
gt = m(x)
assert (
isinstance(new_module.graph._exprs[0], Constant)
and len(new_module.graph._exprs) == 4
)
np.testing.assert_equal(result.numpy(), (x + 2).numpy())
S.TENSORMETHOD_LOADER = orig_loader_dict
def test_module_loader():
class MyModule4(Module):
def __init__(self):
super().__init__()
self.conv = M.Conv2d(3, 3, 3)
def forward(self, x):
return self.conv(x)
m = MyModule4()
x = Tensor(np.random.random((1, 3, 32, 32)))
traced_module = trace_module(m, x)
orig_loader_dict = S.MODULE_LOADER
S.MODULE_LOADER = {}
@register_module_loader(("megengine.module.conv", "Conv2d"))
def conv2dm_loader(expr):
module = expr.inputs[0].owner
args = list(expr.args)
orig_inp = args[1]
astype_expr = CallMethod(orig_inp, "astype")
oup = TensorNode(
astype_expr,
shape=orig_inp.shape,
dtype=orig_inp.dtype,
qparams=orig_inp.qparams,
)
astype_expr.set_args_kwargs(orig_inp, module.weight.dtype)
astype_expr.return_val = (oup,)
args[1] = oup
expr.set_args_kwargs(*args)
obj = pickle.dumps(traced_module)
new_module = pickle.loads(obj)
result = new_module(x)
gt = m(x)
assert (
isinstance(new_module.graph._exprs[1], CallMethod)
and len(new_module.graph._exprs) == 3
)
np.testing.assert_equal(result.numpy(), gt.numpy())
S.MODULE_LOADER = orig_loader_dict
def test_shared_module():
class MyModule(M.Module):
def __init__(self):
super().__init__()
self.a = M.Elemwise("ADD")
self.b = self.a
def forward(self, x, y):
z = self.a(x, y)
z = self.b(z, y)
return z
x = Tensor(1)
y = Tensor(2)
m = MyModule()
tm = trace_module(m, x, y)
obj = pickle.dumps(tm)
load_tm = pickle.loads(obj)
_check_expr_users(load_tm)
_check_name(load_tm.flatten())
_check_id(load_tm)
assert load_tm.a is load_tm.b
def test_convert_kwargs_to_args():
def func(a, b, c=4, *, d, e=3, f=4):
pass
args = (1,)
kwargs = {"b": 1, "d": 6}
new_args, new_kwargs = _convert_kwargs_to_args(func, args, kwargs)
assert new_args == (1, 1, 4)
assert new_kwargs == {"d": 6, "e": 3, "f": 4}
args = (1,)
kwargs = {"d": 6}
new_args, new_kwargs = _convert_kwargs_to_args(func, args, kwargs, is_bounded=True)
assert new_args == (1, 4)
assert new_kwargs == {"d": 6, "e": 3, "f": 4}
def func1(a, b, c, d, e, *, f):
pass
args = ()
kwargs = {"a": 1, "b": 2, "c": 3, "d": 4, "e": 5, "f": 6}
new_args, new_kwargs = _convert_kwargs_to_args(func1, args, kwargs)
assert new_args == (1, 2, 3, 4, 5)
assert new_kwargs == {"f": 6}
def test_opdef_serialization():
with TemporaryFile() as f:
x = builtin.Elemwise(mode="Add")
pickle.dump(x, f)
f.seek(0)
load_x = pickle.load(f)
assert x == load_x
with TemporaryFile() as f:
x = builtin.Convolution(stride_h=9, compute_mode="float32")
x.strategy = (
builtin.Convolution.Strategy.PROFILE
| builtin.Convolution.Strategy.HEURISTIC
| builtin.Convolution.Strategy.REPRODUCIBLE
)
pickle.dump(x, f)
f.seek(0)
load_x = pickle.load(f)
assert x.strategy == load_x.strategy
assert x == load_x
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册