From 7c4f1a3851c1fdbea4aa5ca5f15f19dcc7bf4476 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Sat, 9 May 2020 14:55:30 +0800 Subject: [PATCH] feat(mge/quantization): add calibration support GitOrigin-RevId: f16fbba2b7cbc6138c4382fcb96b70f5eb71074c --- python_module/megengine/module/module.py | 17 +++++++-------- .../megengine/module/quantized/concat.py | 11 +++------- .../module/quantized/conv_bn_relu.py | 11 ++++------ .../megengine/module/quantized/elemwise.py | 11 +++------- .../module/quantized/quant_dequant.py | 14 ++++--------- .../megengine/quantization/__init__.py | 1 + .../megengine/quantization/observer.py | 2 +- .../megengine/quantization/quantize.py | 21 +++++++++++++++++++ 8 files changed, 45 insertions(+), 43 deletions(-) diff --git a/python_module/megengine/module/module.py b/python_module/megengine/module/module.py index 4ee9f62be..60e77dca0 100644 --- a/python_module/megengine/module/module.py +++ b/python_module/megengine/module/module.py @@ -496,8 +496,11 @@ class QATModule(Module): self, target: Tensor, fq: "FakeQuantize", obs: "Observer" ): oup = self.apply_observer(target, obs) - scale, zero_point = obs.get_qparams() - return fq(oup, scale, zero_point) + if self.quantizing == self.QATMode.CALIBRATION: + return oup + else: + scale, zero_point = obs.get_qparams() + return fq(oup, scale, zero_point) def set_qat_mode(self, mode: QATMode): r""" @@ -524,11 +527,7 @@ class QATModule(Module): """ def __call__(self, *args, **kwargs): - if self.quantizing == self.QATMode.QAT: - return self.forward_qat(*args, **kwargs) - elif self.quantizing == self.QATMode.CALIBRATION: - # TODO implement the CALIBRATION - assert False - return None - else: + if self.quantizing == self.QATMode.DISABLED: return self.forward(*args, **kwargs) + else: + return self.forward_qat(*args, **kwargs) diff --git a/python_module/megengine/module/quantized/concat.py b/python_module/megengine/module/quantized/concat.py index 62a7778a8..f3f266a1d 100644 --- a/python_module/megengine/module/quantized/concat.py +++ b/python_module/megengine/module/quantized/concat.py @@ -20,11 +20,9 @@ class Concat(Module): A :class:`~.Module` to do quantized concat, inference only. """ - def __init__(self): + def __init__(self, dtype=None): super().__init__() - self.scale = 1.0 - self.zero_point = 0.0 - self.output_dtype = mgb.dtype.qint8(self.scale) + self.output_dtype = dtype def forward(self, inps: Iterable[Tensor], axis: int = 0): if self.training: @@ -39,7 +37,4 @@ def to_quantized(float_module): Replace :class:`~.module.QATModule`'s ``to_quantized`` method. implemented here to avoid circular import. """ - qmod = Concat() - qmod.output_dtype = float_module.act_observer.get_dtype() - qmod.scale, qmod.zero_point = float_module.act_observer.get_qparams() - return qmod + return Concat(float_module.act_observer.get_dtype()) diff --git a/python_module/megengine/module/quantized/conv_bn_relu.py b/python_module/megengine/module/quantized/conv_bn_relu.py index dfc502a72..18eddaa8f 100644 --- a/python_module/megengine/module/quantized/conv_bn_relu.py +++ b/python_module/megengine/module/quantized/conv_bn_relu.py @@ -34,6 +34,7 @@ class _ConvBnActivation2d(Conv2d): groups: int = 1, conv_mode: str = "CROSS_CORRELATION", compute_mode: str = "DEFAULT", + dtype=None, ): super().__init__( in_channels, @@ -47,11 +48,7 @@ class _ConvBnActivation2d(Conv2d): conv_mode, compute_mode, ) - self.scale = 1.0 - self.zero_point = 0.0 - self.output_dtype = mgb.dtype.qint8(self.scale) - self.weight = self.weight.astype(self.output_dtype) - self.bias = self.bias.astype(mgb.dtype.qint32(self.scale)) + self.output_dtype = dtype def calc_conv_quantized(self, inp, nonlinear_mode="IDENTITY"): inp_scale = mgb.dtype.get_scale(inp.dtype) @@ -87,6 +84,7 @@ class ConvBnRelu2d(_ConvBnActivation2d): def to_quantized(quantized_class, float_module): + output_dtype = float_module.act_observer.get_dtype() qconv = quantized_class( float_module.conv.in_channels, float_module.conv.out_channels, @@ -95,15 +93,14 @@ def to_quantized(quantized_class, float_module): float_module.conv.padding, float_module.conv.dilation, float_module.conv.groups, + dtype=output_dtype, ) w_fold, b_fold = float_module.fold_weight_bias( float_module.bn.running_mean, float_module.bn.running_var ) weight = w_fold.astype(float_module.weight_observer.get_dtype()) - qconv.output_dtype = float_module.act_observer.get_dtype() qconv.weight = Parameter(weight.numpy()) qconv.bias = Parameter(b_fold.numpy()) - qconv.scale, qconv.zero_point = float_module.act_observer.get_qparams() return qconv diff --git a/python_module/megengine/module/quantized/elemwise.py b/python_module/megengine/module/quantized/elemwise.py index 9a03ac9a1..47f30e47a 100644 --- a/python_module/megengine/module/quantized/elemwise.py +++ b/python_module/megengine/module/quantized/elemwise.py @@ -34,12 +34,10 @@ class Elemwise(Module): _elemwise_multi_type_mode = mgb.opr_param_defs.ElemwiseMultiType.Mode - def __init__(self, method): + def __init__(self, method, dtype=None): super().__init__() self.method = self._elemwise_multi_type_mode.convert("Q" + method) - self.scale = 1.0 - self.zero_point = 0.0 - self.output_dtype = mgb.dtype.qint8(self.scale) + self.output_dtype = dtype def forward(self, *inps): if self.training: @@ -53,7 +51,4 @@ def to_quantized(float_module): Replace :class:`~.module.QATModule`'s ``to_quantized`` method. implemented here to avoid circular import. """ - qmod = Elemwise(float_module.method.name) - qmod.output_dtype = float_module.act_observer.get_dtype() - qmod.scale, qmod.zero_point = float_module.act_observer.get_qparams() - return qmod + return Elemwise(float_module.method.name, float_module.act_observer.get_dtype()) diff --git a/python_module/megengine/module/quantized/quant_dequant.py b/python_module/megengine/module/quantized/quant_dequant.py index 5faf92387..5a91b6fd1 100644 --- a/python_module/megengine/module/quantized/quant_dequant.py +++ b/python_module/megengine/module/quantized/quant_dequant.py @@ -16,11 +16,9 @@ class QuantStub(Module): A helper quantize operation on input and inference only. """ - def __init__(self): + def __init__(self, dtype=None): super().__init__() - self.scale = 1.0 - self.zero_point = 0.0 - self.output_dtype = mgb.dtype.qint8(self.scale) + self.output_dtype = dtype def forward(self, inp): if self.training: @@ -45,10 +43,7 @@ def to_quantized(float_module): Replace :class:`~.module.QATModule`'s ``to_quantized`` method. implemented here to avoid circular import. """ - qmod = QuantStub() - qmod.output_dtype = float_module.act_observer.get_dtype() - qmod.scale, qmod.zero_point = float_module.act_observer.get_qparams() - return qmod + return QuantStub(float_module.act_observer.get_dtype()) @register_method_to_class(Float.DequantStub) @@ -57,5 +52,4 @@ def to_quantized(float_module): Replace :class:`~.module.QATModule`'s ``to_quantized`` method. implemented here to avoid circular import. """ - qmod = DequantStub() - return qmod + return DequantStub() diff --git a/python_module/megengine/quantization/__init__.py b/python_module/megengine/quantization/__init__.py index 46145bd8f..9d490be8b 100644 --- a/python_module/megengine/quantization/__init__.py +++ b/python_module/megengine/quantization/__init__.py @@ -14,5 +14,6 @@ from .quantize import ( enable_fake_quant, enable_observer, quantize, + quantize_calibration, quantize_qat, ) diff --git a/python_module/megengine/quantization/observer.py b/python_module/megengine/quantization/observer.py index 3c4484e61..b6799e790 100644 --- a/python_module/megengine/quantization/observer.py +++ b/python_module/megengine/quantization/observer.py @@ -11,7 +11,7 @@ import numpy as np from .. import functional as F from .._internal.dtype import _metadata_dict, get_quantized_dtype -from ..core import Buffer, Function, ones, tensor, zeros +from ..core import Buffer, Function, tensor from ..module import Module diff --git a/python_module/megengine/quantization/quantize.py b/python_module/megengine/quantization/quantize.py index 1ce5c953f..1bfba352a 100644 --- a/python_module/megengine/quantization/quantize.py +++ b/python_module/megengine/quantization/quantize.py @@ -34,6 +34,8 @@ def quantize(module: Module, inplace=True): else: setattr(parent, key.split(".")[-1], submodule.to_quantized()) + return module + def quantize_qat(module: Module, qconfig: QConfig = ema_fakequant_qconfig): r""" @@ -53,6 +55,25 @@ def quantize_qat(module: Module, qconfig: QConfig = ema_fakequant_qconfig): module.apply(fn) +def quantize_calibration(module: Module, qconfig: QConfig = ema_fakequant_qconfig): + r""" + Recursively convert `module` to `calibration` mode through :meth:`~.Module.apply` + and set qconfig relatively. + + :param module: root module to do convert recursively. + :param qconfig: a instance of :class:`~.QConfig` to be set as submodules' qconfig. + default is :any:`~.qconfig.ema_fakequant_qconfig`. + """ + + def fn(mod: Module): + if isinstance(mod, QATModule): + mod.set_qat_mode(QATModule.QATMode.CALIBRATION) + mod.set_qconfig(qconfig) + + module.apply(fn) + enable_observer(module) + + def disable_fake_quant(module: Module): r""" Recursively disable `module` fake quantization in QATModule through :meth:`~.Module.apply` -- GitLab