From 555ecea9bcc4b1157a54228a4e57d1291ae69562 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 12 Aug 2020 10:49:51 +0800 Subject: [PATCH] feat(mge/quantization): add bias fakequant support GitOrigin-RevId: a5e953b3fa3e0cf91b03708c26dca4561243504a --- python_module/megengine/core/tensor.py | 7 +- python_module/megengine/module/qat/conv.py | 4 +- python_module/megengine/module/qat/conv_bn.py | 4 +- python_module/megengine/module/qat/linear.py | 4 +- python_module/megengine/module/qat/module.py | 15 ++-- .../megengine/quantization/fake_quant.py | 23 +----- .../megengine/quantization/observer.py | 12 +-- python_module/megengine/quantization/utils.py | 73 ++++++++++++++++++- 8 files changed, 99 insertions(+), 43 deletions(-) diff --git a/python_module/megengine/core/tensor.py b/python_module/megengine/core/tensor.py index b823544bf..f96f90f2f 100644 --- a/python_module/megengine/core/tensor.py +++ b/python_module/megengine/core/tensor.py @@ -138,6 +138,7 @@ class Tensor: def __init__(self, val=None, *, requires_grad=None): self._reset(val, requires_grad=requires_grad) + self.q_dict = {"mode": None, "scale": None, "zero_point": None} def _reset(self, val=None, *, requires_grad=None): self.__sym_override = None @@ -677,9 +678,9 @@ class Tensor: def __deepcopy__(self, memo): """ - Since Tensor have __getstate__ and __setstate__ method, - deepcopy only process the that and ignore the attribute of Parameter. - So we need to add __deepcopy__ method to deepcopy correct attribute. + The default deepcopy will ignore other attributes except those defined at + __getstate__ and __setstate__ method. + So we need to add __deepcopy__ method to deepcopy correct attributes. """ assert (self.__val is not None) and ( self.__sym is None diff --git a/python_module/megengine/module/qat/conv.py b/python_module/megengine/module/qat/conv.py index 489f94cb4..315da839e 100644 --- a/python_module/megengine/module/qat/conv.py +++ b/python_module/megengine/module/qat/conv.py @@ -6,6 +6,7 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from ... import functional as F +from ...quantization.utils import fake_quant_bias from .. import conv as Float from .module import QATModule @@ -18,7 +19,8 @@ class Conv2d(Float.Conv2d, QATModule): def calc_conv_qat(self, inp): w_qat = self.apply_quant_weight(self.weight) - conv = self.calc_conv(inp, w_qat, self.bias) + b_qat = fake_quant_bias(self.bias, inp, w_qat) + conv = self.calc_conv(inp, w_qat, b_qat) return conv @classmethod diff --git a/python_module/megengine/module/qat/conv_bn.py b/python_module/megengine/module/qat/conv_bn.py index 38e120521..9ed6ebab5 100644 --- a/python_module/megengine/module/qat/conv_bn.py +++ b/python_module/megengine/module/qat/conv_bn.py @@ -7,6 +7,7 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from ...core import ones, zeros from ...functional import add_update, relu, sqrt, sum, zero_grad +from ...quantization.utils import fake_quant_bias from .. import conv_bn as Float from .module import QATModule @@ -132,7 +133,8 @@ class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule): b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd w_qat = self.apply_quant_weight(w_fold) - conv = self.conv.calc_conv(inp, w_qat, b_fold) + b_qat = fake_quant_bias(b_fold, inp, w_qat) + conv = self.conv.calc_conv(inp, w_qat, b_qat) if not (self.training and approx): return conv diff --git a/python_module/megengine/module/qat/linear.py b/python_module/megengine/module/qat/linear.py index d8174624f..4067d51c6 100644 --- a/python_module/megengine/module/qat/linear.py +++ b/python_module/megengine/module/qat/linear.py @@ -5,6 +5,7 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from ...quantization.utils import fake_quant_bias from .. import linear as Float from .module import QATModule @@ -23,7 +24,8 @@ class Linear(Float.Linear, QATModule): def forward(self, x): w_qat = self.apply_quant_weight(self.weight) - return self.apply_quant_activation(self._calc_linear(x, w_qat, self.bias),) + b_qat = fake_quant_bias(self.bias, x, w_qat) + return self.apply_quant_activation(self._calc_linear(x, w_qat, b_qat)) @classmethod def from_float_module(cls, float_module: Float.Linear): diff --git a/python_module/megengine/module/qat/module.py b/python_module/megengine/module/qat/module.py index 7eec68658..5c510cca5 100644 --- a/python_module/megengine/module/qat/module.py +++ b/python_module/megengine/module/qat/module.py @@ -73,11 +73,16 @@ class QATModule(Module): if observer is None: return target oup = observer(target) - if fake_quant is None: - return oup - else: - q_dict = observer.get_qparams() - return fake_quant(oup, q_dict) + q_dict = observer.get_qparams() + # do fake quant + if fake_quant is not None: + oup = fake_quant(oup, q_dict) + # use qparams of fake_quant if have. + if hasattr(fake_quant, "get_qparams"): + q_dict = fake_quant.get_qparams() + # set to tensor qparams. + oup.q_dict.update(q_dict) + return oup def apply_quant_weight(self, target: Tensor): r""" diff --git a/python_module/megengine/quantization/fake_quant.py b/python_module/megengine/quantization/fake_quant.py index b2e9d9393..da365f169 100644 --- a/python_module/megengine/quantization/fake_quant.py +++ b/python_module/megengine/quantization/fake_quant.py @@ -15,8 +15,7 @@ from .._internal.dtype import _metadata_dict, get_quantized_dtype from ..core import Buffer, Function, Parameter from ..jit import sideeffect from ..module import Module -from .observer import Round -from .utils import QuantMode, get_qparam_dict +from .utils import QuantMode, Round, fake_quant_tensor, get_qparam_dict class _FakeQuantize(Module): @@ -143,22 +142,4 @@ class FakeQuantize(_FakeQuantize): """ def fake_quant_forward(self, inp, q_dict): - if q_dict["mode"] == QuantMode.SYMMERTIC: - scale = q_dict["scale"] - # Quant - oup = Round()(inp / scale) - # clip - oup = F.minimum(F.maximum(oup, self.qmin), self.qmax) - # DeQuant - oup = (oup) * scale - return oup - else: - scale = q_dict["scale"] - zero_point = q_dict["zero_point"] - # Quant - oup = Round()(inp / scale) + zero_point - # clip - oup = F.minimum(F.maximum(oup, self.qmin), self.qmax) - # DeQuant - oup = (oup - zero_point) * scale - return oup + return fake_quant_tensor(inp, self.qmin, self.qmax, q_dict) diff --git a/python_module/megengine/quantization/observer.py b/python_module/megengine/quantization/observer.py index 33715574e..6aa3a4060 100644 --- a/python_module/megengine/quantization/observer.py +++ b/python_module/megengine/quantization/observer.py @@ -13,18 +13,10 @@ import numpy as np from .. import functional as F from .._internal.dtype import _metadata_dict, get_quantized_dtype -from ..core import Buffer, Function, tensor +from ..core import Buffer from ..jit import sideeffect from ..module import Module -from .utils import QuantMode, get_qparam_dict - - -class Round(Function): - def forward(self, x): - return x.round() - - def backward(self, output_grads): - return output_grads +from .utils import QuantMode, Round, get_qparam_dict class Observer(Module): diff --git a/python_module/megengine/quantization/utils.py b/python_module/megengine/quantization/utils.py index 470e39ae6..2b940c9df 100644 --- a/python_module/megengine/quantization/utils.py +++ b/python_module/megengine/quantization/utils.py @@ -8,6 +8,24 @@ from enum import Enum from functools import partial, update_wrapper, wraps +from typing import Dict + +from .. import functional as F +from .._internal.dtype import _metadata_dict +from ..core import Function, Tensor + + +class Round(Function): + """ + The functional round have no grad and can not use for quantization-aware-training. + We use Function and STE(Straight-Through Estimator) to implement backward propagation. + """ + + def forward(self, x): + return x.round() + + def backward(self, output_grads): + return output_grads def register_method_to_class(cls): @@ -25,6 +43,9 @@ def register_method_to_class(cls): class QuantMode(Enum): + """Quantization mode enumerate class. + """ + SYMMERTIC = 1 ASYMMERTIC = 2 TQT = 3 @@ -41,5 +62,55 @@ qparam_dict = { } -def get_qparam_dict(mode): +def get_qparam_dict(mode: QuantMode): + """Return the quantization parameters dictory according to the mode. + """ return qparam_dict.get(mode, None) + + +def fake_quant_tensor(inp: Tensor, qmin: int, qmax: int, q_dict: Dict) -> Tensor: + """Apply fake quantization to the inp tensor. + + :param inp: the input tensor which need to be faked. + :param qmin: the minimum value which the integer limit to. + :param qmax: the maximum value which the integer limit to. + :param q_dict: the quantization parameter dict. + + """ + scale = q_dict["scale"] + zero_point = 0 + if q_dict["mode"] == QuantMode.ASYMMERTIC: + zero_point = q_dict["zero_point"] + # Quant + oup = Round()(inp / scale) + zero_point + # Clip + oup = F.minimum(F.maximum(oup, qmin), qmax) + # Dequant + oup = (oup - zero_point) * scale + return oup + + +def fake_quant_bias(bias: Tensor, inp: Tensor, w_qat: Tensor) -> Tensor: + """Apply fake quantization to bias, the special scale from input tensor + and weight tensor, the quantized type set to qint32 also. + + :param bias: the bias tensor which need to be faked. + :param inp: the input tensor which contain the quantization parameters. + :param qmax: the weight tensor which contain the quantization parameters. + + .. warning:: + Only work for symmetric quantization method now. + + """ + b_qat = bias + if hasattr(inp, "q_dict") and b_qat is not None: + if inp.q_dict["scale"] is not None and w_qat.q_dict["scale"] is not None: + # use the same mode with weight. + b_dict = get_qparam_dict(w_qat.q_dict["mode"]) + b_dict["scale"] = inp.q_dict["scale"] * w_qat.q_dict["scale"] + # TODO: add zero_point for ASYMMERTIC mode. + qmax = _metadata_dict["qint32"].qmax + qmin = _metadata_dict["qint32"].qmin + b_qat = fake_quant_tensor(b_qat, qmin, qmax, b_dict) + + return b_qat -- GitLab