提交 555ecea9 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

feat(mge/quantization): add bias fakequant support

GitOrigin-RevId: a5e953b3fa3e0cf91b03708c26dca4561243504a
上级 9440842e
......@@ -138,6 +138,7 @@ class Tensor:
def __init__(self, val=None, *, requires_grad=None):
self._reset(val, requires_grad=requires_grad)
self.q_dict = {"mode": None, "scale": None, "zero_point": None}
def _reset(self, val=None, *, requires_grad=None):
self.__sym_override = None
......@@ -677,9 +678,9 @@ class Tensor:
def __deepcopy__(self, memo):
"""
Since Tensor have __getstate__ and __setstate__ method,
deepcopy only process the that and ignore the attribute of Parameter.
So we need to add __deepcopy__ method to deepcopy correct attribute.
The default deepcopy will ignore other attributes except those defined at
__getstate__ and __setstate__ method.
So we need to add __deepcopy__ method to deepcopy correct attributes.
"""
assert (self.__val is not None) and (
self.__sym is None
......
......@@ -6,6 +6,7 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ... import functional as F
from ...quantization.utils import fake_quant_bias
from .. import conv as Float
from .module import QATModule
......@@ -18,7 +19,8 @@ class Conv2d(Float.Conv2d, QATModule):
def calc_conv_qat(self, inp):
w_qat = self.apply_quant_weight(self.weight)
conv = self.calc_conv(inp, w_qat, self.bias)
b_qat = fake_quant_bias(self.bias, inp, w_qat)
conv = self.calc_conv(inp, w_qat, b_qat)
return conv
@classmethod
......
......@@ -7,6 +7,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ...core import ones, zeros
from ...functional import add_update, relu, sqrt, sum, zero_grad
from ...quantization.utils import fake_quant_bias
from .. import conv_bn as Float
from .module import QATModule
......@@ -132,7 +133,8 @@ class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule):
b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd
w_qat = self.apply_quant_weight(w_fold)
conv = self.conv.calc_conv(inp, w_qat, b_fold)
b_qat = fake_quant_bias(b_fold, inp, w_qat)
conv = self.conv.calc_conv(inp, w_qat, b_qat)
if not (self.training and approx):
return conv
......
......@@ -5,6 +5,7 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ...quantization.utils import fake_quant_bias
from .. import linear as Float
from .module import QATModule
......@@ -23,7 +24,8 @@ class Linear(Float.Linear, QATModule):
def forward(self, x):
w_qat = self.apply_quant_weight(self.weight)
return self.apply_quant_activation(self._calc_linear(x, w_qat, self.bias),)
b_qat = fake_quant_bias(self.bias, x, w_qat)
return self.apply_quant_activation(self._calc_linear(x, w_qat, b_qat))
@classmethod
def from_float_module(cls, float_module: Float.Linear):
......
......@@ -73,11 +73,16 @@ class QATModule(Module):
if observer is None:
return target
oup = observer(target)
if fake_quant is None:
return oup
else:
q_dict = observer.get_qparams()
return fake_quant(oup, q_dict)
q_dict = observer.get_qparams()
# do fake quant
if fake_quant is not None:
oup = fake_quant(oup, q_dict)
# use qparams of fake_quant if have.
if hasattr(fake_quant, "get_qparams"):
q_dict = fake_quant.get_qparams()
# set to tensor qparams.
oup.q_dict.update(q_dict)
return oup
def apply_quant_weight(self, target: Tensor):
r"""
......
......@@ -15,8 +15,7 @@ from .._internal.dtype import _metadata_dict, get_quantized_dtype
from ..core import Buffer, Function, Parameter
from ..jit import sideeffect
from ..module import Module
from .observer import Round
from .utils import QuantMode, get_qparam_dict
from .utils import QuantMode, Round, fake_quant_tensor, get_qparam_dict
class _FakeQuantize(Module):
......@@ -143,22 +142,4 @@ class FakeQuantize(_FakeQuantize):
"""
def fake_quant_forward(self, inp, q_dict):
if q_dict["mode"] == QuantMode.SYMMERTIC:
scale = q_dict["scale"]
# Quant
oup = Round()(inp / scale)
# clip
oup = F.minimum(F.maximum(oup, self.qmin), self.qmax)
# DeQuant
oup = (oup) * scale
return oup
else:
scale = q_dict["scale"]
zero_point = q_dict["zero_point"]
# Quant
oup = Round()(inp / scale) + zero_point
# clip
oup = F.minimum(F.maximum(oup, self.qmin), self.qmax)
# DeQuant
oup = (oup - zero_point) * scale
return oup
return fake_quant_tensor(inp, self.qmin, self.qmax, q_dict)
......@@ -13,18 +13,10 @@ import numpy as np
from .. import functional as F
from .._internal.dtype import _metadata_dict, get_quantized_dtype
from ..core import Buffer, Function, tensor
from ..core import Buffer
from ..jit import sideeffect
from ..module import Module
from .utils import QuantMode, get_qparam_dict
class Round(Function):
def forward(self, x):
return x.round()
def backward(self, output_grads):
return output_grads
from .utils import QuantMode, Round, get_qparam_dict
class Observer(Module):
......
......@@ -8,6 +8,24 @@
from enum import Enum
from functools import partial, update_wrapper, wraps
from typing import Dict
from .. import functional as F
from .._internal.dtype import _metadata_dict
from ..core import Function, Tensor
class Round(Function):
"""
The functional round have no grad and can not use for quantization-aware-training.
We use Function and STE(Straight-Through Estimator) to implement backward propagation.
"""
def forward(self, x):
return x.round()
def backward(self, output_grads):
return output_grads
def register_method_to_class(cls):
......@@ -25,6 +43,9 @@ def register_method_to_class(cls):
class QuantMode(Enum):
"""Quantization mode enumerate class.
"""
SYMMERTIC = 1
ASYMMERTIC = 2
TQT = 3
......@@ -41,5 +62,55 @@ qparam_dict = {
}
def get_qparam_dict(mode):
def get_qparam_dict(mode: QuantMode):
"""Return the quantization parameters dictory according to the mode.
"""
return qparam_dict.get(mode, None)
def fake_quant_tensor(inp: Tensor, qmin: int, qmax: int, q_dict: Dict) -> Tensor:
"""Apply fake quantization to the inp tensor.
:param inp: the input tensor which need to be faked.
:param qmin: the minimum value which the integer limit to.
:param qmax: the maximum value which the integer limit to.
:param q_dict: the quantization parameter dict.
"""
scale = q_dict["scale"]
zero_point = 0
if q_dict["mode"] == QuantMode.ASYMMERTIC:
zero_point = q_dict["zero_point"]
# Quant
oup = Round()(inp / scale) + zero_point
# Clip
oup = F.minimum(F.maximum(oup, qmin), qmax)
# Dequant
oup = (oup - zero_point) * scale
return oup
def fake_quant_bias(bias: Tensor, inp: Tensor, w_qat: Tensor) -> Tensor:
"""Apply fake quantization to bias, the special scale from input tensor
and weight tensor, the quantized type set to qint32 also.
:param bias: the bias tensor which need to be faked.
:param inp: the input tensor which contain the quantization parameters.
:param qmax: the weight tensor which contain the quantization parameters.
.. warning::
Only work for symmetric quantization method now.
"""
b_qat = bias
if hasattr(inp, "q_dict") and b_qat is not None:
if inp.q_dict["scale"] is not None and w_qat.q_dict["scale"] is not None:
# use the same mode with weight.
b_dict = get_qparam_dict(w_qat.q_dict["mode"])
b_dict["scale"] = inp.q_dict["scale"] * w_qat.q_dict["scale"]
# TODO: add zero_point for ASYMMERTIC mode.
qmax = _metadata_dict["qint32"].qmax
qmin = _metadata_dict["qint32"].qmin
b_qat = fake_quant_tensor(b_qat, qmin, qmax, b_dict)
return b_qat
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册