diff --git a/python_module/megengine/module/module.py b/python_module/megengine/module/module.py index 4ee9f62be5ab2fc02da0cbd7b45d262fc7e6f302..60e77dca0a7f53476e6189ed8315fd3163cb9196 100644 --- a/python_module/megengine/module/module.py +++ b/python_module/megengine/module/module.py @@ -496,8 +496,11 @@ class QATModule(Module): self, target: Tensor, fq: "FakeQuantize", obs: "Observer" ): oup = self.apply_observer(target, obs) - scale, zero_point = obs.get_qparams() - return fq(oup, scale, zero_point) + if self.quantizing == self.QATMode.CALIBRATION: + return oup + else: + scale, zero_point = obs.get_qparams() + return fq(oup, scale, zero_point) def set_qat_mode(self, mode: QATMode): r""" @@ -524,11 +527,7 @@ class QATModule(Module): """ def __call__(self, *args, **kwargs): - if self.quantizing == self.QATMode.QAT: - return self.forward_qat(*args, **kwargs) - elif self.quantizing == self.QATMode.CALIBRATION: - # TODO implement the CALIBRATION - assert False - return None - else: + if self.quantizing == self.QATMode.DISABLED: return self.forward(*args, **kwargs) + else: + return self.forward_qat(*args, **kwargs) diff --git a/python_module/megengine/module/quantized/concat.py b/python_module/megengine/module/quantized/concat.py index 62a7778a8923efb089dec2dd3799f466fff1008c..f3f266a1d5615962c76f8927324065d0857b2b67 100644 --- a/python_module/megengine/module/quantized/concat.py +++ b/python_module/megengine/module/quantized/concat.py @@ -20,11 +20,9 @@ class Concat(Module): A :class:`~.Module` to do quantized concat, inference only. """ - def __init__(self): + def __init__(self, dtype=None): super().__init__() - self.scale = 1.0 - self.zero_point = 0.0 - self.output_dtype = mgb.dtype.qint8(self.scale) + self.output_dtype = dtype def forward(self, inps: Iterable[Tensor], axis: int = 0): if self.training: @@ -39,7 +37,4 @@ def to_quantized(float_module): Replace :class:`~.module.QATModule`'s ``to_quantized`` method. implemented here to avoid circular import. """ - qmod = Concat() - qmod.output_dtype = float_module.act_observer.get_dtype() - qmod.scale, qmod.zero_point = float_module.act_observer.get_qparams() - return qmod + return Concat(float_module.act_observer.get_dtype()) diff --git a/python_module/megengine/module/quantized/conv_bn_relu.py b/python_module/megengine/module/quantized/conv_bn_relu.py index dfc502a729e14478bebbfa51c1215e2663c35a95..18eddaa8f24add3c47ad16e9a4e8a962e02e8c88 100644 --- a/python_module/megengine/module/quantized/conv_bn_relu.py +++ b/python_module/megengine/module/quantized/conv_bn_relu.py @@ -34,6 +34,7 @@ class _ConvBnActivation2d(Conv2d): groups: int = 1, conv_mode: str = "CROSS_CORRELATION", compute_mode: str = "DEFAULT", + dtype=None, ): super().__init__( in_channels, @@ -47,11 +48,7 @@ class _ConvBnActivation2d(Conv2d): conv_mode, compute_mode, ) - self.scale = 1.0 - self.zero_point = 0.0 - self.output_dtype = mgb.dtype.qint8(self.scale) - self.weight = self.weight.astype(self.output_dtype) - self.bias = self.bias.astype(mgb.dtype.qint32(self.scale)) + self.output_dtype = dtype def calc_conv_quantized(self, inp, nonlinear_mode="IDENTITY"): inp_scale = mgb.dtype.get_scale(inp.dtype) @@ -87,6 +84,7 @@ class ConvBnRelu2d(_ConvBnActivation2d): def to_quantized(quantized_class, float_module): + output_dtype = float_module.act_observer.get_dtype() qconv = quantized_class( float_module.conv.in_channels, float_module.conv.out_channels, @@ -95,15 +93,14 @@ def to_quantized(quantized_class, float_module): float_module.conv.padding, float_module.conv.dilation, float_module.conv.groups, + dtype=output_dtype, ) w_fold, b_fold = float_module.fold_weight_bias( float_module.bn.running_mean, float_module.bn.running_var ) weight = w_fold.astype(float_module.weight_observer.get_dtype()) - qconv.output_dtype = float_module.act_observer.get_dtype() qconv.weight = Parameter(weight.numpy()) qconv.bias = Parameter(b_fold.numpy()) - qconv.scale, qconv.zero_point = float_module.act_observer.get_qparams() return qconv diff --git a/python_module/megengine/module/quantized/elemwise.py b/python_module/megengine/module/quantized/elemwise.py index 9a03ac9a1811f5d4ea766b4712991f101909aba7..47f30e47a7365ffea4fc951e53f7190fb8b98aa8 100644 --- a/python_module/megengine/module/quantized/elemwise.py +++ b/python_module/megengine/module/quantized/elemwise.py @@ -34,12 +34,10 @@ class Elemwise(Module): _elemwise_multi_type_mode = mgb.opr_param_defs.ElemwiseMultiType.Mode - def __init__(self, method): + def __init__(self, method, dtype=None): super().__init__() self.method = self._elemwise_multi_type_mode.convert("Q" + method) - self.scale = 1.0 - self.zero_point = 0.0 - self.output_dtype = mgb.dtype.qint8(self.scale) + self.output_dtype = dtype def forward(self, *inps): if self.training: @@ -53,7 +51,4 @@ def to_quantized(float_module): Replace :class:`~.module.QATModule`'s ``to_quantized`` method. implemented here to avoid circular import. """ - qmod = Elemwise(float_module.method.name) - qmod.output_dtype = float_module.act_observer.get_dtype() - qmod.scale, qmod.zero_point = float_module.act_observer.get_qparams() - return qmod + return Elemwise(float_module.method.name, float_module.act_observer.get_dtype()) diff --git a/python_module/megengine/module/quantized/quant_dequant.py b/python_module/megengine/module/quantized/quant_dequant.py index 5faf923874130584a1fc51bebe94978530114f5c..5a91b6fd19dfde2acd359b340a4a3aae76c8417f 100644 --- a/python_module/megengine/module/quantized/quant_dequant.py +++ b/python_module/megengine/module/quantized/quant_dequant.py @@ -16,11 +16,9 @@ class QuantStub(Module): A helper quantize operation on input and inference only. """ - def __init__(self): + def __init__(self, dtype=None): super().__init__() - self.scale = 1.0 - self.zero_point = 0.0 - self.output_dtype = mgb.dtype.qint8(self.scale) + self.output_dtype = dtype def forward(self, inp): if self.training: @@ -45,10 +43,7 @@ def to_quantized(float_module): Replace :class:`~.module.QATModule`'s ``to_quantized`` method. implemented here to avoid circular import. """ - qmod = QuantStub() - qmod.output_dtype = float_module.act_observer.get_dtype() - qmod.scale, qmod.zero_point = float_module.act_observer.get_qparams() - return qmod + return QuantStub(float_module.act_observer.get_dtype()) @register_method_to_class(Float.DequantStub) @@ -57,5 +52,4 @@ def to_quantized(float_module): Replace :class:`~.module.QATModule`'s ``to_quantized`` method. implemented here to avoid circular import. """ - qmod = DequantStub() - return qmod + return DequantStub() diff --git a/python_module/megengine/quantization/__init__.py b/python_module/megengine/quantization/__init__.py index 46145bd8f09c5d3debf1a2b1c4f1dacfc3ccb6a1..9d490be8b0ea019ddc07ad056d29fc757efb7d5a 100644 --- a/python_module/megengine/quantization/__init__.py +++ b/python_module/megengine/quantization/__init__.py @@ -14,5 +14,6 @@ from .quantize import ( enable_fake_quant, enable_observer, quantize, + quantize_calibration, quantize_qat, ) diff --git a/python_module/megengine/quantization/observer.py b/python_module/megengine/quantization/observer.py index 3c4484e6190eec3b17ddaec31d839de128143513..b6799e790744cb8a8568cf3a3ec435fba016d42f 100644 --- a/python_module/megengine/quantization/observer.py +++ b/python_module/megengine/quantization/observer.py @@ -11,7 +11,7 @@ import numpy as np from .. import functional as F from .._internal.dtype import _metadata_dict, get_quantized_dtype -from ..core import Buffer, Function, ones, tensor, zeros +from ..core import Buffer, Function, tensor from ..module import Module diff --git a/python_module/megengine/quantization/quantize.py b/python_module/megengine/quantization/quantize.py index 1ce5c953f9cdba494c5220d49c8787fe0719e942..1bfba352a322f41808d67db6d95807db51abfba6 100644 --- a/python_module/megengine/quantization/quantize.py +++ b/python_module/megengine/quantization/quantize.py @@ -34,6 +34,8 @@ def quantize(module: Module, inplace=True): else: setattr(parent, key.split(".")[-1], submodule.to_quantized()) + return module + def quantize_qat(module: Module, qconfig: QConfig = ema_fakequant_qconfig): r""" @@ -53,6 +55,25 @@ def quantize_qat(module: Module, qconfig: QConfig = ema_fakequant_qconfig): module.apply(fn) +def quantize_calibration(module: Module, qconfig: QConfig = ema_fakequant_qconfig): + r""" + Recursively convert `module` to `calibration` mode through :meth:`~.Module.apply` + and set qconfig relatively. + + :param module: root module to do convert recursively. + :param qconfig: a instance of :class:`~.QConfig` to be set as submodules' qconfig. + default is :any:`~.qconfig.ema_fakequant_qconfig`. + """ + + def fn(mod: Module): + if isinstance(mod, QATModule): + mod.set_qat_mode(QATModule.QATMode.CALIBRATION) + mod.set_qconfig(qconfig) + + module.apply(fn) + enable_observer(module) + + def disable_fake_quant(module: Module): r""" Recursively disable `module` fake quantization in QATModule through :meth:`~.Module.apply`