提交 7b19bc76 编写于 作者: M Megvii Engine Team

feat(traced_module): support traced module backward compatible serialization

GitOrigin-RevId: aaa9e51c74c11fa7955ae7bbfac476fa9bcf0d7d
上级 ffbfe59c
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ..core._imperative_rt.core2 import set_cpp_apply_module_trace from ..core._imperative_rt.core2 import set_cpp_apply_module_trace
from . import compat
from .traced_module import ( from .traced_module import (
TracedModule, TracedModule,
_register_all_builtin_module, _register_all_builtin_module,
......
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
from .. import tensor
from ..core.ops.builtin import BatchNorm
from .expr import CallMethod, Constant
from .node import TensorNode
from .serialization import (
register_functional_loader,
register_module_loader,
register_opdef_loader,
register_tensor_method_loader,
)
"""
# Expr loaders examples
from ..core.ops.builtin import Elemwise
@register_opdef_loader(Elemwise)
def add_opdef_loader(expr):
if expr.opdef_state["mode"] == "ADD":
expr.opdef_state["mode"] == "MUL"
node = expr.inputs[1]
astype_expr = CallMethod(node, "astype")
oup = TensorNode(
astype_expr,
shape=node.shape,
dtype=expr.inputs[0].dtype,
qparams=node.qparams,
)
astype_expr.set_args_kwargs(node, expr.inputs[0].dtype)
astype_expr.return_val = (oup,)
expr.inputs[1] = oup
@register_functional_loader(("megengine.functional.nn", "conv2d"))
def conv2df_loader(expr):
# expr.func = ("megengine.functional.nn","conv2d")
kwargs = expr.kwargs
orig_weight = expr.named_args["weight"]
astype_expr = CallMethod(orig_weight, "astype")
oup = TensorNode(
astype_expr,
shape=orig_weight.shape,
dtype=orig_weight.dtype,
qparams=orig_weight.qparams,
)
astype_expr.set_args_kwargs(orig_weight, expr.named_args["inp"].dtype)
astype_expr.return_val = (oup,)
expr.set_arg("weight", oup)
@register_module_loader(("megengine.module.conv", "Conv2d"))
def conv2dm_loader(expr):
module = expr.inputs[0].owner
args = list(expr.args)
orig_inp = args[1]
astype_expr = CallMethod(orig_inp, "astype")
oup = TensorNode(
astype_expr,
shape=orig_inp.shape,
dtype=orig_inp.dtype,
qparams=orig_inp.qparams,
)
astype_expr.set_args_kwargs(orig_inp, module.weight.dtype)
astype_expr.return_val = (oup,)
args[1] = oup
expr.set_args_kwargs(*args)
@register_tensor_method_loader("__add__")
def add_loader(expr):
args = list(expr.args)
if not isinstance(args[1], TensorNode):
args[1] = tensor(args[1])
node = Constant(args[1], "const").outputs[0]
astype_expr = CallMethod(node, "astype")
oup = TensorNode(
astype_expr, shape=node.shape, dtype=node.dtype, qparams=node.qparams,
)
astype_expr.set_args_kwargs(node, expr.inputs[0].dtype)
astype_expr.return_val = (oup,)
args[1] = oup
expr.set_args_kwargs(*args)
"""
@register_module_loader(
("megengine.module.batchnorm", "BatchNorm1d"),
("megengine.module.batchnorm", "BatchNorm2d"),
("megengine.module.batchnorm", "SyncBatchNorm"),
)
def bn2d_module_loader(expr):
# mge 1.6
if not hasattr(expr, "version"):
module = expr.inputs[0].owner
if not hasattr(module, "param_dim"):
module.param_dim = "dim_1c11"
@register_module_loader(
("megengine.module.conv_bn", "ConvBn2d"),
("megengine.module.conv_bn", "ConvBnRelu2d"),
("megengine.module.qat.conv_bn", "ConvBn2d"),
("megengine.module.qat.conv_bn", "ConvBnRelu2d"),
)
def convbn2d_module_loader(expr):
# mge 1.6
if not hasattr(expr, "version"):
module = expr.inputs[0].owner
if not hasattr(module.bn, "param_dim"):
module.bn.param_dim = "dim_1c11"
@register_opdef_loader(BatchNorm)
def bn_opdef_loader(expr):
# mge 1.6
if not hasattr(expr, "version"):
output = expr.outputs[-1]
oup = TensorNode(expr, shape=(0,), dtype=None, qparams=output._qparams,)
expr.outputs.insert(4, oup)
...@@ -11,19 +11,28 @@ import collections ...@@ -11,19 +11,28 @@ import collections
import copy import copy
import inspect import inspect
import re import re
from typing import Callable, Dict, List, Optional, Union import weakref
from importlib import import_module
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
from ..core._imperative_rt import OpDef from ..core._imperative_rt import OpDef
from ..core._imperative_rt.core2 import Tensor as RawTensor from ..core._imperative_rt.core2 import Tensor as RawTensor
from ..core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing from ..core._imperative_rt.core2 import (
apply,
is_tracing_module,
set_module_tracing,
unset_module_tracing,
)
from ..core.ops.builtin import FakeQuant from ..core.ops.builtin import FakeQuant
from ..core.ops.special import Const from ..core.ops.special import Const
from ..module import Module from ..module import Module
from ..tensor import Parameter, Tensor from ..tensor import Parameter, Tensor
from ..version import __version__
from .module_tracer import active_module_tracer, module_tracer from .module_tracer import active_module_tracer, module_tracer
from .node import ModuleNode, Node, NodeMixin, TensorNode from .node import ModuleNode, Node, NodeMixin, TensorNode
from .pytree import ArgsIndex, TreeDef, _is_const_leaf, _is_leaf, tree_flatten from .pytree import ArgsIndex, TreeDef, _is_const_leaf, _is_leaf, tree_flatten
from .serialization import get_opdef_state, load_opdef_from_state from .serialization import _ModuleState
from .utils import _check_builtin_module_attr, _check_obj_attr, _convert_kwargs_to_args
def rstrip(s: str, __chars: str): def rstrip(s: str, __chars: str):
...@@ -112,6 +121,7 @@ class Expr: ...@@ -112,6 +121,7 @@ class Expr:
node.users.append(self) node.users.append(self)
else: else:
assert node is None assert node is None
assert not isinstance(val, (Module, RawTensor))
assert _is_leaf(val) and _is_const_leaf(val) assert _is_leaf(val) and _is_const_leaf(val)
idx = len(self.inputs) + len(self.const_val) idx = len(self.inputs) + len(self.const_val)
self.const_val.append((idx, val)) self.const_val.append((idx, val))
...@@ -132,14 +142,14 @@ class Expr: ...@@ -132,14 +142,14 @@ class Expr:
current_graph._namespace.auto_naming_for_outputs(self) current_graph._namespace.auto_naming_for_outputs(self)
def unflatten_args(self, inputs): def unflatten_args(self, inputs):
if self.arg_def is not None: assert self.arg_def is not None, "{} expr doesn't have args/kwargs".format(
inputs = list(inputs) type(self).__name__
for idx, val in self.const_val: )
inputs.insert(idx, val) inputs = list(inputs)
args, kwargs = self.arg_def.unflatten(inputs) for idx, val in self.const_val:
return args, kwargs inputs.insert(idx, val)
else: args, kwargs = self.arg_def.unflatten(inputs)
return inputs, {} return args, kwargs
def replace_inputs(self, repl_dict: Dict[Node, Node]): def replace_inputs(self, repl_dict: Dict[Node, Node]):
r"""Replace the input Nodes of this Expr. r"""Replace the input Nodes of this Expr.
...@@ -165,6 +175,39 @@ class Expr: ...@@ -165,6 +175,39 @@ class Expr:
node.users.remove(self) node.users.remove(self)
repl_node.users.append(self) repl_node.users.append(self)
@property
def _support_set_args_kwargs(self):
return False
def set_args_kwargs(self, *args, **kwargs):
r""" Set args and kwargs for Expr.
"""
assert (
self._support_set_args_kwargs
), "Doesn't support set args/kwargs for {} expr".format(type(self).__name__)
args, kwargs = _convert_kwargs_to_args(self._get_func(), args, kwargs)
inputs, arg_def = tree_flatten((args, kwargs))
orig_inputs = self.inputs
self.inputs = []
self.const_val = []
for val in inputs:
if isinstance(val, (TensorNode, ModuleNode)):
self.inputs.append(val)
else:
assert _is_leaf(val) and _is_const_leaf(val)
idx = len(self.inputs) + len(self.const_val)
self.const_val.append((idx, val))
for n in orig_inputs:
if n not in self.inputs:
n.users.remove(self)
for n in self.inputs:
if n not in orig_inputs:
n.users.append(self)
self.arg_def = arg_def
@property @property
def kwargs(self): def kwargs(self):
r"""Get the keyword arguments of the operation corresponding to this Expr.""" r"""Get the keyword arguments of the operation corresponding to this Expr."""
...@@ -177,6 +220,61 @@ class Expr: ...@@ -177,6 +220,61 @@ class Expr:
args, _ = self.unflatten_args(self.inputs) args, _ = self.unflatten_args(self.inputs)
return args return args
def _get_func(self):
# get called function when the expr is interpreted
raise NotImplementedError
@property
def named_args(self):
func = self._get_func()
return inspect.getcallargs(func, *self.args, **self.kwargs)
def set_arg(self, name, val):
func = self._get_func()
if name in self.kwargs:
new_kwargs = self.kwargs
new_kwargs[name] = val
self.set_args_kwargs(*self.args, **new_kwargs)
else:
arg_spec = inspect.getfullargspec(func)
if name in arg_spec.args:
ind = arg_spec.args.index(name)
new_args = list(self.args)
new_args[ind] = val
self.set_args_kwargs(*new_args)
elif name == arg_spec.varargs:
assert arg_spec.varargs is not None
assert len(self.args) >= len(arg_spec.args)
val = (val,) if not isinstance(val, Sequence) else val
self.set_args_kwargs(*self.args[0 : len(arg_spec.args)], *val)
else:
assert (
arg_spec.varkw is not None
), "func {} does't have argument named {}".format(func, name)
new_kwargs = self.kwargs
new_kwargs[name] = val
self.set_args_kwargs(*self.args, **new_kwargs)
@property
def return_val(self):
return self.out_def.unflatten(self.outputs)
@return_val.setter
def return_val(self, new_outputs):
outputs, out_def = tree_flatten(
new_outputs, is_leaf=lambda x: isinstance(x, Node)
)
assert all(
isinstance(o, Node) for o in outputs
), "Return values of expr must be ModuleNode or TensorNode or Container with them"
assert all(
o.expr in (None, self) for o in outputs
), "Some nodes are produced by other expr, can not be output of expr {}".format(
self
)
self.outputs = outputs
self.out_def = out_def
@property @property
def top_graph(self): def top_graph(self):
r"""Get the parent graph of this Expr.""" r"""Get the parent graph of this Expr."""
...@@ -184,12 +282,6 @@ class Expr: ...@@ -184,12 +282,6 @@ class Expr:
return self._top_graph() return self._top_graph()
return None return None
def __getstate__(self):
state = self.__dict__.copy()
if "_top_graph" in state:
state.pop("_top_graph")
return state
@classmethod @classmethod
def _get_next_id(cls): def _get_next_id(cls):
return cls.__total_id return cls.__total_id
...@@ -199,6 +291,23 @@ class Expr: ...@@ -199,6 +291,23 @@ class Expr:
assert isinstance(id, int) assert isinstance(id, int)
cls.__total_id = id cls.__total_id = id
def __copy__(self):
cls = self.__class__
result = cls.__new__(cls)
result.__dict__.update(self.__dict__)
return result
def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
state = {}
memo[id(self)] = result
for k, v in self.__dict__.items():
if not isinstance(v, weakref.ReferenceType):
state[k] = copy.deepcopy(v, memo)
result.__dict__.update(state)
return result
# expr: None (i.e. fake expression which is used to mark input) # expr: None (i.e. fake expression which is used to mark input)
class Input(Expr): class Input(Expr):
...@@ -229,6 +338,17 @@ class Input(Expr): ...@@ -229,6 +338,17 @@ class Input(Expr):
def __repr__(self): def __repr__(self):
return "%{}:\t{} = Input()".format(self._id, self.outputs[0]) return "%{}:\t{} = Input()".format(self._id, self.outputs[0])
def __getstate__(self):
state = {
"_id": self._id,
"_disable_remove": self._disable_remove,
"inputs": self.inputs,
"outputs": self.outputs,
"name": self.name,
}
_check_obj_attr(state)
return state
# expr: outputs = getattr(inputs[0], self.name) # expr: outputs = getattr(inputs[0], self.name)
class GetAttr(Expr): class GetAttr(Expr):
...@@ -276,11 +396,23 @@ class GetAttr(Expr): ...@@ -276,11 +396,23 @@ class GetAttr(Expr):
def __repr__(self): def __repr__(self):
out_type = "Tensor" out_type = "Tensor"
if isinstance(self.outputs[0], ModuleNode): if isinstance(self.outputs[0], ModuleNode):
out_type = self.outputs[0].module_type.__name__ m_type = self.outputs[0].module_type
out_type = m_type.__name__ if isinstance(m_type, type) else m_type[1]
return '%{}:\t{} = getattr({}, "{}") -> ({})'.format( return '%{}:\t{} = getattr({}, "{}") -> ({})'.format(
self._id, self.outputs[0], self.inputs[0], self.name, out_type self._id, self.outputs[0], self.inputs[0], self.name, out_type
) )
def __getstate__(self):
state = {
"_id": self._id,
"_disable_remove": self._disable_remove,
"inputs": self.inputs,
"outputs": self.outputs,
"name": self.name,
}
_check_obj_attr(state)
return state
# expr: outputs = inputs[0].__call__(*inputs[1:]) # expr: outputs = inputs[0].__call__(*inputs[1:])
class CallMethod(Expr): class CallMethod(Expr):
...@@ -307,6 +439,7 @@ class CallMethod(Expr): ...@@ -307,6 +439,7 @@ class CallMethod(Expr):
node, node,
] ]
self.const_val = [] self.const_val = []
self.arg_def = tree_flatten(((node,), {}))[1]
self.method = method self.method = method
@classmethod @classmethod
...@@ -342,6 +475,27 @@ class CallMethod(Expr): ...@@ -342,6 +475,27 @@ class CallMethod(Expr):
outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor)) outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor))
return outputs return outputs
def _get_func(self):
if isinstance(self.args[0], type):
obj_type = self.args[0]
elif isinstance(self.args[0], ModuleNode):
obj_type = self.args[0].module_type
else:
assert isinstance(self.args[0], TensorNode)
obj_type = Tensor
meth = getattr(
obj_type, "forward" if issubclass(obj_type, Module) else self.method
)
return meth
@property
def _support_set_args_kwargs(self):
# only expr call tensor method or builtin module support modify args/kwargs
return (
isinstance(self.args[0], (TensorNode, type))
or self.args[0].module_type is not Module
)
def __repr__(self): def __repr__(self):
args = ", ".join(str(i) for i in self.args[1:]) args = ", ".join(str(i) for i in self.args[1:])
kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items()) kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
...@@ -359,6 +513,21 @@ class CallMethod(Expr): ...@@ -359,6 +513,21 @@ class CallMethod(Expr):
", ".join([args, kwargs]), ", ".join([args, kwargs]),
) )
def __getstate__(self):
state = {
"_id": self._id,
"_disable_remove": self._disable_remove,
"inputs": self.inputs,
"const_val": self.const_val,
"method": self.method,
"arg_def": self.arg_def,
"out_def": self.out_def,
"outputs": self.outputs,
"version": __version__,
}
_check_obj_attr(state)
return state
# expr: outputs = apply(self.opdef, *inputs) # expr: outputs = apply(self.opdef, *inputs)
class Apply(Expr): class Apply(Expr):
...@@ -394,14 +563,32 @@ class Apply(Expr): ...@@ -394,14 +563,32 @@ class Apply(Expr):
) )
def __getstate__(self): def __getstate__(self):
state = super().__getstate__() opdef_state = self.opdef.__getstate__()
state["opdef"] = get_opdef_state(state["opdef"]) opdef_state["opdef_type"] = type(self.opdef)
state = {
"_id": self._id,
"_disable_remove": self._disable_remove,
"opdef_state": opdef_state,
"inputs": self.inputs,
"outputs": self.outputs,
"version": __version__,
}
_check_obj_attr(state)
return state return state
def __setstate__(self, state): def __setstate__(self, state):
state["opdef"] = load_opdef_from_state(state["opdef"]) # compat with mge 1.6
for k, v in state.items(): if "opdef" in state and "opdef_state" not in state:
setattr(self, k, v) opdef_state = state.pop("opdef")
opdef_state["opdef_type"] = opdef_state.pop("type")
state["opdef_state"] = opdef_state
self.__dict__.update(state)
assert isinstance(state["opdef_state"], dict)
opdef_state = state["opdef_state"].copy()
opdef_type = opdef_state.pop("opdef_type")
opdef_obj = opdef_type()
opdef_obj.__setstate__(opdef_state)
setattr(self, "opdef", opdef_obj)
@classmethod @classmethod
def apply_module_trace_hook(cls, opdef, *inputs): def apply_module_trace_hook(cls, opdef, *inputs):
...@@ -458,12 +645,24 @@ class CallFunction(Expr): ...@@ -458,12 +645,24 @@ class CallFunction(Expr):
def interpret(self, *inputs): def interpret(self, *inputs):
args, kwargs = self.unflatten_args(inputs) args, kwargs = self.unflatten_args(inputs)
outputs = self.func(*args, **kwargs) func = (
self.func
if not is_tracing_module()
else active_module_tracer().patcher.wrap_fn(self.func)
)
outputs = func(*args, **kwargs)
if outputs is None: if outputs is None:
return outputs return outputs
outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor)) outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor))
return outputs return outputs
def _get_func(self):
return self.func
@property
def _support_set_args_kwargs(self):
return True
def __repr__(self): def __repr__(self):
args = ", ".join(str(i) for i in self.args) args = ", ".join(str(i) for i in self.args)
kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items()) kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
...@@ -477,6 +676,33 @@ class CallFunction(Expr): ...@@ -477,6 +676,33 @@ class CallFunction(Expr):
", ".join([args, kwargs]), ", ".join([args, kwargs]),
) )
def __getstate__(self):
state = {
"_id": self._id,
"_disable_remove": self._disable_remove,
"func": (self.func.__module__, self.func.__qualname__),
"const_val": self.const_val,
"inputs": self.inputs,
"arg_def": self.arg_def,
"out_def": self.out_def,
"outputs": self.outputs,
"version": __version__,
}
_check_obj_attr(state)
return state
def __setstate__(self, state):
self.__dict__.update(state)
try:
if isinstance(self.func, tuple):
mname, fname = self.func
f = import_module(mname)
for i in fname.split("."):
f = getattr(f, i)
self.func = f
except Exception:
pass
# expr outputs = self.value # expr outputs = self.value
class Constant(Expr): class Constant(Expr):
...@@ -496,6 +722,13 @@ class Constant(Expr): ...@@ -496,6 +722,13 @@ class Constant(Expr):
assert isinstance(c, (RawTensor, Module)) assert isinstance(c, (RawTensor, Module))
if isinstance(c, Module): if isinstance(c, Module):
assert module_tracer.is_builtin(c) or c.is_qat assert module_tracer.is_builtin(c) or c.is_qat
if isinstance(c, RawTensor):
if is_tracing_module():
unset_module_tracing()
c = Tensor(c)
set_module_tracing()
else:
c = Tensor(c)
self.value = c self.value = c
self.name = name self.name = name
self.inputs = [] self.inputs = []
...@@ -530,9 +763,25 @@ class Constant(Expr): ...@@ -530,9 +763,25 @@ class Constant(Expr):
) )
def __getstate__(self): def __getstate__(self):
state = self.__dict__.copy() state = {
if "_top_graph" in state: "_id": self._id,
state.pop("_top_graph") "_disable_remove": self._disable_remove,
"value": self.value,
"name": self.name,
"inputs": self.inputs,
"outputs": self.outputs,
}
_check_obj_attr(state)
if isinstance(self.value, RawTensor): if isinstance(self.value, RawTensor):
state["value"] = Tensor(self.value) state["value"] = Tensor(self.value)
if isinstance(self.value, Module) and module_tracer.is_builtin(self.value):
_check_builtin_module_attr(self.value)
state["value"] = _ModuleState.get_module_state(self.value)
return state return state
def __setstate__(self, state):
for k, v in state.items():
if isinstance(v, _ModuleState):
state[k] = v.to_module()
self.__dict__.update(state)
...@@ -72,7 +72,6 @@ BUILTIN_ARRAY_METHOD = [ ...@@ -72,7 +72,6 @@ BUILTIN_ARRAY_METHOD = [
"astype", "astype",
"reshape", "reshape",
"_broadcast", "_broadcast",
"transpose",
"flatten", "flatten",
"sum", "sum",
"prod", "prod",
......
...@@ -6,7 +6,9 @@ ...@@ -6,7 +6,9 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import abc import abc
import copy
import weakref import weakref
from importlib import import_module
from typing import Any, Dict, List, Tuple, Type from typing import Any, Dict, List, Tuple, Type
import numpy import numpy
...@@ -14,7 +16,9 @@ import numpy ...@@ -14,7 +16,9 @@ import numpy
from .. import get_logger from .. import get_logger
from ..core._imperative_rt.core2 import Tensor as RawTensor from ..core._imperative_rt.core2 import Tensor as RawTensor
from ..module import Module from ..module import Module
from ..quantization.utils import QParams
from ..tensor import Tensor from ..tensor import Tensor
from .utils import _check_obj_attr
logger = get_logger(__name__) logger = get_logger(__name__)
...@@ -145,6 +149,23 @@ class Node: ...@@ -145,6 +149,23 @@ class Node:
assert isinstance(id, int) assert isinstance(id, int)
cls.__total_id = id cls.__total_id = id
def __copy__(self):
cls = self.__class__
result = cls.__new__(cls)
result.__dict__.update(self.__dict__)
return result
def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
state = {}
memo[id(self)] = result
for k, v in self.__dict__.items():
if not isinstance(v, weakref.ReferenceType) and k != "actual_node":
state[k] = copy.deepcopy(v, memo)
result.__dict__.update(state)
return result
class ModuleNode(Node): class ModuleNode(Node):
r"""``ModuleNode`` represents the Module objects.""" r"""``ModuleNode`` represents the Module objects."""
...@@ -157,19 +178,28 @@ class ModuleNode(Node): ...@@ -157,19 +178,28 @@ class ModuleNode(Node):
super().__init__(expr, name, qualname) super().__init__(expr, name, qualname)
def __getstate__(self): def __getstate__(self):
return { state = {
"expr": self.expr, "expr": self.expr,
"users": self.users, "users": self.users,
"_id": self._id, "_id": self._id,
"_name": self._name, "_name": self._name,
"_qualname": self._qualname, "_qualname": self._qualname,
"module_type": self.module_type, "module_type": (self.module_type.__module__, self.module_type.__qualname__),
} }
_check_obj_attr(state)
return state
def __setstate__(self, state): def __setstate__(self, state):
if "_orig_name" in state: if "_orig_name" in state:
state["_qualname"] = state.pop("_orig_name") state["_qualname"] = state.pop("_orig_name")
self.__dict__.update(state) self.__dict__.update(state)
try:
if isinstance(self.module_type, tuple):
mname, classname = self.module_type
mtype = getattr(import_module(mname), classname)
self.module_type = mtype
except Exception:
pass
@property @property
def owner(self): def owner(self):
...@@ -185,12 +215,26 @@ class TensorNode(Node): ...@@ -185,12 +215,26 @@ class TensorNode(Node):
_shape = None # type: Tuple[int] _shape = None # type: Tuple[int]
_dtype = None # type: numpy.dtype _dtype = None # type: numpy.dtype
_qparams = None _qparams = None # type: QParams
_device = None _device = None
_value = None # type: Tensor _value = None # type: Tensor
def __init__(
self,
expr: "Expr",
name: str = None,
qualname: str = None,
shape: Tuple[int] = None,
dtype: numpy.dtype = None,
qparams: QParams = None,
):
super().__init__(expr, name, qualname)
self._shape = shape
self._dtype = shape
self._qparams = qparams
def __getstate__(self): def __getstate__(self):
return { state = {
"expr": self.expr, "expr": self.expr,
"users": self.users, "users": self.users,
"_id": self._id, "_id": self._id,
...@@ -201,6 +245,8 @@ class TensorNode(Node): ...@@ -201,6 +245,8 @@ class TensorNode(Node):
"_name": self._name, "_name": self._name,
"_qualname": self._qualname, "_qualname": self._qualname,
} }
_check_obj_attr(state)
return state
def __setstate__(self, state): def __setstate__(self, state):
if "_orig_name" in state: if "_orig_name" in state:
...@@ -276,7 +322,10 @@ class NodeMixin(abc.ABC): ...@@ -276,7 +322,10 @@ class NodeMixin(abc.ABC):
assert isinstance(node, TensorNode) assert isinstance(node, TensorNode)
assert isinstance(value, RawTensor) assert isinstance(value, RawTensor)
if isinstance(value, RawTensor): if isinstance(value, RawTensor):
node._dtype = value.dtype try:
node._dtype = value.dtype
except RuntimeError:
node._dtype = None
node._shape = ( node._shape = (
value._tuple_shape if isinstance(value, Tensor) else value.shape value._tuple_shape if isinstance(value, Tensor) else value.shape
) )
......
...@@ -7,15 +7,18 @@ ...@@ -7,15 +7,18 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections import collections
from collections import OrderedDict from collections import OrderedDict, defaultdict
from functools import partial
from typing import Callable, NamedTuple from typing import Callable, NamedTuple
import numpy as np import numpy as np
from ..core._imperative_rt import OpDef
from ..core._imperative_rt.common import CompNode from ..core._imperative_rt.common import CompNode
from ..core._imperative_rt.core2 import Tensor as RawTensor from ..core._imperative_rt.core2 import Tensor as RawTensor
from ..core._wrap import Device from ..core._wrap import Device
from ..core.tensor.dtype import QuantDtypeMeta from ..core.tensor.dtype import QuantDtypeMeta
from ..distributed import Group
from ..module import Module from ..module import Module
from ..quantization.utils import LSQParams, QParams, QuantMode from ..quantization.utils import LSQParams, QParams, QuantMode
from ..tensor import Parameter, Tensor from ..tensor import Parameter, Tensor
...@@ -49,45 +52,54 @@ SUPPORTED_LEAF_TYPE = { ...@@ -49,45 +52,54 @@ SUPPORTED_LEAF_TYPE = {
type(Ellipsis), type(Ellipsis),
QuantMode, QuantMode,
ArgsIndex, ArgsIndex,
Group,
} }
USER_REGISTERED_LEAF_TYPE = []
USER_REGISTERED_CONTAINER_TYPE = []
# if isinstance(object, SUPPORTED_LEAF_CLS) or issubclass(obj, SUPPORTED_LEAF_CLS) is True, the object could be threated as leaf node of pytree # if isinstance(object, SUPPORTED_LEAF_CLS) or issubclass(obj, SUPPORTED_LEAF_CLS) is True, the object could be threated as leaf node of pytree
SUPPORTED_LEAF_CLS = [Module, Node, NodeMixin, np.dtype, np.ndarray, np.number] SUPPORTED_LEAF_CLS = [
Module,
Node,
NodeMixin,
np.dtype,
np.ndarray,
np.number,
np.bool_,
OpDef,
]
NodeType = NamedTuple("NodeType", [("flatten", Callable), ("unflatten", Callable)]) NodeType = NamedTuple("NodeType", [("flatten", Callable), ("unflatten", Callable)])
def register_supported_type(type, flatten=None, unflatten=None): def register_supported_type(type, flatten=None, unflatten=None):
tp_info = (type.__module__, type.__qualname__)
if flatten and unflatten: if flatten and unflatten:
SUPPORTED_TYPE[type] = NodeType(flatten, unflatten) USER_REGISTERED_CONTAINER_TYPE.append(tp_info)
else: else:
SUPPORTED_LEAF_CLS.append(type) USER_REGISTERED_LEAF_TYPE.append(tp_info)
_register_supported_type(type, flatten, unflatten)
def _dict_flatten(inp):
aux_data = []
results = []
for key, value in sorted(inp.items()):
results.append(value)
aux_data.append(key)
return results, tuple(aux_data)
def _dict_unflatten(inps, aux_data): def _register_supported_type(type, flatten=None, unflatten=None):
return dict(zip(aux_data, inps)) if flatten and unflatten:
SUPPORTED_TYPE[type] = NodeType(flatten, unflatten)
else:
SUPPORTED_LEAF_CLS.append(type)
def _ordereddict_flatten(inp): def _dict_flatten(ordered, inp):
aux_data = [] aux_data = []
results = [] results = []
for key, value in inp.items(): dict_items = inp.items() if ordered else sorted(inp.items())
for key, value in dict_items:
results.append(value) results.append(value)
aux_data.append(key) aux_data.append(key)
return results, tuple(aux_data) return results, tuple(aux_data)
def _ordereddict_unflatten(inps, aux_data): def _dict_unflatten(dict_type, inps, aux_data):
return OrderedDict(zip(aux_data, inps)) return dict_type(zip(aux_data, inps))
def qparams_flatten(inp): def qparams_flatten(inp):
...@@ -99,33 +111,41 @@ def qparams_flatten(inp): ...@@ -99,33 +111,41 @@ def qparams_flatten(inp):
return results, tuple(aux_data) return results, tuple(aux_data)
def qparams_unflatten(inp, aux_data): def qparams_unflatten(qparam_type, inp, aux_data):
obj = QParams.__new__(QParams) obj = qparam_type.__new__(qparam_type)
for k, v in zip(aux_data, inp): for k, v in zip(aux_data, inp):
setattr(obj, k, v) setattr(obj, k, v)
return obj return obj
register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x)) _register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x))
register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: tuple(x)) _register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: tuple(x))
register_supported_type(dict, _dict_flatten, _dict_unflatten) _register_supported_type(
register_supported_type( dict, partial(_dict_flatten, False), partial(_dict_unflatten, dict)
collections.OrderedDict, _ordereddict_flatten, _ordereddict_unflatten )
_register_supported_type(
defaultdict, partial(_dict_flatten, False), partial(_dict_unflatten, defaultdict)
) )
register_supported_type( _register_supported_type(
OrderedDict, partial(_dict_flatten, True), partial(_dict_unflatten, OrderedDict)
)
_register_supported_type(
slice, slice,
lambda x: ([x.start, x.stop, x.step], None), lambda x: ([x.start, x.stop, x.step], None),
lambda x, aux_data: slice(x[0], x[1], x[2]), lambda x, aux_data: slice(x[0], x[1], x[2]),
) )
register_supported_type(QParams, qparams_flatten, qparams_unflatten) _register_supported_type(QParams, qparams_flatten, partial(qparams_unflatten, QParams))
_register_supported_type(
LSQParams, qparams_flatten, partial(qparams_unflatten, LSQParams)
)
def _is_leaf(obj): def _is_leaf(obj):
if isinstance(obj, type): obj_type = obj if isinstance(obj, type) else type(obj)
return issubclass(obj, tuple(SUPPORTED_LEAF_CLS)) or obj in SUPPORTED_LEAF_TYPE
return ( return (
isinstance(obj, tuple(SUPPORTED_LEAF_CLS)) or type(obj) in SUPPORTED_LEAF_TYPE issubclass(obj_type, tuple(SUPPORTED_LEAF_CLS))
or obj_type in SUPPORTED_LEAF_TYPE
) )
......
...@@ -5,30 +5,158 @@ ...@@ -5,30 +5,158 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Dict from importlib import import_module
from typing import Dict, Tuple
from ..core._imperative_rt import OpDef from ..core._imperative_rt import OpDef
from ..core.ops import builtin from ..core.ops import builtin
from ..tensor import Tensor
from ..version import __version__ from ..version import __version__
from .utils import _convert_kwargs_to_args
OPDEF_PARAM_LOADER = {} OPDEF_LOADER = {}
FUNCTIONAL_LOADER = {}
TENSORMETHOD_LOADER = {}
MODULE_LOADER = {}
def get_opdef_state(obj: OpDef) -> Dict: class _ModuleState:
state = obj.__getstate__() obj = None
state["type"] = type(obj)
state["version"] = __version__
return state
def __init__(self, module: Tuple, state: Dict, version: str):
self.module = module
self.state = state
self.version = version
def load_opdef_from_state(state: Dict) -> OpDef: @classmethod
assert "type" in state and issubclass(state["type"], OpDef) def get_module_state(cls, module):
assert "version" in state typem = (type(module).__module__, type(module).__qualname__)
opdef_type = state.pop("type") state = module.__dict__.copy()
if opdef_type in OPDEF_PARAM_LOADER: state.pop("_m_dump_modulestate", None)
loader = OPDEF_PARAM_LOADER[opdef_type] if hasattr(module, "_m_dump_modulestate"):
state = loader(state) assert isinstance(module._m_dump_modulestate, cls)
state.pop("version") module._m_dump_modulestate.__init__(typem, state, __version__)
opdef_obj = opdef_type() else:
opdef_obj.__setstate__(state) module.__dict__["_m_dump_modulestate"] = _ModuleState(
return opdef_obj typem, state, __version__
)
return module._m_dump_modulestate
def __getstate__(self):
return {"module": self.module, "state": self.state, "version": self.version}
def to_module(self):
if self.obj is None:
typem = getattr(import_module(self.module[0]), self.module[1])
m_obj = typem.__new__(typem)
m_obj.__dict__.update(self.state)
self.obj = m_obj
return self.obj
def register_opdef_loader(*opdefs):
def callback(loader):
for opdef in opdefs:
assert opdef not in OPDEF_LOADER
OPDEF_LOADER[opdef] = loader
return loader
return callback
def register_functional_loader(*funcs):
def callback(loader):
for func in funcs:
assert func not in FUNCTIONAL_LOADER
FUNCTIONAL_LOADER[func] = loader
return loader
return callback
def register_module_loader(*module_types):
def callback(loader):
for module_type in module_types:
assert module_type not in MODULE_LOADER
MODULE_LOADER[module_type] = loader
return loader
return callback
def register_tensor_method_loader(*methods):
def callback(loader):
for method in methods:
assert method not in TENSORMETHOD_LOADER
TENSORMETHOD_LOADER[method] = loader
return loader
return callback
def _replace_args_kwargs(expr, new_args, new_kwargs):
if len(new_args) != len(expr.args) or set(new_kwargs.keys()) != set(
expr.kwargs.keys()
):
expr.set_args_kwargs(*new_args, **new_kwargs)
def load_functional(expr):
func = (
(expr.func.__module__, expr.func.__qualname__)
if callable(expr.func)
else expr.func
)
assert isinstance(func, tuple)
if func in FUNCTIONAL_LOADER:
loader = FUNCTIONAL_LOADER[func]
loader(expr)
mname, fname = func
f = import_module(mname)
for i in fname.split("."):
f = getattr(f, i)
expr.func = f
assert callable(expr.func)
if not hasattr(expr, "version") or expr.version != __version__:
args, kwargs = _convert_kwargs_to_args(expr.func, expr.args, expr.kwargs)
_replace_args_kwargs(expr, args, kwargs)
def load_call_module_expr(expr):
m_type = expr.inputs[0].module_type
if isinstance(m_type, type):
m_type = (m_type.__module__, m_type.__qualname__)
if m_type in MODULE_LOADER:
MODULE_LOADER[m_type](expr)
if isinstance(expr.inputs[0].module_type, tuple):
mname, classname = expr.inputs[0].module_type
expr.inputs[0].module_type = getattr(import_module(mname), classname)
if not hasattr(expr, "version") or expr.version != __version__:
fwd_func = getattr(expr.inputs[0].module_type, "forward")
args, kwargs = _convert_kwargs_to_args(fwd_func, expr.args, expr.kwargs)
_replace_args_kwargs(expr, args, kwargs)
def load_call_tensor_method_expr(expr):
if expr.method in TENSORMETHOD_LOADER:
loader = TENSORMETHOD_LOADER[expr.method]
loader(expr)
if not hasattr(expr, "version") or expr.version != __version__:
tmethod = (
getattr(expr.args[0], expr.method)
if isinstance(expr.args[0], type)
else getattr(Tensor, expr.method)
)
args, kwargs = _convert_kwargs_to_args(tmethod, expr.args, expr.kwargs)
_replace_args_kwargs(expr, args, kwargs)
def load_apply_expr(expr):
opdef_type = type(expr.opdef)
if opdef_type in OPDEF_LOADER:
OPDEF_LOADER[opdef_type](expr)
opdef_state = expr.opdef_state
opdef_obj = opdef_state.pop("opdef_type")()
opdef_obj.__setstate__(opdef_state)
expr.opdef = opdef_obj
...@@ -14,6 +14,7 @@ import inspect ...@@ -14,6 +14,7 @@ import inspect
import keyword import keyword
import re import re
import weakref import weakref
from importlib import import_module
from inspect import getcallargs, getmembers, isclass, ismethod from inspect import getcallargs, getmembers, isclass, ismethod
from itertools import chain from itertools import chain
from types import FunctionType from types import FunctionType
...@@ -53,6 +54,7 @@ from ..quantization.observer import ( ...@@ -53,6 +54,7 @@ from ..quantization.observer import (
SyncMinMaxObserver, SyncMinMaxObserver,
) )
from ..tensor import Tensor from ..tensor import Tensor
from ..version import __version__
from .expr import ( from .expr import (
Apply, Apply,
CallFunction, CallFunction,
...@@ -80,8 +82,27 @@ from .module_tracer import ( ...@@ -80,8 +82,27 @@ from .module_tracer import (
set_active_module_tracer, set_active_module_tracer,
) )
from .node import ModuleNode, Node, NodeMixin, TensorNode from .node import ModuleNode, Node, NodeMixin, TensorNode
from .pytree import ArgsIndex, tree_flatten from .pytree import (
from .utils import replace_container_with_module_container USER_REGISTERED_CONTAINER_TYPE,
USER_REGISTERED_LEAF_TYPE,
ArgsIndex,
TreeDef,
_register_supported_type,
tree_flatten,
)
from .serialization import (
_ModuleState,
load_apply_expr,
load_call_module_expr,
load_call_tensor_method_expr,
load_functional,
)
from .utils import (
_check_builtin_module_attr,
_check_obj_attr,
_convert_kwargs_to_args,
replace_container_with_module_container,
)
logger = get_logger(__name__) logger = get_logger(__name__)
...@@ -341,7 +362,7 @@ class NameSpace: ...@@ -341,7 +362,7 @@ class NameSpace:
def create_unique_name(self, name: str, node: Any = None) -> str: def create_unique_name(self, name: str, node: Any = None) -> str:
assert isinstance(name, str), "The name must be a string" assert isinstance(name, str), "The name must be a string"
if name in self._used_names and self._used_names[name] is node: if name in self._used_names and (self._used_names[name] is node):
return name return name
name = re.sub("[^0-9a-zA-Z_]+", "_", name) name = re.sub("[^0-9a-zA-Z_]+", "_", name)
...@@ -1067,6 +1088,7 @@ class InternalGraph: ...@@ -1067,6 +1088,7 @@ class InternalGraph:
if node2value[n][1] == 0: if node2value[n][1] == 0:
node2value.pop(n) node2value.pop(n)
if values is not None: if values is not None:
assert len(values) == len(expr.outputs)
for n, v in zip(expr.outputs, values): for n, v in zip(expr.outputs, values):
if ref_count(n) > 0: if ref_count(n) > 0:
node2value[n] = [v, ref_count(n)] node2value[n] = [v, ref_count(n)]
...@@ -1105,13 +1127,27 @@ class InternalGraph: ...@@ -1105,13 +1127,27 @@ class InternalGraph:
return res return res
def __getstate__(self): def __getstate__(self):
state = self.__dict__.copy() state = {
if "_top_graph" in state: "_exprs": self._exprs,
state.pop("_top_graph") "_inputs": self._inputs,
"_outputs": self._outputs,
"_watch_point": [],
"_end_point": [],
"_namespace": self._namespace,
"_rst": collections.defaultdict(list),
"_name": self._name,
"_qualname": self._qualname,
}
if self._total_ids:
state["_total_ids"] = self._total_ids
_check_obj_attr(state)
return state return state
def __setstate__(self, state): def __setstate__(self, state):
old_version = False old_version = False
if "_module_name" in state: if "_module_name" in state:
old_version = True old_version = True
state["_qualname"] = state.pop("_module_name") state["_qualname"] = state.pop("_module_name")
...@@ -1144,6 +1180,25 @@ class InternalGraph: ...@@ -1144,6 +1180,25 @@ class InternalGraph:
self._namespace = NameSpace(self._name, self._qualname) self._namespace = NameSpace(self._name, self._qualname)
self._re_associate_name() self._re_associate_name()
def __copy__(self):
cls = self.__class__
result = cls.__new__(cls)
result.__dict__.update(self.__dict__)
return result
def __deepcopy__(self, memo):
if id(self) in memo:
return memo[id(self)]
cls = self.__class__
result = cls.__new__(cls)
state = {}
memo[id(self)] = result
for k, v in self.__dict__.items():
if not isinstance(v, weakref.ReferenceType):
state[k] = copy.deepcopy(v, memo)
result.__dict__.update(state)
return result
def _get_meth_name(obj, func): def _get_meth_name(obj, func):
tp = obj if isinstance(obj, type) else type(obj) tp = obj if isinstance(obj, type) else type(obj)
...@@ -1157,9 +1212,7 @@ def _get_meth_name(obj, func): ...@@ -1157,9 +1212,7 @@ def _get_meth_name(obj, func):
def _wrapped_function(orig_func): def _wrapped_function(orig_func):
@functools.wraps(orig_func) @functools.wraps(orig_func)
def wrapped_fn(*args, **kwargs): def wrapped_fn(*args, **kwargs):
method_func = wrapped_fn method_func = kwargs.pop("method_func", wrapped_fn)
if "method_func" in kwargs:
method_func = kwargs.pop("method_func")
if is_tracing_module(): if is_tracing_module():
unset_module_tracing() unset_module_tracing()
inputs, tree_def = tree_flatten((args, kwargs)) inputs, tree_def = tree_flatten((args, kwargs))
...@@ -1167,11 +1220,11 @@ def _wrapped_function(orig_func): ...@@ -1167,11 +1220,11 @@ def _wrapped_function(orig_func):
if not NodeMixin.get(i, None): if not NodeMixin.get(i, None):
if isinstance(i, (RawTensor, NodeMixin)): if isinstance(i, (RawTensor, NodeMixin)):
NodeMixin.wrap_safe(i, Constant.make(i)) NodeMixin.wrap_safe(i, Constant.make(i))
meth_name, arg_type = None, None args, kwargs = _convert_kwargs_to_args(orig_func, args, kwargs)
if args: meth_name = _get_meth_name(args[0], method_func)
meth_name = _get_meth_name(args[0], method_func) arg_type = args[0] if isinstance(args[0], type) else type(args[0])
arg_type = args[0] if isinstance(args[0], type) else type(args[0])
if meth_name and arg_type and issubclass(arg_type, RawTensor): if meth_name and arg_type and issubclass(arg_type, RawTensor):
inputs, tree_def = tree_flatten((args, kwargs))
self = inputs[0] self = inputs[0]
if meth_name == "__new__": if meth_name == "__new__":
if all([not isinstance(i, RawTensor) for i in inputs]): if all([not isinstance(i, RawTensor) for i in inputs]):
...@@ -1190,6 +1243,7 @@ def _wrapped_function(orig_func): ...@@ -1190,6 +1243,7 @@ def _wrapped_function(orig_func):
call_node = CallMethod.make(NodeMixin.get(self), meth_name) call_node = CallMethod.make(NodeMixin.get(self), meth_name)
call_node.add_inputs(inputs[1:]) call_node.add_inputs(inputs[1:])
else: else:
inputs, tree_def = tree_flatten((args, kwargs))
call_node = CallFunction.make(orig_func) call_node = CallFunction.make(orig_func)
call_node.add_inputs(inputs) call_node.add_inputs(inputs)
...@@ -1228,9 +1282,11 @@ class TracedModuleBuilder(NodeMixin): ...@@ -1228,9 +1282,11 @@ class TracedModuleBuilder(NodeMixin):
"_record_wrapped_nodes", "_record_wrapped_nodes",
"_argdef_graph_map", "_argdef_graph_map",
"_argdef_outdef_map", "_argdef_outdef_map",
"_check_qat_module",
"nodes", "nodes",
"__class__", "__class__",
"__dict__", "__dict__",
"_is_top",
] ]
def __init__(self, mod, is_top_module=False): def __init__(self, mod, is_top_module=False):
...@@ -1301,22 +1357,18 @@ class TracedModuleBuilder(NodeMixin): ...@@ -1301,22 +1357,18 @@ class TracedModuleBuilder(NodeMixin):
qat_module.weight_fake_quant.set_qparams(qparams) qat_module.weight_fake_quant.set_qparams(qparams)
def build(self): def build(self):
if self._is_builtin or isinstance(self._mod, TracedModule): if self._is_builtin:
if module_tracer.is_builtin(self._mod) or isinstance( assert module_tracer.is_builtin(self._mod)
self._mod, TracedModule mod_type = type(self._mod)
):
mod_type = type(self._mod)
else:
assert isinstance(self._mod, (Observer, _FakeQuantize))
mod_type = (
Observer if isinstance(self._mod, Observer) else _FakeQuantize
)
for node in self.nodes: for node in self.nodes:
node.module_type = mod_type node.module_type = mod_type
return self._mod return self._mod
else: else:
is_qat = isinstance(self._mod, QATModule) is_qat = isinstance(self._mod, QATModule) or (
isinstance(self._mod, TracedModule) and self._mod.is_qat
)
traced_module = TracedModule( traced_module = TracedModule(
self._is_top, self._argdef_graph_map, self._argdef_outdef_map, is_qat self._is_top, self._argdef_graph_map, self._argdef_outdef_map, is_qat
) )
...@@ -1338,15 +1390,18 @@ class TracedModuleBuilder(NodeMixin): ...@@ -1338,15 +1390,18 @@ class TracedModuleBuilder(NodeMixin):
traced_module.with_act = self._mod.with_act traced_module.with_act = self._mod.with_act
traced_module.with_weight = self._mod.with_weight traced_module.with_weight = self._mod.with_weight
if not hasattr(traced_module, "act_fake_quant"): if not hasattr(traced_module, "act_fake_quant"):
traced_module.act_fakequant = None traced_module.act_fake_quant = None
if not hasattr(traced_module, "act_observer"): if not hasattr(traced_module, "act_observer"):
traced_module.act_observer = None traced_module.act_observer = None
if not hasattr(traced_module, "weight_fake_quant"): if not hasattr(traced_module, "weight_fake_quant"):
traced_module.weight_fakequant = None traced_module.weight_fake_quant = None
if not hasattr(traced_module, "weight_observer"): if not hasattr(traced_module, "weight_observer"):
traced_module.weight_observer = None traced_module.weight_observer = None
set_module_tracing() set_module_tracing()
if self._is_top:
traced_module._update_ref()
return traced_module return traced_module
def _record_wrapped_nodes(self, node): def _record_wrapped_nodes(self, node):
...@@ -1357,6 +1412,7 @@ class TracedModuleBuilder(NodeMixin): ...@@ -1357,6 +1412,7 @@ class TracedModuleBuilder(NodeMixin):
# prepare args and kwargs for inner graph # prepare args and kwargs for inner graph
if "method_func" in kwargs: if "method_func" in kwargs:
kwargs.pop("method_func") kwargs.pop("method_func")
args, kwargs = _convert_kwargs_to_args(self._mod.forward, args, kwargs, True)
def mark_constant(x): def mark_constant(x):
node = NodeMixin.get(x, None) node = NodeMixin.get(x, None)
...@@ -1372,11 +1428,7 @@ class TracedModuleBuilder(NodeMixin): ...@@ -1372,11 +1428,7 @@ class TracedModuleBuilder(NodeMixin):
callnode.arg_def = tree_def callnode.arg_def = tree_def
if ( if self._is_builtin or tree_def in self._argdef_graph_map:
self._is_builtin
or tree_def in self._argdef_graph_map
or isinstance(self._mod, TracedModule)
):
unset_module_tracing() unset_module_tracing()
rst = self._mod(*args, **kwargs) rst = self._mod(*args, **kwargs)
outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf) outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf)
...@@ -1385,33 +1437,7 @@ class TracedModuleBuilder(NodeMixin): ...@@ -1385,33 +1437,7 @@ class TracedModuleBuilder(NodeMixin):
self._body = None self._body = None
elif tree_def in self._argdef_graph_map: elif tree_def in self._argdef_graph_map:
self._body = self._argdef_graph_map[tree_def] self._body = self._argdef_graph_map[tree_def]
else:
self._mod._is_top = False
self._body = self._mod.argdef_graph_map[tree_def]
module_qualname = NodeMixin.get(self).qualname
if module_qualname != self._body.qualname:
src_name, dst_name = self._body.qualname, module_qualname
def replace_qualname(g):
attr_name = get_suffix_name(src_name, g.qualname)
if attr_name is not None:
g._qualname = (
("%s.%s" % (dst_name, attr_name))
if attr_name
else dst_name
)
assert get_suffix_name(dst_name, g.qualname) is not None
for mod in self._mod.modules():
if not hasattr(mod, "argdef_graph_map"):
continue
for g in mod.argdef_graph_map.values():
replace_qualname(g)
g._namespace.qualname = g.qualname
for n in g.nodes(False):
replace_qualname(n)
else: else:
self_node = None
orig_self = NodeMixin.get(self) orig_self = NodeMixin.get(self)
parent_graph = active_module_tracer().current_scope() parent_graph = active_module_tracer().current_scope()
module_qualname = orig_self._qualname module_qualname = orig_self._qualname
...@@ -1423,20 +1449,14 @@ class TracedModuleBuilder(NodeMixin): ...@@ -1423,20 +1449,14 @@ class TracedModuleBuilder(NodeMixin):
active_module_tracer().push_scope(self._body) active_module_tracer().push_scope(self._body)
# rebind self to new input node # rebind self to new input node
if self_node: NodeMixin.wrap_safe(
NodeMixin.wrap_safe(self, self_node) self,
active_module_tracer().current_scope()._add_input(self_node) Input.make(
else: name="self",
NodeMixin.wrap_safe( qualname=module_qualname,
self, type=NodeMixin.get_wrapped_type(self),
self_node ),
if self_node )
else Input.make(
name="self",
qualname=module_qualname,
type=NodeMixin.get_wrapped_type(self),
),
)
origin_inp_node = [NodeMixin.get(i, None) for i in inputs[1:]] origin_inp_node = [NodeMixin.get(i, None) for i in inputs[1:]]
# prepare args and kwargs for inner graph # prepare args and kwargs for inner graph
...@@ -1470,8 +1490,23 @@ class TracedModuleBuilder(NodeMixin): ...@@ -1470,8 +1490,23 @@ class TracedModuleBuilder(NodeMixin):
return x return x
args = [self] args = [self]
for i, v in enumerate(inputs[1:]): orig_traced_inputs = (
args.append(wrap(v, idx2key[i + 1])) None
if not isinstance(self._mod, TracedModule)
else self._mod.argdef_graph_map[tree_def].inputs
)
ind = 1
for v in inputs[1:]:
if isinstance(v, (RawTensor, NodeMixin)):
args_name = (
orig_traced_inputs[ind]._name
if orig_traced_inputs
else idx2key[ind]
)
ind += 1
args.append(wrap(v, args_name))
else:
args.append(v)
args, kwargs = tree_def.unflatten(args) args, kwargs = tree_def.unflatten(args)
active_module_tracer().patcher.auto_patch( active_module_tracer().patcher.auto_patch(
...@@ -1514,7 +1549,6 @@ class TracedModuleBuilder(NodeMixin): ...@@ -1514,7 +1549,6 @@ class TracedModuleBuilder(NodeMixin):
attr = getattr(type(self._mod), name).__get__(self, type(self)) attr = getattr(type(self._mod), name).__get__(self, type(self))
else: else:
attr = getattr(self._mod, name) attr = getattr(self._mod, name)
if ( if (
isinstance(attr, FunctionType) isinstance(attr, FunctionType)
and id(attr) in active_module_tracer().patcher.patched_fn_ids and id(attr) in active_module_tracer().patcher.patched_fn_ids
...@@ -1568,7 +1602,7 @@ class TracedModuleBuilder(NodeMixin): ...@@ -1568,7 +1602,7 @@ class TracedModuleBuilder(NodeMixin):
wrapped = self.__getattr__(name) wrapped = self.__getattr__(name)
if isinstance(wrapped, TracedModuleBuilder): if isinstance(wrapped, TracedModuleBuilder):
if not isinstance(mod_attr, (List, Dict)): if not isinstance(mod_attr, (List, Dict, QATModule)):
assert mod_attr is wrapped._mod assert mod_attr is wrapped._mod
else: else:
assert mod_attr is wrapped assert mod_attr is wrapped
...@@ -1977,8 +2011,6 @@ class TracedModule(Module): ...@@ -1977,8 +2011,6 @@ class TracedModule(Module):
def graph(self) -> InternalGraph: def graph(self) -> InternalGraph:
"""Return the ``InternalGraph`` of this ``TracedModule``. """Return the ``InternalGraph`` of this ``TracedModule``.
""" """
if self._is_top:
self._update_ref()
assert len(self.argdef_graph_map) == 1 assert len(self.argdef_graph_map) == 1
return list(self.argdef_graph_map.values())[0] return list(self.argdef_graph_map.values())[0]
...@@ -2112,7 +2144,7 @@ class TracedModule(Module): ...@@ -2112,7 +2144,7 @@ class TracedModule(Module):
if hasattr(obj, "argdef_graph_map") if hasattr(obj, "argdef_graph_map")
else None else None
) )
if expr_graph is not None: if expr_graph is not None and not obj.is_qat:
exprs = _flatten_subgraph(graph, expr_graph, expr, obj) exprs = _flatten_subgraph(graph, expr_graph, expr, obj)
if parent_graph is not None: if parent_graph is not None:
...@@ -2137,26 +2169,119 @@ class TracedModule(Module): ...@@ -2137,26 +2169,119 @@ class TracedModule(Module):
) )
new_module.graph._re_associate_name() new_module.graph._re_associate_name()
new_module.graph.compile() new_module.graph.compile()
new_module._update_ref()
new_module.graph._reset_ids() new_module.graph._reset_ids()
return new_module return new_module
def __getstate__(self): def __getstate__(self):
d = self.__dict__ d = self.__dict__.copy()
for k in Module.__dict__: for k in Module.__dict__:
d.pop(k, None) d.pop(k, None)
_check_obj_attr(d)
for k in d:
if module_tracer.is_builtin(d[k]):
assert _check_builtin_module_attr(
d[k]
), "Module {} can not be serialized. ".format(type(d[k]))
d[k] = _ModuleState.get_module_state(d[k])
dump_info = {
"version": __version__,
"register_type": USER_REGISTERED_LEAF_TYPE,
"register_container_type": USER_REGISTERED_CONTAINER_TYPE,
"register_mdule": USER_REGISTERED_MODULE,
"register_function": USER_REGISTERED_FUNCTION,
}
d["dump_info"] = dump_info
return d return d
def __setstate__(self, state):
for k, v in state.items():
if isinstance(v, _ModuleState):
state[k] = v.to_module()
self.__dict__.update(state)
self._update_ref()
for _, graph in self.argdef_graph_map.items():
for expr in graph._exprs:
if isinstance(expr, CallFunction):
load_functional(expr)
if isinstance(expr, CallMethod):
if expr.method == "__call__":
load_call_module_expr(expr)
else:
load_call_tensor_method_expr(expr)
if isinstance(expr, Apply):
load_apply_expr(expr)
for _, graph in self.argdef_graph_map.items():
ind = 0
while ind < len(graph._exprs):
cur_expr = graph._exprs[ind]
has_new_expr = False
for i in cur_expr.inputs:
if i.expr not in graph._exprs and not isinstance(i.expr, Input):
graph._exprs.insert(ind, i.expr)
has_new_expr = True
if not has_new_expr:
ind += 1
for expr in graph._exprs:
for i in expr.inputs:
if expr.inputs.count(i) != i.users.count(expr):
add_or_del_count = expr.inputs.count(i) - i.users.count(expr)
if add_or_del_count > 0:
i.users.extend([expr] * add_or_del_count)
else:
[i.users.remove(expr) for i in range(-add_or_del_count)]
for o in expr.outputs:
if o.expr is not expr:
assert o not in o.expr.outputs
o.expr = expr
for node in graph.nodes(False):
# remove users of node which doesn't use node as input
node.users = [e for e in node.users if node in e.inputs]
for expr in graph._exprs:
graph._namespace.auto_naming_for_outputs(expr)
self._update_ref()
for _, graph in self.argdef_graph_map.items():
graph._reset_ids()
def __copy__(self):
cls = self.__class__
result = cls.__new__(cls)
result.__dict__.update(self.__dict__)
return result
def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
state = {}
memo[id(self)] = result
for k, v in self.__dict__.items():
if not isinstance(v, weakref.ReferenceType):
state[k] = copy.deepcopy(v, memo)
result.__dict__.update(state)
result._update_ref()
return result
def cpp_apply_module_trace(opdef, *args): def cpp_apply_module_trace(opdef, *args):
return Apply.apply_module_trace_hook(opdef, *args) return Apply.apply_module_trace_hook(opdef, *args)
USER_REGISTERED_MODULE = []
USER_REGISTERED_FUNCTION = []
def register_as_builtin(mod_cls: Type[Module]) -> None: def register_as_builtin(mod_cls: Type[Module]) -> None:
r"""Registers class ``mod_cls`` (subclass of :class:`~.Module`) as builtin module. r"""Registers class ``mod_cls`` (subclass of :class:`~.Module`) as builtin module.
Args: Args:
mod_cls: the module class which will be treated as builtin module in tracing. mod_cls: the module class which will be treated as builtin module in tracing.
""" """
USER_REGISTERED_MODULE.append((mod_cls.__module__, mod_cls.__qualname__))
module_tracer.register_as_builtin(mod_cls) module_tracer.register_as_builtin(mod_cls)
...@@ -2181,6 +2306,7 @@ def wrap(func: Callable): ...@@ -2181,6 +2306,7 @@ def wrap(func: Callable):
Args: Args:
func: the function of the global function to insert into the graph when it's called. func: the function of the global function to insert into the graph when it's called.
""" """
USER_REGISTERED_FUNCTION.append((func.__module__, func.__qualname__))
assert callable(func), "func must be a callable" assert callable(func), "func must be a callable"
assert hasattr(func, "__code__") assert hasattr(func, "__code__")
fn_name = func.__code__.co_name fn_name = func.__code__.co_name
...@@ -2247,6 +2373,8 @@ def trace_module( ...@@ -2247,6 +2373,8 @@ def trace_module(
NodeMixin.wrap_safe( NodeMixin.wrap_safe(
builder, Input.make(name="top", type=ModuleNode, qualname=net_name) builder, Input.make(name="top", type=ModuleNode, qualname=net_name)
) )
args, kwargs = _convert_kwargs_to_args(mod.forward, args, kwargs, True)
inputs, _ = tree_flatten((args, kwargs)) inputs, _ = tree_flatten((args, kwargs))
for _, i in enumerate(inputs): for _, i in enumerate(inputs):
# assert isinstance(i, Tensor), "not support " # assert isinstance(i, Tensor), "not support "
......
...@@ -5,12 +5,17 @@ ...@@ -5,12 +5,17 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections
import copy import copy
import inspect
from collections.abc import MutableMapping, MutableSequence from collections.abc import MutableMapping, MutableSequence
from typing import Dict, Iterable, List, Optional, Sequence from typing import Dict, Iterable, List, Optional, Sequence, Type
from .. import get_logger
from ..module import Module from ..module import Module
logger = get_logger(__name__)
def replace_container_with_module_container(container): def replace_container_with_module_container(container):
has_module = False has_module = False
...@@ -52,6 +57,101 @@ def replace_container_with_module_container(container): ...@@ -52,6 +57,101 @@ def replace_container_with_module_container(container):
return has_module, module_container return has_module, module_container
def _convert_kwargs_to_args(func, args, kwargs, is_bounded=False):
# is_bounded = True when func is a method and provided args don't include 'self'
arg_specs = inspect.getfullargspec(func)
arg_specs_args = arg_specs.args
if is_bounded:
arg_specs_args = arg_specs.args[1:]
new_args = []
new_kwargs = {}
new_args.extend(args)
if set(arg_specs_args[0 : len(new_args)]) & set(kwargs.keys()):
repeated_arg_name = set(arg_specs_args[0 : len(new_args)]) & set(kwargs.keys())
raise TypeError(
"{} got multiple values for argument {}".format(
func.__qualname__, ", ".join(repeated_arg_name)
)
)
if len(new_args) < len(arg_specs.args):
for ind in range(len(new_args), len(arg_specs_args)):
arg_name = arg_specs_args[ind]
if arg_name in kwargs:
new_args.append(kwargs[arg_name])
else:
index = ind - len(arg_specs_args) + len(arg_specs.defaults)
assert index < len(arg_specs.defaults) and index >= 0
new_args.append(arg_specs.defaults[index])
for kwarg_name in arg_specs.kwonlyargs:
if kwarg_name in kwargs:
new_kwargs[kwarg_name] = kwargs[kwarg_name]
else:
assert kwarg_name in arg_specs.kwonlydefaults
new_kwargs[kwarg_name] = arg_specs.kwonlydefaults[kwarg_name]
for k, v in kwargs.items():
if k not in arg_specs.args and k not in arg_specs.kwonlyargs:
if arg_specs.varkw is None:
raise TypeError(
"{} got an unexpected keyword argument {}".format(
func.__qualname__, k
)
)
new_kwargs[k] = v
return tuple(new_args), new_kwargs
def _check_obj_attr(obj):
# check if all the attributes of a obj is serializable
from .pytree import tree_flatten
from .pytree import SUPPORTED_LEAF_CLS, SUPPORTED_LEAF_TYPE, TreeDef
from .expr import Expr
from .traced_module import TracedModule, InternalGraph, NameSpace
def _check_leaf_type(leaf):
leaf_type = leaf if isinstance(leaf, type) else type(leaf)
traced_module_types = [Expr, TreeDef, TracedModule, InternalGraph, NameSpace]
return (
issubclass(leaf_type, tuple(SUPPORTED_LEAF_CLS + traced_module_types))
or leaf_type in SUPPORTED_LEAF_TYPE
)
for _, v in obj.items():
leafs, _ = tree_flatten(v, is_leaf=lambda _: True)
for leaf in leafs:
assert _check_leaf_type(
leaf
), "Type {} is not supported by traced module".format(
leaf if isinstance(leaf, type) else type(leaf)
)
def _check_builtin_module_attr(mod):
from .pytree import _is_leaf as _check_leaf_type
from .pytree import tree_flatten
# check if all the attributes of a builtin module is serializable
is_non_serializable_module = lambda m: isinstance(
m, Module
) and not _check_builtin_module_attr(m)
for k, v in mod.__dict__.items():
if k == "_m_dump_modulestate":
continue
if is_non_serializable_module(v):
return False
elif not isinstance(v, Module):
leafs, _ = tree_flatten(v, is_leaf=lambda _: True)
for leaf in leafs:
if not _check_leaf_type(leaf) or is_non_serializable_module(leaf):
logger.warn(
"Type {} is not supported by traced module".format(
leaf if isinstance(leaf, type) else type(leaf)
)
)
return False
return True
class _ModuleList(Module, MutableSequence): class _ModuleList(Module, MutableSequence):
r"""A List-like container. r"""A List-like container.
......
...@@ -15,7 +15,6 @@ import numpy as np ...@@ -15,7 +15,6 @@ import numpy as np
import megengine as mge import megengine as mge
from megengine import Parameter, Tensor from megengine import Parameter, Tensor
from megengine.core.ops import builtin from megengine.core.ops import builtin
from megengine.traced_module.serialization import get_opdef_state, load_opdef_from_state
def test_tensor_serialization(): def test_tensor_serialization():
...@@ -88,25 +87,3 @@ def test_compatibility(): ...@@ -88,25 +87,3 @@ def test_compatibility():
test_old_tensor("tensor_v1_1.mge") test_old_tensor("tensor_v1_1.mge")
test_old_tensor("tensor_v1_2.mge") test_old_tensor("tensor_v1_2.mge")
def test_opdef_serialization():
with TemporaryFile() as f:
x = builtin.Elemwise(mode="Add")
pickle.dump(get_opdef_state(x), f)
f.seek(0)
load_x = load_opdef_from_state(pickle.load(f))
assert x == load_x
with TemporaryFile() as f:
x = builtin.Convolution(stride_h=9, compute_mode="float32")
x.strategy = (
builtin.Convolution.Strategy.PROFILE
| builtin.Convolution.Strategy.HEURISTIC
| builtin.Convolution.Strategy.REPRODUCIBLE
)
pickle.dump(get_opdef_state(x), f)
f.seek(0)
load_x = load_opdef_from_state(pickle.load(f))
assert x.strategy == load_x.strategy
assert x == load_x
...@@ -85,12 +85,12 @@ class NewModule(M.Module): ...@@ -85,12 +85,12 @@ class NewModule(M.Module):
return x return x
def _check_expr_users(traced_module): def _check_expr_users(flattened_module):
node_user = defaultdict(list) node_user = defaultdict(list)
for expr in traced_module.graph._exprs: for expr in flattened_module.graph._exprs:
for node in expr.inputs: for node in expr.inputs:
node_user[node].append(expr) node_user[node].append(expr)
for node in traced_module.graph.nodes(): for node in flattened_module.graph.nodes():
node.users.sort(key=lambda m: m._id) node.users.sort(key=lambda m: m._id)
node_user[node].sort(key=lambda m: m._id) node_user[node].sort(key=lambda m: m._id)
assert node.users == node_user[node] assert node.users == node_user[node]
......
...@@ -8,6 +8,7 @@ import numpy as np ...@@ -8,6 +8,7 @@ import numpy as np
import megengine as mge import megengine as mge
import megengine.functional as F import megengine.functional as F
import megengine.module as M import megengine.module as M
import megengine.module.qat as QM
import megengine.quantization as Q import megengine.quantization as Q
from megengine import Tensor from megengine import Tensor
from megengine.module.qat.module import QATModule from megengine.module.qat.module import QATModule
...@@ -28,10 +29,18 @@ def get_subattr(self: M.Module, name: str): ...@@ -28,10 +29,18 @@ def get_subattr(self: M.Module, name: str):
return getattr(self, name) return getattr(self, name)
class MyConvBnRelu2d(M.ConvBnRelu2d):
pass
class MyQATConvBnRelu2d(QM.ConvBnRelu2d):
pass
class Myblcok(M.Module): class Myblcok(M.Module):
def __init__(self,): def __init__(self,):
super().__init__() super().__init__()
self.conv0 = M.ConvBnRelu2d(3, 3, 3, 1, 1) self.conv0 = MyConvBnRelu2d(3, 3, 3, 1, 1)
self.conv1 = M.ConvBn2d(3, 3, 1, 1, 0) self.conv1 = M.ConvBn2d(3, 3, 1, 1, 0)
self.conv2 = M.ConvBn2d(3, 3, 1, 1, 0) self.conv2 = M.ConvBn2d(3, 3, 1, 1, 0)
self.add = M.Elemwise("FUSE_ADD_RELU") self.add = M.Elemwise("FUSE_ADD_RELU")
...@@ -106,7 +115,11 @@ def check_qparams(qparmsa: Q.QParams, qparmsb: Q.QParams): ...@@ -106,7 +115,11 @@ def check_qparams(qparmsa: Q.QParams, qparmsb: Q.QParams):
def build_observered_net(net: M.Module, observer_cls): def build_observered_net(net: M.Module, observer_cls):
qat_net = Q.quantize_qat(net, qconfig=get_observer_config(observer_cls)) qat_net = Q.quantize_qat(
net,
qconfig=get_observer_config(observer_cls),
mapping={MyConvBnRelu2d: MyQATConvBnRelu2d},
)
Q.enable_observer(qat_net) Q.enable_observer(qat_net)
inp = Tensor(np.random.random(size=(5, 3, 32, 32))) inp = Tensor(np.random.random(size=(5, 3, 32, 32)))
qat_net(inp) qat_net(inp)
...@@ -134,6 +147,15 @@ def test_trace_qat(): ...@@ -134,6 +147,15 @@ def test_trace_qat():
check_qparams(weight_qparams, traced_weight_qparams) check_qparams(weight_qparams, traced_weight_qparams)
if act_qparams: if act_qparams:
check_qparams(act_qparams, traced_act_qparams) check_qparams(act_qparams, traced_act_qparams)
flatten_traced_net = traced_net.flatten()
conv0_node = flatten_traced_net.graph.get_node_by_name(
"MyModule_block0_conv0"
).as_unique()
conv0_out_node = flatten_traced_net.graph.get_node_by_name(
"MyModule_block0_conv0_out"
).as_unique()
assert isinstance(conv0_node.owner, TracedModule)
assert conv0_out_node.expr.inputs[0] is conv0_node
_check_qat_module(build_observered_net(MyModule(), Q.MinMaxObserver)) _check_qat_module(build_observered_net(MyModule(), Q.MinMaxObserver))
_check_qat_module(build_observered_net(MyModule(), MyMinMaxObserver)) _check_qat_module(build_observered_net(MyModule(), MyMinMaxObserver))
......
...@@ -6,14 +6,59 @@ ...@@ -6,14 +6,59 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import pickle import pickle
from collections import defaultdict
from tempfile import TemporaryFile
import numpy as np import numpy as np
import megengine.functional as F import megengine.functional as F
import megengine.module as M import megengine.module as M
import megengine.traced_module.serialization as S
from megengine import Tensor from megengine import Tensor
from megengine.core._imperative_rt.core2 import apply
from megengine.core.ops import builtin
from megengine.core.ops.builtin import Elemwise
from megengine.module import Module from megengine.module import Module
from megengine.traced_module import trace_module from megengine.traced_module import trace_module
from megengine.traced_module.expr import CallMethod, Constant
from megengine.traced_module.node import TensorNode
from megengine.traced_module.serialization import (
register_functional_loader,
register_module_loader,
register_opdef_loader,
register_tensor_method_loader,
)
from megengine.traced_module.utils import _convert_kwargs_to_args
def _check_id(traced_module):
_total_ids = traced_module.graph._total_ids
node_ids = [n._id for n in traced_module.graph.nodes().as_list()]
assert len(set(node_ids)) == len(node_ids)
assert max(node_ids) + 1 == _total_ids[0]
expr_ids = [n._id for n in traced_module.graph.exprs().as_list()]
assert len(set(expr_ids)) == len(expr_ids)
assert max(expr_ids) + 1 == _total_ids[1]
def _check_name(flatened_module):
node_names = [n._name for n in flatened_module.graph.nodes().as_list()]
assert len(set(node_names)) == len(node_names)
def _check_expr_users(traced_module):
node_user = defaultdict(list)
for expr in traced_module.graph._exprs:
for node in expr.inputs:
node_user[node].append(expr)
if isinstance(expr, CallMethod) and expr.graph:
_check_expr_users(expr.inputs[0].owner)
for node in traced_module.graph.nodes(False):
node.users.sort(key=lambda m: m._id)
node_user[node].sort(key=lambda m: m._id)
assert node.users == node_user[node]
class MyBlock(Module): class MyBlock(Module):
...@@ -48,5 +93,274 @@ def test_dump_and_load(): ...@@ -48,5 +93,274 @@ def test_dump_and_load():
traced_module = trace_module(module, x) traced_module = trace_module(module, x)
np.testing.assert_array_equal(expect, traced_module(x)) np.testing.assert_array_equal(expect, traced_module(x))
obj = pickle.dumps(traced_module) obj = pickle.dumps(traced_module)
pickle.loads(obj) new_tm = pickle.loads(obj)
_check_id(new_tm)
_check_expr_users(new_tm)
traced_module.graph._reset_ids()
old_nodes = traced_module.graph.nodes().as_list()
new_nodes = new_tm.graph.nodes().as_list()
old_exprs = traced_module.graph.exprs().as_list()
new_exprs = new_tm.graph.exprs().as_list()
assert len(old_nodes) == len(new_nodes)
for i, j in zip(old_nodes, new_nodes):
assert i._name == j._name
assert i._qualname == j._qualname
assert i._id == j._id
assert len(old_exprs) == len(new_exprs)
for i, j in zip(old_exprs, new_exprs):
assert i._id == j._id
np.testing.assert_array_equal(expect, traced_module(x)) np.testing.assert_array_equal(expect, traced_module(x))
def test_opdef_loader():
class MyModule1(Module):
def forward(self, x, y):
op = Elemwise("ADD")
return apply(op, x, y)[0]
m = MyModule1()
x = Tensor(np.ones((20)))
y = Tensor(np.ones((20)))
traced_module = trace_module(m, x, y)
orig_loader_dict = S.OPDEF_LOADER
S.OPDEF_LOADER = {}
@register_opdef_loader(Elemwise)
def add_opdef_loader(expr):
if expr.opdef_state["mode"] == "ADD":
expr.opdef_state["mode"] = "MUL"
node = expr.inputs[1]
astype_expr = CallMethod(node, "astype")
oup = TensorNode(
astype_expr,
shape=node.shape,
dtype=expr.inputs[0].dtype,
qparams=node.qparams,
)
astype_expr.set_args_kwargs(node, expr.inputs[0].dtype)
astype_expr.return_val = (oup,)
expr.inputs[1] = oup
obj = pickle.dumps(traced_module)
new_module = pickle.loads(obj)
_check_id(new_module)
_check_expr_users(new_module)
_check_name(new_module.flatten())
assert (
isinstance(new_module.graph._exprs[0], CallMethod)
and new_module.graph._exprs[1].opdef.mode == "MUL"
and len(new_module.graph._exprs) == 2
)
result = new_module(x, y)
np.testing.assert_equal(result.numpy(), x.numpy())
S.OPDEF_LOADER = orig_loader_dict
def test_functional_loader():
class MyModule2(Module):
def forward(self, x, y):
return F.conv2d(x, y)
m = MyModule2()
x = Tensor(np.random.random((1, 3, 32, 32)))
y = Tensor(np.random.random((3, 3, 3, 3)))
traced_module = trace_module(m, x, y)
orig_loader_dict = S.FUNCTIONAL_LOADER
S.FUNCTIONAL_LOADER = {}
@register_functional_loader(("megengine.functional.nn", "conv2d"))
def conv2df_loader(expr):
# expr.func = ("megengine.functional.nn","conv2d")
kwargs = expr.kwargs
orig_weight = expr.named_args["weight"]
astype_expr = CallMethod(orig_weight, "astype")
oup = TensorNode(
astype_expr,
shape=orig_weight.shape,
dtype=orig_weight.dtype,
qparams=orig_weight.qparams,
)
astype_expr.set_args_kwargs(orig_weight, expr.named_args["inp"].dtype)
astype_expr.return_val = (oup,)
expr.set_arg("weight", oup)
obj = pickle.dumps(traced_module)
new_module = pickle.loads(obj)
_check_expr_users(new_module)
_check_id(new_module)
result = new_module(x, y)
gt = m(x, y)
assert (
isinstance(new_module.graph._exprs[0], CallMethod)
and len(new_module.graph._exprs) == 2
)
np.testing.assert_equal(result.numpy(), gt.numpy())
S.FUNCTIONAL_LOADER = orig_loader_dict
def test_tensor_method_loader():
class MyModule3(Module):
def forward(self, x):
return x + 1
m = MyModule3()
x = Tensor(np.ones((20)))
traced_module = trace_module(m, x)
orig_loader_dict = S.TENSORMETHOD_LOADER
S.TENSORMETHOD_LOADER = {}
@register_tensor_method_loader("__add__")
def add_loader(expr):
args = list(expr.args)
if not isinstance(args[1], TensorNode):
args[1] = Tensor(args[1])
node = Constant(args[1], "const").outputs[0]
astype_expr = CallMethod(node, "astype")
oup = TensorNode(
astype_expr, shape=node.shape, dtype=node.dtype, qparams=node.qparams,
)
astype_expr.set_args_kwargs(node, expr.inputs[0].dtype)
astype_expr.return_val = (oup,)
add_expr = CallMethod(oup, "__add__")
add_expr.set_args_kwargs(oup, oup)
oup1 = TensorNode(
add_expr, shape=oup.shape, dtype=oup.dtype, qparams=node.qparams,
)
add_expr.return_val = oup1
args[1] = oup1
expr.set_args_kwargs(*args)
obj = pickle.dumps(traced_module)
new_module = pickle.loads(obj)
_check_expr_users(new_module)
_check_id(new_module)
result = new_module(x)
gt = m(x)
assert (
isinstance(new_module.graph._exprs[0], Constant)
and len(new_module.graph._exprs) == 4
)
np.testing.assert_equal(result.numpy(), (x + 2).numpy())
S.TENSORMETHOD_LOADER = orig_loader_dict
def test_module_loader():
class MyModule4(Module):
def __init__(self):
super().__init__()
self.conv = M.Conv2d(3, 3, 3)
def forward(self, x):
return self.conv(x)
m = MyModule4()
x = Tensor(np.random.random((1, 3, 32, 32)))
traced_module = trace_module(m, x)
orig_loader_dict = S.MODULE_LOADER
S.MODULE_LOADER = {}
@register_module_loader(("megengine.module.conv", "Conv2d"))
def conv2dm_loader(expr):
module = expr.inputs[0].owner
args = list(expr.args)
orig_inp = args[1]
astype_expr = CallMethod(orig_inp, "astype")
oup = TensorNode(
astype_expr,
shape=orig_inp.shape,
dtype=orig_inp.dtype,
qparams=orig_inp.qparams,
)
astype_expr.set_args_kwargs(orig_inp, module.weight.dtype)
astype_expr.return_val = (oup,)
args[1] = oup
expr.set_args_kwargs(*args)
obj = pickle.dumps(traced_module)
new_module = pickle.loads(obj)
result = new_module(x)
gt = m(x)
assert (
isinstance(new_module.graph._exprs[1], CallMethod)
and len(new_module.graph._exprs) == 3
)
np.testing.assert_equal(result.numpy(), gt.numpy())
S.MODULE_LOADER = orig_loader_dict
def test_shared_module():
class MyModule(M.Module):
def __init__(self):
super().__init__()
self.a = M.Elemwise("ADD")
self.b = self.a
def forward(self, x, y):
z = self.a(x, y)
z = self.b(z, y)
return z
x = Tensor(1)
y = Tensor(2)
m = MyModule()
tm = trace_module(m, x, y)
obj = pickle.dumps(tm)
load_tm = pickle.loads(obj)
_check_expr_users(load_tm)
_check_name(load_tm.flatten())
_check_id(load_tm)
assert load_tm.a is load_tm.b
def test_convert_kwargs_to_args():
def func(a, b, c=4, *, d, e=3, f=4):
pass
args = (1,)
kwargs = {"b": 1, "d": 6}
new_args, new_kwargs = _convert_kwargs_to_args(func, args, kwargs)
assert new_args == (1, 1, 4)
assert new_kwargs == {"d": 6, "e": 3, "f": 4}
args = (1,)
kwargs = {"d": 6}
new_args, new_kwargs = _convert_kwargs_to_args(func, args, kwargs, is_bounded=True)
assert new_args == (1, 4)
assert new_kwargs == {"d": 6, "e": 3, "f": 4}
def func1(a, b, c, d, e, *, f):
pass
args = ()
kwargs = {"a": 1, "b": 2, "c": 3, "d": 4, "e": 5, "f": 6}
new_args, new_kwargs = _convert_kwargs_to_args(func1, args, kwargs)
assert new_args == (1, 2, 3, 4, 5)
assert new_kwargs == {"f": 6}
def test_opdef_serialization():
with TemporaryFile() as f:
x = builtin.Elemwise(mode="Add")
pickle.dump(x, f)
f.seek(0)
load_x = pickle.load(f)
assert x == load_x
with TemporaryFile() as f:
x = builtin.Convolution(stride_h=9, compute_mode="float32")
x.strategy = (
builtin.Convolution.Strategy.PROFILE
| builtin.Convolution.Strategy.HEURISTIC
| builtin.Convolution.Strategy.REPRODUCIBLE
)
pickle.dump(x, f)
f.seek(0)
load_x = pickle.load(f)
assert x.strategy == load_x.strategy
assert x == load_x
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册