提交 b3d0affa 编写于 作者: M Megvii Engine Team

feat(traced_module): support trace custom qat module

GitOrigin-RevId: 49f70a5f467e93ff58fc5152499f04733258fd0d
上级 15712807
......@@ -17,12 +17,14 @@ from typing import Callable, Dict, List
from ...core._imperative_rt import OpDef
from ...core._imperative_rt.core2 import Tensor as RawTensor
from ...core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing
from ...core.ops.builtin import FakeQuant
from ...core.ops.special import Const
from ...module import Module
from ...tensor import Parameter, Tensor
from .module_tracer import active_module_tracer, module_tracer
from .node import ModuleNode, Node, NodeMixin, TensorNode
from .pytree import ArgsIndex, TreeDef, tree_flatten
from .pytree import ArgsIndex, TreeDef, _is_const_leaf, _is_leaf, tree_flatten
from .serialization import get_opdef_state, load_opdef_from_state
def rstrip(s: str, __chars: str):
......@@ -76,6 +78,7 @@ class Expr:
node.users.append(self)
else:
assert node is None
assert _is_leaf(val) and _is_const_leaf(val)
idx = len(self.inputs) + len(self.const_val)
self.const_val.append((idx, val))
......@@ -154,6 +157,11 @@ class Expr:
return self._top_graph()
return None
def __getstate__(self):
state = self.__dict__.copy()
state.pop("_top_graph", None)
return state
# expr: None (i.e. fake expression which is used to mark input)
class Input(Expr):
......@@ -321,14 +329,36 @@ class Apply(Expr):
", ".join(str(i) for i in self.inputs),
)
def __getstate__(self):
state = super().__getstate__()
state["opdef"] = get_opdef_state(state["opdef"])
return state
def __setstate__(self, state):
state["opdef"] = load_opdef_from_state(state["opdef"])
for k, v in state.items():
setattr(self, k, v)
@classmethod
def apply_module_trace_hook(cls, opdef, *inputs):
for i in inputs:
node = NodeMixin.get(i, None)
if node is None: # capture as constant
NodeMixin.wrap_safe(i, Constant.make(i))
apply_node = cls.make(opdef)
apply_node.add_inputs(inputs)
if isinstance(opdef, FakeQuant):
inp_nodes = [NodeMixin.get(inputs[0])]
for i in inputs[1:]:
node = Constant.make(i)
inp_nodes.append(node)
apply_node = cls.make(opdef)
for n in inp_nodes:
n.users.append(apply_node)
apply_node.inputs = inp_nodes
else:
apply_node = cls.make(opdef)
apply_node.add_inputs(inputs)
assert not apply_node.const_val
unset_module_tracing()
......@@ -387,7 +417,7 @@ class Constant(Expr):
super().__init__()
assert isinstance(c, (RawTensor, Module))
if isinstance(c, Module):
assert module_tracer.is_builtin(c)
assert module_tracer.is_builtin(c) or c.is_qat
self.value = c
self.name = name
self.inputs = []
......@@ -395,6 +425,7 @@ class Constant(Expr):
self.outputs = [
node_cls(self, name=name),
]
self.outputs[0]._name = name if name else "const_" + str(self._id)
@classmethod
def make(cls, *args, **kwargs):
......@@ -422,7 +453,7 @@ class Constant(Expr):
)
def __getstate__(self):
state = self.__dict__.copy()
state = super().__getstate__()
if isinstance(self.value, RawTensor):
state["value"] = Tensor(self.value)
return state
from copy import deepcopy
from typing import Union
from ...core.tensor.dtype import QuantDtypeMeta
from ...quantization.fake_quant import QParamsModuleMixin, _FakeQuantize
from ...quantization.utils import QParams, QuantMode, fake_quant_tensor
class FakeQuantize(_FakeQuantize, QParamsModuleMixin):
def __init__(
self, dtype: Union[str, QuantDtypeMeta], enable: bool = True, **kwargs
):
super().__init__(dtype, enable, **kwargs)
self.qparams = None
def fake_quant_forward(self, inp, qparams: QParams = None):
if qparams is None:
qparams = self.get_qparams()
assert (
qparams.dtype_meta is self.dtype
), "input qparams' dtype is not equal to self.dtype.\nqparams.dtype_meta={}\nself.dtype={}".format(
qparams.dtype_meta, self.dtype
)
return fake_quant_tensor(inp, qparams)
def get_qparams(self):
return self.qparams
def set_qparams(self, qparams: QParams):
"""
:param qparams: used to set initial scale.
"""
if qparams.scale is None:
raise AssertionError("Can not get an initialized scale")
scale = qparams.scale
if qparams.dtype_meta is None:
qparams.dtype_meta = self.dtype
else:
assert (
qparams.dtype_meta is self.dtype
), "input qparams' dtype is not equal to self.dtype.\nqparams.dtype_meta={}\nself.dtype={}".format(
qparams.dtype_meta, self.dtype
)
dtype_meta = qparams.dtype_meta
zero_point = qparams.zero_point
mode = qparams.mode
self.qparams = QParams(mode, dtype_meta, scale, zero_point)
......@@ -12,6 +12,7 @@ from ... import Tensor
from ... import functional as F
from ...core.tensor.array_method import ArrayMethodMixin
from ...module import Module
from ...module.qat import QATModule
_active_module_tracer = None
......@@ -68,7 +69,7 @@ BUILTIN_ARRAY_METHOD = [
"__iand__",
"__ior__",
"__ixor__",
"T",
"transpose",
"astype",
"reshape",
"_broadcast",
......@@ -180,6 +181,7 @@ class Patcher:
self.patch_method(ArrayMethodMixin, meth, self.wrap_fn)
self.patch_method(Tensor, "detach", self.wrap_fn)
self.patch_method(Tensor, "__new__", self.wrap_fn)
self.patch_method(QATModule, "_apply_fakequant_with_observer", self.wrap_fn)
for i, j in self._builtin_functions:
if id(i) not in self.visited_frames_ids:
self.patch_function(i, j, self.wrap_fn)
......
......@@ -127,7 +127,7 @@ class TensorNode(Node):
shape = None # type: Tuple[int]
dtype = None # type: numpy.dtype
qparam = None
qparams = None
device = None
def __getstate__(self):
......@@ -135,7 +135,7 @@ class TensorNode(Node):
"expr": self.expr,
"users": self.users,
"_id": self._id,
"qparam": self.qparam,
"qparams": self.qparams,
"shape": self.shape,
"dtype": self.dtype,
"device": self.device,
......
......@@ -155,10 +155,7 @@ def tree_flatten(
assert is_leaf(values), values
node = LeafDef(leaf_type(values))
if is_const_leaf(values):
if isinstance(values, np.ndarray):
node.const_val = str(values)
else:
node.const_val = values
node.const_val = values
return [values,], node
rst = []
......@@ -232,9 +229,13 @@ class LeafDef(TreeDef):
return leaves[0]
def __eq__(self, other):
if isinstance(self.const_val, np.ndarray):
return self.type == other.type and (self.const_val == other.const_val).all()
return self.type == other.type and self.const_val == other.const_val
def __hash__(self):
if isinstance(self.const_val, np.ndarray):
return hash(tuple([self.type, str(self.const_val)]))
return hash(tuple([self.type, self.const_val]))
def __repr__(self):
......
......@@ -29,14 +29,20 @@ from ...core._imperative_rt.core2 import (
from ...core._trace_option import set_symbolic_shape
from ...core.tensor.array_method import ArrayMethodMixin
from ...module import Module
from ...quantization.fake_quant import LSQ, TQT, FakeQuantize
from ...module.qat import QATModule
from ...quantization.fake_quant import LSQ, TQT, FakeQuantize, _FakeQuantize
from ...quantization.observer import (
ExponentialMovingAverageObserver,
HistogramObserver,
MinMaxObserver,
Observer,
PassiveObserver,
SyncExponentialMovingAverageObserver,
SyncMinMaxObserver,
)
from ...tensor import Tensor
from .expr import Apply, CallFunction, CallMethod, Constant, Expr, GetAttr, Input
from .fake_quant import FakeQuantize as TM_FakeQuant
from .module_tracer import (
Patcher,
active_module_tracer,
......@@ -613,7 +619,8 @@ def _wrapped_function(orig_func):
if isinstance(i, (RawTensor, NodeMixin)):
NodeMixin.wrap_safe(i, Constant.make(i))
meth_name = _get_meth_name(args[0], wrapped_fn) if args else None
if meth_name:
arg_type = args[0] if isinstance(args[0], type) else type(args[0])
if meth_name and issubclass(arg_type, RawTensor):
self = inputs[0]
if meth_name == "__new__":
if all([not isinstance(i, RawTensor) for i in inputs]):
......@@ -680,7 +687,15 @@ class TracedModuleBuilder(NodeMixin):
self._mod = mod
self._body = None
self._is_top = is_top_module
self._is_builtin = module_tracer.is_builtin(mod)
self._is_builtin = (
True
if isinstance(mod, (Observer, _FakeQuantize))
else module_tracer.is_builtin(mod)
)
if isinstance(self._mod, QATModule):
unset_module_tracing()
self._check_qat_module(self._mod)
set_module_tracing()
self._argdef_graph_map = {}
self._argdef_outdef_map = {}
......@@ -693,15 +708,65 @@ class TracedModuleBuilder(NodeMixin):
dict(TracedModuleBuilder.__dict__),
)
def _check_qat_module(self, qat_module):
def isbuiltin(m):
return m is None or module_tracer.is_builtin(m)
if qat_module.with_act:
act_observer = qat_module.act_observer
act_fakequant = qat_module.act_fake_quant
if not isbuiltin(act_observer) or not isbuiltin(act_fakequant):
qparams = (
act_observer.get_qparams()
if hasattr(act_observer, "get_qparams")
else act_fakequant.get_qparams()
)
dtype = (
act_observer.dtype
if hasattr(act_observer, "dtype")
else act_fakequant.dtype
)
qat_module.act_observer = None
qat_module.act_fake_quant = TM_FakeQuant(dtype)
qat_module.act_fake_quant.set_qparams(qparams)
if qat_module.with_weight:
weight_observer = qat_module.weight_observer
weight_fakequant = qat_module.weight_fake_quant
if not isbuiltin(weight_observer) or not isbuiltin(weight_fakequant):
qparams = (
weight_observer.get_qparams()
if hasattr(weight_observer, "get_qparams")
else weight_fakequant.get_qparams()
)
dtype = (
weight_observer.dtype
if hasattr(weight_observer, "dtype")
else weight_fakequant.dtype
)
qat_module.weight_observer = None
qat_module.weight_fake_quant = TM_FakeQuant(dtype)
qat_module.weight_fake_quant.set_qparams(qparams)
def build(self):
if self._is_builtin or isinstance(self._mod, TracedModule):
if module_tracer.is_builtin(self._mod) or isinstance(
self._mod, TracedModule
):
mod_type = type(self._mod)
else:
assert isinstance(self._mod, (Observer, _FakeQuantize))
mod_type = (
Observer if isinstance(self._mod, Observer) else _FakeQuantize
)
for node in self.nodes:
node.module_type = type(self._mod)
# node._owner = weakref.ref(self._mod)
node.module_type = mod_type
return self._mod
else:
is_qat = isinstance(self._mod, QATModule)
traced_module = TracedModule(
self._is_top, self._argdef_graph_map, self._argdef_outdef_map
self._is_top, self._argdef_graph_map, self._argdef_outdef_map, is_qat
)
for _, g in self._argdef_graph_map.items():
g.compile()
......@@ -712,6 +777,20 @@ class TracedModuleBuilder(NodeMixin):
v = v.build()
setattr(traced_module, k, v)
if isinstance(self._mod, QATModule):
unset_module_tracing()
traced_module.with_act = self._mod.with_act
traced_module.with_weight = self._mod.with_weight
if not hasattr(traced_module, "act_fake_quant"):
traced_module.act_fakequant = None
if not hasattr(traced_module, "act_observer"):
traced_module.act_observer = None
if not hasattr(traced_module, "weight_fake_quant"):
traced_module.weight_fakequant = None
if not hasattr(traced_module, "weight_observer"):
traced_module.weight_observer = None
set_module_tracing()
return traced_module
def _record_wrapped_nodes(self, node):
......@@ -846,7 +925,8 @@ class TracedModuleBuilder(NodeMixin):
attr = getattr(self._mod, name)
if isinstance(attr, Module):
attr = TracedModuleBuilder(attr)
setattr(self, name, attr)
if isinstance(attr, (Module, RawTensor)):
setattr(self, name, attr)
NodeMixin.wrap(
attr,
lambda: GetAttr.make(
......@@ -1066,7 +1146,7 @@ class TracedModule(Module):
argdef_graph_map = None
argdef_outdef_map = None
def __init__(self, is_top, argdef_graph_map, argdef_outdef_map):
def __init__(self, is_top, argdef_graph_map, argdef_outdef_map, is_qat=False):
super(TracedModule, self).__init__()
self.argdef_graph_map = argdef_graph_map
self.argdef_outdef_map = argdef_outdef_map
......@@ -1074,6 +1154,7 @@ class TracedModule(Module):
self.watch_points = []
self.watch_node_value = {}
self.end_points = []
self.is_qat = is_qat
def forward(self, *args, **kwargs):
inputs, treedef = tree_flatten(((self, *args), kwargs))
......@@ -1195,8 +1276,8 @@ class TracedModule(Module):
):
if graph is not None and prefix_name and prefix_name[-1] != "_":
prefix_name += "_"
if graph is None:
assert not isinstance(module, TracedModule)
if graph is None or module.is_qat:
assert not isinstance(module, TracedModule) or module.is_qat
const = Constant(module, "self.%s" % module2name[id(module)])
m_node = call.inputs[0]
if m_node.top_graph != active_module_tracer().current_scope():
......@@ -1326,9 +1407,23 @@ def _register_all_builtin_module():
isclass(m[1])
and issubclass(m[1], M.Module)
and m[1] is not M.Sequential
and m[1] is not M.ModuleList
):
module_tracer.register_as_builtin(m[1])
module_tracer.register_as_builtin(Observer)
module_tracer.register_as_builtin(MinMaxObserver)
module_tracer.register_as_builtin(SyncMinMaxObserver)
module_tracer.register_as_builtin(ExponentialMovingAverageObserver)
module_tracer.register_as_builtin(SyncExponentialMovingAverageObserver)
module_tracer.register_as_builtin(HistogramObserver)
module_tracer.register_as_builtin(PassiveObserver)
module_tracer.register_as_builtin(LSQ)
module_tracer.register_as_builtin(TQT)
module_tracer.register_as_builtin(FakeQuantize)
module_tracer.register_as_builtin(TM_FakeQuant)
def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule:
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册