提交 7c4f1a38 编写于 作者: M Megvii Engine Team

feat(mge/quantization): add calibration support

GitOrigin-RevId: f16fbba2b7cbc6138c4382fcb96b70f5eb71074c
上级 5eca4da3
...@@ -496,8 +496,11 @@ class QATModule(Module): ...@@ -496,8 +496,11 @@ class QATModule(Module):
self, target: Tensor, fq: "FakeQuantize", obs: "Observer" self, target: Tensor, fq: "FakeQuantize", obs: "Observer"
): ):
oup = self.apply_observer(target, obs) oup = self.apply_observer(target, obs)
scale, zero_point = obs.get_qparams() if self.quantizing == self.QATMode.CALIBRATION:
return fq(oup, scale, zero_point) return oup
else:
scale, zero_point = obs.get_qparams()
return fq(oup, scale, zero_point)
def set_qat_mode(self, mode: QATMode): def set_qat_mode(self, mode: QATMode):
r""" r"""
...@@ -524,11 +527,7 @@ class QATModule(Module): ...@@ -524,11 +527,7 @@ class QATModule(Module):
""" """
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
if self.quantizing == self.QATMode.QAT: if self.quantizing == self.QATMode.DISABLED:
return self.forward_qat(*args, **kwargs)
elif self.quantizing == self.QATMode.CALIBRATION:
# TODO implement the CALIBRATION
assert False
return None
else:
return self.forward(*args, **kwargs) return self.forward(*args, **kwargs)
else:
return self.forward_qat(*args, **kwargs)
...@@ -20,11 +20,9 @@ class Concat(Module): ...@@ -20,11 +20,9 @@ class Concat(Module):
A :class:`~.Module` to do quantized concat, inference only. A :class:`~.Module` to do quantized concat, inference only.
""" """
def __init__(self): def __init__(self, dtype=None):
super().__init__() super().__init__()
self.scale = 1.0 self.output_dtype = dtype
self.zero_point = 0.0
self.output_dtype = mgb.dtype.qint8(self.scale)
def forward(self, inps: Iterable[Tensor], axis: int = 0): def forward(self, inps: Iterable[Tensor], axis: int = 0):
if self.training: if self.training:
...@@ -39,7 +37,4 @@ def to_quantized(float_module): ...@@ -39,7 +37,4 @@ def to_quantized(float_module):
Replace :class:`~.module.QATModule`'s ``to_quantized`` method. Replace :class:`~.module.QATModule`'s ``to_quantized`` method.
implemented here to avoid circular import. implemented here to avoid circular import.
""" """
qmod = Concat() return Concat(float_module.act_observer.get_dtype())
qmod.output_dtype = float_module.act_observer.get_dtype()
qmod.scale, qmod.zero_point = float_module.act_observer.get_qparams()
return qmod
...@@ -34,6 +34,7 @@ class _ConvBnActivation2d(Conv2d): ...@@ -34,6 +34,7 @@ class _ConvBnActivation2d(Conv2d):
groups: int = 1, groups: int = 1,
conv_mode: str = "CROSS_CORRELATION", conv_mode: str = "CROSS_CORRELATION",
compute_mode: str = "DEFAULT", compute_mode: str = "DEFAULT",
dtype=None,
): ):
super().__init__( super().__init__(
in_channels, in_channels,
...@@ -47,11 +48,7 @@ class _ConvBnActivation2d(Conv2d): ...@@ -47,11 +48,7 @@ class _ConvBnActivation2d(Conv2d):
conv_mode, conv_mode,
compute_mode, compute_mode,
) )
self.scale = 1.0 self.output_dtype = dtype
self.zero_point = 0.0
self.output_dtype = mgb.dtype.qint8(self.scale)
self.weight = self.weight.astype(self.output_dtype)
self.bias = self.bias.astype(mgb.dtype.qint32(self.scale))
def calc_conv_quantized(self, inp, nonlinear_mode="IDENTITY"): def calc_conv_quantized(self, inp, nonlinear_mode="IDENTITY"):
inp_scale = mgb.dtype.get_scale(inp.dtype) inp_scale = mgb.dtype.get_scale(inp.dtype)
...@@ -87,6 +84,7 @@ class ConvBnRelu2d(_ConvBnActivation2d): ...@@ -87,6 +84,7 @@ class ConvBnRelu2d(_ConvBnActivation2d):
def to_quantized(quantized_class, float_module): def to_quantized(quantized_class, float_module):
output_dtype = float_module.act_observer.get_dtype()
qconv = quantized_class( qconv = quantized_class(
float_module.conv.in_channels, float_module.conv.in_channels,
float_module.conv.out_channels, float_module.conv.out_channels,
...@@ -95,15 +93,14 @@ def to_quantized(quantized_class, float_module): ...@@ -95,15 +93,14 @@ def to_quantized(quantized_class, float_module):
float_module.conv.padding, float_module.conv.padding,
float_module.conv.dilation, float_module.conv.dilation,
float_module.conv.groups, float_module.conv.groups,
dtype=output_dtype,
) )
w_fold, b_fold = float_module.fold_weight_bias( w_fold, b_fold = float_module.fold_weight_bias(
float_module.bn.running_mean, float_module.bn.running_var float_module.bn.running_mean, float_module.bn.running_var
) )
weight = w_fold.astype(float_module.weight_observer.get_dtype()) weight = w_fold.astype(float_module.weight_observer.get_dtype())
qconv.output_dtype = float_module.act_observer.get_dtype()
qconv.weight = Parameter(weight.numpy()) qconv.weight = Parameter(weight.numpy())
qconv.bias = Parameter(b_fold.numpy()) qconv.bias = Parameter(b_fold.numpy())
qconv.scale, qconv.zero_point = float_module.act_observer.get_qparams()
return qconv return qconv
......
...@@ -34,12 +34,10 @@ class Elemwise(Module): ...@@ -34,12 +34,10 @@ class Elemwise(Module):
_elemwise_multi_type_mode = mgb.opr_param_defs.ElemwiseMultiType.Mode _elemwise_multi_type_mode = mgb.opr_param_defs.ElemwiseMultiType.Mode
def __init__(self, method): def __init__(self, method, dtype=None):
super().__init__() super().__init__()
self.method = self._elemwise_multi_type_mode.convert("Q" + method) self.method = self._elemwise_multi_type_mode.convert("Q" + method)
self.scale = 1.0 self.output_dtype = dtype
self.zero_point = 0.0
self.output_dtype = mgb.dtype.qint8(self.scale)
def forward(self, *inps): def forward(self, *inps):
if self.training: if self.training:
...@@ -53,7 +51,4 @@ def to_quantized(float_module): ...@@ -53,7 +51,4 @@ def to_quantized(float_module):
Replace :class:`~.module.QATModule`'s ``to_quantized`` method. Replace :class:`~.module.QATModule`'s ``to_quantized`` method.
implemented here to avoid circular import. implemented here to avoid circular import.
""" """
qmod = Elemwise(float_module.method.name) return Elemwise(float_module.method.name, float_module.act_observer.get_dtype())
qmod.output_dtype = float_module.act_observer.get_dtype()
qmod.scale, qmod.zero_point = float_module.act_observer.get_qparams()
return qmod
...@@ -16,11 +16,9 @@ class QuantStub(Module): ...@@ -16,11 +16,9 @@ class QuantStub(Module):
A helper quantize operation on input and inference only. A helper quantize operation on input and inference only.
""" """
def __init__(self): def __init__(self, dtype=None):
super().__init__() super().__init__()
self.scale = 1.0 self.output_dtype = dtype
self.zero_point = 0.0
self.output_dtype = mgb.dtype.qint8(self.scale)
def forward(self, inp): def forward(self, inp):
if self.training: if self.training:
...@@ -45,10 +43,7 @@ def to_quantized(float_module): ...@@ -45,10 +43,7 @@ def to_quantized(float_module):
Replace :class:`~.module.QATModule`'s ``to_quantized`` method. Replace :class:`~.module.QATModule`'s ``to_quantized`` method.
implemented here to avoid circular import. implemented here to avoid circular import.
""" """
qmod = QuantStub() return QuantStub(float_module.act_observer.get_dtype())
qmod.output_dtype = float_module.act_observer.get_dtype()
qmod.scale, qmod.zero_point = float_module.act_observer.get_qparams()
return qmod
@register_method_to_class(Float.DequantStub) @register_method_to_class(Float.DequantStub)
...@@ -57,5 +52,4 @@ def to_quantized(float_module): ...@@ -57,5 +52,4 @@ def to_quantized(float_module):
Replace :class:`~.module.QATModule`'s ``to_quantized`` method. Replace :class:`~.module.QATModule`'s ``to_quantized`` method.
implemented here to avoid circular import. implemented here to avoid circular import.
""" """
qmod = DequantStub() return DequantStub()
return qmod
...@@ -14,5 +14,6 @@ from .quantize import ( ...@@ -14,5 +14,6 @@ from .quantize import (
enable_fake_quant, enable_fake_quant,
enable_observer, enable_observer,
quantize, quantize,
quantize_calibration,
quantize_qat, quantize_qat,
) )
...@@ -11,7 +11,7 @@ import numpy as np ...@@ -11,7 +11,7 @@ import numpy as np
from .. import functional as F from .. import functional as F
from .._internal.dtype import _metadata_dict, get_quantized_dtype from .._internal.dtype import _metadata_dict, get_quantized_dtype
from ..core import Buffer, Function, ones, tensor, zeros from ..core import Buffer, Function, tensor
from ..module import Module from ..module import Module
......
...@@ -34,6 +34,8 @@ def quantize(module: Module, inplace=True): ...@@ -34,6 +34,8 @@ def quantize(module: Module, inplace=True):
else: else:
setattr(parent, key.split(".")[-1], submodule.to_quantized()) setattr(parent, key.split(".")[-1], submodule.to_quantized())
return module
def quantize_qat(module: Module, qconfig: QConfig = ema_fakequant_qconfig): def quantize_qat(module: Module, qconfig: QConfig = ema_fakequant_qconfig):
r""" r"""
...@@ -53,6 +55,25 @@ def quantize_qat(module: Module, qconfig: QConfig = ema_fakequant_qconfig): ...@@ -53,6 +55,25 @@ def quantize_qat(module: Module, qconfig: QConfig = ema_fakequant_qconfig):
module.apply(fn) module.apply(fn)
def quantize_calibration(module: Module, qconfig: QConfig = ema_fakequant_qconfig):
r"""
Recursively convert `module` to `calibration` mode through :meth:`~.Module.apply`
and set qconfig relatively.
:param module: root module to do convert recursively.
:param qconfig: a instance of :class:`~.QConfig` to be set as submodules' qconfig.
default is :any:`~.qconfig.ema_fakequant_qconfig`.
"""
def fn(mod: Module):
if isinstance(mod, QATModule):
mod.set_qat_mode(QATModule.QATMode.CALIBRATION)
mod.set_qconfig(qconfig)
module.apply(fn)
enable_observer(module)
def disable_fake_quant(module: Module): def disable_fake_quant(module: Module):
r""" r"""
Recursively disable `module` fake quantization in QATModule through :meth:`~.Module.apply` Recursively disable `module` fake quantization in QATModule through :meth:`~.Module.apply`
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册