提交 7c4f1a38 编写于 作者: M Megvii Engine Team

feat(mge/quantization): add calibration support

GitOrigin-RevId: f16fbba2b7cbc6138c4382fcb96b70f5eb71074c
上级 5eca4da3
......@@ -496,6 +496,9 @@ class QATModule(Module):
self, target: Tensor, fq: "FakeQuantize", obs: "Observer"
):
oup = self.apply_observer(target, obs)
if self.quantizing == self.QATMode.CALIBRATION:
return oup
else:
scale, zero_point = obs.get_qparams()
return fq(oup, scale, zero_point)
......@@ -524,11 +527,7 @@ class QATModule(Module):
"""
def __call__(self, *args, **kwargs):
if self.quantizing == self.QATMode.QAT:
return self.forward_qat(*args, **kwargs)
elif self.quantizing == self.QATMode.CALIBRATION:
# TODO implement the CALIBRATION
assert False
return None
else:
if self.quantizing == self.QATMode.DISABLED:
return self.forward(*args, **kwargs)
else:
return self.forward_qat(*args, **kwargs)
......@@ -20,11 +20,9 @@ class Concat(Module):
A :class:`~.Module` to do quantized concat, inference only.
"""
def __init__(self):
def __init__(self, dtype=None):
super().__init__()
self.scale = 1.0
self.zero_point = 0.0
self.output_dtype = mgb.dtype.qint8(self.scale)
self.output_dtype = dtype
def forward(self, inps: Iterable[Tensor], axis: int = 0):
if self.training:
......@@ -39,7 +37,4 @@ def to_quantized(float_module):
Replace :class:`~.module.QATModule`'s ``to_quantized`` method.
implemented here to avoid circular import.
"""
qmod = Concat()
qmod.output_dtype = float_module.act_observer.get_dtype()
qmod.scale, qmod.zero_point = float_module.act_observer.get_qparams()
return qmod
return Concat(float_module.act_observer.get_dtype())
......@@ -34,6 +34,7 @@ class _ConvBnActivation2d(Conv2d):
groups: int = 1,
conv_mode: str = "CROSS_CORRELATION",
compute_mode: str = "DEFAULT",
dtype=None,
):
super().__init__(
in_channels,
......@@ -47,11 +48,7 @@ class _ConvBnActivation2d(Conv2d):
conv_mode,
compute_mode,
)
self.scale = 1.0
self.zero_point = 0.0
self.output_dtype = mgb.dtype.qint8(self.scale)
self.weight = self.weight.astype(self.output_dtype)
self.bias = self.bias.astype(mgb.dtype.qint32(self.scale))
self.output_dtype = dtype
def calc_conv_quantized(self, inp, nonlinear_mode="IDENTITY"):
inp_scale = mgb.dtype.get_scale(inp.dtype)
......@@ -87,6 +84,7 @@ class ConvBnRelu2d(_ConvBnActivation2d):
def to_quantized(quantized_class, float_module):
output_dtype = float_module.act_observer.get_dtype()
qconv = quantized_class(
float_module.conv.in_channels,
float_module.conv.out_channels,
......@@ -95,15 +93,14 @@ def to_quantized(quantized_class, float_module):
float_module.conv.padding,
float_module.conv.dilation,
float_module.conv.groups,
dtype=output_dtype,
)
w_fold, b_fold = float_module.fold_weight_bias(
float_module.bn.running_mean, float_module.bn.running_var
)
weight = w_fold.astype(float_module.weight_observer.get_dtype())
qconv.output_dtype = float_module.act_observer.get_dtype()
qconv.weight = Parameter(weight.numpy())
qconv.bias = Parameter(b_fold.numpy())
qconv.scale, qconv.zero_point = float_module.act_observer.get_qparams()
return qconv
......
......@@ -34,12 +34,10 @@ class Elemwise(Module):
_elemwise_multi_type_mode = mgb.opr_param_defs.ElemwiseMultiType.Mode
def __init__(self, method):
def __init__(self, method, dtype=None):
super().__init__()
self.method = self._elemwise_multi_type_mode.convert("Q" + method)
self.scale = 1.0
self.zero_point = 0.0
self.output_dtype = mgb.dtype.qint8(self.scale)
self.output_dtype = dtype
def forward(self, *inps):
if self.training:
......@@ -53,7 +51,4 @@ def to_quantized(float_module):
Replace :class:`~.module.QATModule`'s ``to_quantized`` method.
implemented here to avoid circular import.
"""
qmod = Elemwise(float_module.method.name)
qmod.output_dtype = float_module.act_observer.get_dtype()
qmod.scale, qmod.zero_point = float_module.act_observer.get_qparams()
return qmod
return Elemwise(float_module.method.name, float_module.act_observer.get_dtype())
......@@ -16,11 +16,9 @@ class QuantStub(Module):
A helper quantize operation on input and inference only.
"""
def __init__(self):
def __init__(self, dtype=None):
super().__init__()
self.scale = 1.0
self.zero_point = 0.0
self.output_dtype = mgb.dtype.qint8(self.scale)
self.output_dtype = dtype
def forward(self, inp):
if self.training:
......@@ -45,10 +43,7 @@ def to_quantized(float_module):
Replace :class:`~.module.QATModule`'s ``to_quantized`` method.
implemented here to avoid circular import.
"""
qmod = QuantStub()
qmod.output_dtype = float_module.act_observer.get_dtype()
qmod.scale, qmod.zero_point = float_module.act_observer.get_qparams()
return qmod
return QuantStub(float_module.act_observer.get_dtype())
@register_method_to_class(Float.DequantStub)
......@@ -57,5 +52,4 @@ def to_quantized(float_module):
Replace :class:`~.module.QATModule`'s ``to_quantized`` method.
implemented here to avoid circular import.
"""
qmod = DequantStub()
return qmod
return DequantStub()
......@@ -14,5 +14,6 @@ from .quantize import (
enable_fake_quant,
enable_observer,
quantize,
quantize_calibration,
quantize_qat,
)
......@@ -11,7 +11,7 @@ import numpy as np
from .. import functional as F
from .._internal.dtype import _metadata_dict, get_quantized_dtype
from ..core import Buffer, Function, ones, tensor, zeros
from ..core import Buffer, Function, tensor
from ..module import Module
......
......@@ -34,6 +34,8 @@ def quantize(module: Module, inplace=True):
else:
setattr(parent, key.split(".")[-1], submodule.to_quantized())
return module
def quantize_qat(module: Module, qconfig: QConfig = ema_fakequant_qconfig):
r"""
......@@ -53,6 +55,25 @@ def quantize_qat(module: Module, qconfig: QConfig = ema_fakequant_qconfig):
module.apply(fn)
def quantize_calibration(module: Module, qconfig: QConfig = ema_fakequant_qconfig):
r"""
Recursively convert `module` to `calibration` mode through :meth:`~.Module.apply`
and set qconfig relatively.
:param module: root module to do convert recursively.
:param qconfig: a instance of :class:`~.QConfig` to be set as submodules' qconfig.
default is :any:`~.qconfig.ema_fakequant_qconfig`.
"""
def fn(mod: Module):
if isinstance(mod, QATModule):
mod.set_qat_mode(QATModule.QATMode.CALIBRATION)
mod.set_qconfig(qconfig)
module.apply(fn)
enable_observer(module)
def disable_fake_quant(module: Module):
r"""
Recursively disable `module` fake quantization in QATModule through :meth:`~.Module.apply`
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册