提交 555ecea9 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

feat(mge/quantization): add bias fakequant support

GitOrigin-RevId: a5e953b3fa3e0cf91b03708c26dca4561243504a
上级 9440842e
...@@ -138,6 +138,7 @@ class Tensor: ...@@ -138,6 +138,7 @@ class Tensor:
def __init__(self, val=None, *, requires_grad=None): def __init__(self, val=None, *, requires_grad=None):
self._reset(val, requires_grad=requires_grad) self._reset(val, requires_grad=requires_grad)
self.q_dict = {"mode": None, "scale": None, "zero_point": None}
def _reset(self, val=None, *, requires_grad=None): def _reset(self, val=None, *, requires_grad=None):
self.__sym_override = None self.__sym_override = None
...@@ -677,9 +678,9 @@ class Tensor: ...@@ -677,9 +678,9 @@ class Tensor:
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
""" """
Since Tensor have __getstate__ and __setstate__ method, The default deepcopy will ignore other attributes except those defined at
deepcopy only process the that and ignore the attribute of Parameter. __getstate__ and __setstate__ method.
So we need to add __deepcopy__ method to deepcopy correct attribute. So we need to add __deepcopy__ method to deepcopy correct attributes.
""" """
assert (self.__val is not None) and ( assert (self.__val is not None) and (
self.__sym is None self.__sym is None
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ... import functional as F from ... import functional as F
from ...quantization.utils import fake_quant_bias
from .. import conv as Float from .. import conv as Float
from .module import QATModule from .module import QATModule
...@@ -18,7 +19,8 @@ class Conv2d(Float.Conv2d, QATModule): ...@@ -18,7 +19,8 @@ class Conv2d(Float.Conv2d, QATModule):
def calc_conv_qat(self, inp): def calc_conv_qat(self, inp):
w_qat = self.apply_quant_weight(self.weight) w_qat = self.apply_quant_weight(self.weight)
conv = self.calc_conv(inp, w_qat, self.bias) b_qat = fake_quant_bias(self.bias, inp, w_qat)
conv = self.calc_conv(inp, w_qat, b_qat)
return conv return conv
@classmethod @classmethod
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ...core import ones, zeros from ...core import ones, zeros
from ...functional import add_update, relu, sqrt, sum, zero_grad from ...functional import add_update, relu, sqrt, sum, zero_grad
from ...quantization.utils import fake_quant_bias
from .. import conv_bn as Float from .. import conv_bn as Float
from .module import QATModule from .module import QATModule
...@@ -132,7 +133,8 @@ class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule): ...@@ -132,7 +133,8 @@ class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule):
b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd
w_qat = self.apply_quant_weight(w_fold) w_qat = self.apply_quant_weight(w_fold)
conv = self.conv.calc_conv(inp, w_qat, b_fold) b_qat = fake_quant_bias(b_fold, inp, w_qat)
conv = self.conv.calc_conv(inp, w_qat, b_qat)
if not (self.training and approx): if not (self.training and approx):
return conv return conv
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ...quantization.utils import fake_quant_bias
from .. import linear as Float from .. import linear as Float
from .module import QATModule from .module import QATModule
...@@ -23,7 +24,8 @@ class Linear(Float.Linear, QATModule): ...@@ -23,7 +24,8 @@ class Linear(Float.Linear, QATModule):
def forward(self, x): def forward(self, x):
w_qat = self.apply_quant_weight(self.weight) w_qat = self.apply_quant_weight(self.weight)
return self.apply_quant_activation(self._calc_linear(x, w_qat, self.bias),) b_qat = fake_quant_bias(self.bias, x, w_qat)
return self.apply_quant_activation(self._calc_linear(x, w_qat, b_qat))
@classmethod @classmethod
def from_float_module(cls, float_module: Float.Linear): def from_float_module(cls, float_module: Float.Linear):
......
...@@ -73,11 +73,16 @@ class QATModule(Module): ...@@ -73,11 +73,16 @@ class QATModule(Module):
if observer is None: if observer is None:
return target return target
oup = observer(target) oup = observer(target)
if fake_quant is None:
return oup
else:
q_dict = observer.get_qparams() q_dict = observer.get_qparams()
return fake_quant(oup, q_dict) # do fake quant
if fake_quant is not None:
oup = fake_quant(oup, q_dict)
# use qparams of fake_quant if have.
if hasattr(fake_quant, "get_qparams"):
q_dict = fake_quant.get_qparams()
# set to tensor qparams.
oup.q_dict.update(q_dict)
return oup
def apply_quant_weight(self, target: Tensor): def apply_quant_weight(self, target: Tensor):
r""" r"""
......
...@@ -15,8 +15,7 @@ from .._internal.dtype import _metadata_dict, get_quantized_dtype ...@@ -15,8 +15,7 @@ from .._internal.dtype import _metadata_dict, get_quantized_dtype
from ..core import Buffer, Function, Parameter from ..core import Buffer, Function, Parameter
from ..jit import sideeffect from ..jit import sideeffect
from ..module import Module from ..module import Module
from .observer import Round from .utils import QuantMode, Round, fake_quant_tensor, get_qparam_dict
from .utils import QuantMode, get_qparam_dict
class _FakeQuantize(Module): class _FakeQuantize(Module):
...@@ -143,22 +142,4 @@ class FakeQuantize(_FakeQuantize): ...@@ -143,22 +142,4 @@ class FakeQuantize(_FakeQuantize):
""" """
def fake_quant_forward(self, inp, q_dict): def fake_quant_forward(self, inp, q_dict):
if q_dict["mode"] == QuantMode.SYMMERTIC: return fake_quant_tensor(inp, self.qmin, self.qmax, q_dict)
scale = q_dict["scale"]
# Quant
oup = Round()(inp / scale)
# clip
oup = F.minimum(F.maximum(oup, self.qmin), self.qmax)
# DeQuant
oup = (oup) * scale
return oup
else:
scale = q_dict["scale"]
zero_point = q_dict["zero_point"]
# Quant
oup = Round()(inp / scale) + zero_point
# clip
oup = F.minimum(F.maximum(oup, self.qmin), self.qmax)
# DeQuant
oup = (oup - zero_point) * scale
return oup
...@@ -13,18 +13,10 @@ import numpy as np ...@@ -13,18 +13,10 @@ import numpy as np
from .. import functional as F from .. import functional as F
from .._internal.dtype import _metadata_dict, get_quantized_dtype from .._internal.dtype import _metadata_dict, get_quantized_dtype
from ..core import Buffer, Function, tensor from ..core import Buffer
from ..jit import sideeffect from ..jit import sideeffect
from ..module import Module from ..module import Module
from .utils import QuantMode, get_qparam_dict from .utils import QuantMode, Round, get_qparam_dict
class Round(Function):
def forward(self, x):
return x.round()
def backward(self, output_grads):
return output_grads
class Observer(Module): class Observer(Module):
......
...@@ -8,6 +8,24 @@ ...@@ -8,6 +8,24 @@
from enum import Enum from enum import Enum
from functools import partial, update_wrapper, wraps from functools import partial, update_wrapper, wraps
from typing import Dict
from .. import functional as F
from .._internal.dtype import _metadata_dict
from ..core import Function, Tensor
class Round(Function):
"""
The functional round have no grad and can not use for quantization-aware-training.
We use Function and STE(Straight-Through Estimator) to implement backward propagation.
"""
def forward(self, x):
return x.round()
def backward(self, output_grads):
return output_grads
def register_method_to_class(cls): def register_method_to_class(cls):
...@@ -25,6 +43,9 @@ def register_method_to_class(cls): ...@@ -25,6 +43,9 @@ def register_method_to_class(cls):
class QuantMode(Enum): class QuantMode(Enum):
"""Quantization mode enumerate class.
"""
SYMMERTIC = 1 SYMMERTIC = 1
ASYMMERTIC = 2 ASYMMERTIC = 2
TQT = 3 TQT = 3
...@@ -41,5 +62,55 @@ qparam_dict = { ...@@ -41,5 +62,55 @@ qparam_dict = {
} }
def get_qparam_dict(mode): def get_qparam_dict(mode: QuantMode):
"""Return the quantization parameters dictory according to the mode.
"""
return qparam_dict.get(mode, None) return qparam_dict.get(mode, None)
def fake_quant_tensor(inp: Tensor, qmin: int, qmax: int, q_dict: Dict) -> Tensor:
"""Apply fake quantization to the inp tensor.
:param inp: the input tensor which need to be faked.
:param qmin: the minimum value which the integer limit to.
:param qmax: the maximum value which the integer limit to.
:param q_dict: the quantization parameter dict.
"""
scale = q_dict["scale"]
zero_point = 0
if q_dict["mode"] == QuantMode.ASYMMERTIC:
zero_point = q_dict["zero_point"]
# Quant
oup = Round()(inp / scale) + zero_point
# Clip
oup = F.minimum(F.maximum(oup, qmin), qmax)
# Dequant
oup = (oup - zero_point) * scale
return oup
def fake_quant_bias(bias: Tensor, inp: Tensor, w_qat: Tensor) -> Tensor:
"""Apply fake quantization to bias, the special scale from input tensor
and weight tensor, the quantized type set to qint32 also.
:param bias: the bias tensor which need to be faked.
:param inp: the input tensor which contain the quantization parameters.
:param qmax: the weight tensor which contain the quantization parameters.
.. warning::
Only work for symmetric quantization method now.
"""
b_qat = bias
if hasattr(inp, "q_dict") and b_qat is not None:
if inp.q_dict["scale"] is not None and w_qat.q_dict["scale"] is not None:
# use the same mode with weight.
b_dict = get_qparam_dict(w_qat.q_dict["mode"])
b_dict["scale"] = inp.q_dict["scale"] * w_qat.q_dict["scale"]
# TODO: add zero_point for ASYMMERTIC mode.
qmax = _metadata_dict["qint32"].qmax
qmin = _metadata_dict["qint32"].qmin
b_qat = fake_quant_tensor(b_qat, qmin, qmax, b_dict)
return b_qat
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册