提交 ce88e6c4 编写于 作者: M Megvii Engine Team

feat(mge/quantization): use extra act_fakequant to decide whether to do bias fakequant

GitOrigin-RevId: bf54012155292a5270db34d49cc5e761776b9ad3
上级 1d7dd001
......@@ -5,8 +5,6 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ...quantization.utils import fake_quant_bias
from .. import batch_matmul_activation as Float
from .module import QATModule
......@@ -18,7 +16,7 @@ class BatchMatMulActivation(Float.BatchMatMulActivation, QATModule):
def forward(self, inp):
w_qat = self.apply_quant_weight(self.weight)
b_qat = fake_quant_bias(self.bias, inp, w_qat)
b_qat = self.apply_quant_bias(self.bias, inp, w_qat)
return self.apply_quant_activation(self._calc_linear(inp, w_qat, b_qat))
@classmethod
......
......@@ -6,7 +6,6 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ... import functional as F
from ...quantization.utils import fake_quant_bias
from .. import conv as Float
from .module import QATModule
......@@ -19,10 +18,7 @@ class Conv2d(Float.Conv2d, QATModule):
def calc_conv_qat(self, inp):
w_qat = self.apply_quant_weight(self.weight)
if self.weight_fake_quant and self.weight_fake_quant.enabled:
b_qat = fake_quant_bias(self.bias, inp, w_qat)
else:
b_qat = self.bias
b_qat = self.apply_quant_bias(self.bias, inp, w_qat)
conv = self.calc_conv(inp, w_qat, b_qat)
return conv
......
......@@ -6,7 +6,6 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ...functional import ones, relu, sqrt, sum, zeros
from ...quantization.utils import fake_quant_bias
from .. import conv_bn as Float
from .module import QATModule
......@@ -122,10 +121,7 @@ class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule):
b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd
w_qat = self.apply_quant_weight(w_fold)
if self.weight_fake_quant and self.weight_fake_quant.enabled:
b_qat = fake_quant_bias(b_fold, inp, w_qat)
else:
b_qat = b_fold
b_qat = self.apply_quant_bias(b_fold, inp, w_qat)
conv = self.conv.calc_conv(inp, w_qat, b_qat)
if not (self.training and approx):
return conv
......
......@@ -5,7 +5,6 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ...quantization.utils import fake_quant_bias
from .. import linear as Float
from .module import QATModule
......@@ -22,13 +21,10 @@ class Linear(Float.Linear, QATModule):
"""
def forward(self, x):
def forward(self, inp):
w_qat = self.apply_quant_weight(self.weight)
if self.weight_fake_quant and self.weight_fake_quant.enabled:
b_qat = fake_quant_bias(self.bias, x, w_qat)
else:
b_qat = self.bias
return self.apply_quant_activation(self._calc_linear(x, w_qat, b_qat))
b_qat = self.apply_quant_bias(self.bias, inp, w_qat)
return self.apply_quant_activation(self._calc_linear(inp, w_qat, b_qat))
@classmethod
def from_float_module(cls, float_module: Float.Linear):
......
......@@ -11,6 +11,7 @@ from abc import abstractmethod
from ...quantization.fake_quant import FakeQuantize
from ...quantization.observer import Observer
from ...quantization.qconfig import QConfig
from ...quantization.utils import fake_quant_bias
from ...tensor import Tensor
from ..module import Module
......@@ -107,6 +108,24 @@ class QATModule(Module):
target, self.act_fake_quant, self.act_observer
)
def apply_quant_bias(self, target: Tensor, inp: Tensor, w_qat: Tensor):
r"""
Use :func:`~.fake_quant_bias` to process ``target``. Only valid when
``act_fake_quant`` and ``weight_fake_quant`` are both enabled.
"""
# bias should have the same dtype as activation, so act_fake_quant can also
# decide whether to do bias fakequant
if (
self.act_fake_quant
and self.act_fake_quant.enabled
and self.weight_fake_quant
and self.weight_fake_quant.enabled
):
b_qat = fake_quant_bias(target, inp, w_qat)
else:
b_qat = target
return b_qat
def _get_method_result(
self, method: str, fake_quant: FakeQuantize, observer: Observer
):
......
......@@ -30,4 +30,10 @@ from .quantize import (
quantize_qat,
reset_qconfig,
)
from .utils import QParams, QuantMode, create_qparams
from .utils import (
QParams,
QuantMode,
create_qparams,
fake_quant_bias,
fake_quant_tensor,
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册