From ce88e6c4f727ee7a91a0085108120386d75148d9 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 25 Feb 2021 17:12:22 +0800 Subject: [PATCH] feat(mge/quantization): use extra act_fakequant to decide whether to do bias fakequant GitOrigin-RevId: bf54012155292a5270db34d49cc5e761776b9ad3 --- .../module/qat/batch_matmul_activation.py | 4 +--- .../python/megengine/module/qat/conv.py | 6 +----- .../python/megengine/module/qat/conv_bn.py | 6 +----- .../python/megengine/module/qat/linear.py | 10 +++------- .../python/megengine/module/qat/module.py | 19 +++++++++++++++++++ .../python/megengine/quantization/__init__.py | 8 +++++++- 6 files changed, 32 insertions(+), 21 deletions(-) diff --git a/imperative/python/megengine/module/qat/batch_matmul_activation.py b/imperative/python/megengine/module/qat/batch_matmul_activation.py index e3e2a0b2..1b1ff2c7 100644 --- a/imperative/python/megengine/module/qat/batch_matmul_activation.py +++ b/imperative/python/megengine/module/qat/batch_matmul_activation.py @@ -5,8 +5,6 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - -from ...quantization.utils import fake_quant_bias from .. import batch_matmul_activation as Float from .module import QATModule @@ -18,7 +16,7 @@ class BatchMatMulActivation(Float.BatchMatMulActivation, QATModule): def forward(self, inp): w_qat = self.apply_quant_weight(self.weight) - b_qat = fake_quant_bias(self.bias, inp, w_qat) + b_qat = self.apply_quant_bias(self.bias, inp, w_qat) return self.apply_quant_activation(self._calc_linear(inp, w_qat, b_qat)) @classmethod diff --git a/imperative/python/megengine/module/qat/conv.py b/imperative/python/megengine/module/qat/conv.py index f8205c95..c3608d59 100644 --- a/imperative/python/megengine/module/qat/conv.py +++ b/imperative/python/megengine/module/qat/conv.py @@ -6,7 +6,6 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from ... import functional as F -from ...quantization.utils import fake_quant_bias from .. import conv as Float from .module import QATModule @@ -19,10 +18,7 @@ class Conv2d(Float.Conv2d, QATModule): def calc_conv_qat(self, inp): w_qat = self.apply_quant_weight(self.weight) - if self.weight_fake_quant and self.weight_fake_quant.enabled: - b_qat = fake_quant_bias(self.bias, inp, w_qat) - else: - b_qat = self.bias + b_qat = self.apply_quant_bias(self.bias, inp, w_qat) conv = self.calc_conv(inp, w_qat, b_qat) return conv diff --git a/imperative/python/megengine/module/qat/conv_bn.py b/imperative/python/megengine/module/qat/conv_bn.py index 2cc5be08..3ee4d407 100644 --- a/imperative/python/megengine/module/qat/conv_bn.py +++ b/imperative/python/megengine/module/qat/conv_bn.py @@ -6,7 +6,6 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from ...functional import ones, relu, sqrt, sum, zeros -from ...quantization.utils import fake_quant_bias from .. import conv_bn as Float from .module import QATModule @@ -122,10 +121,7 @@ class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule): b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd w_qat = self.apply_quant_weight(w_fold) - if self.weight_fake_quant and self.weight_fake_quant.enabled: - b_qat = fake_quant_bias(b_fold, inp, w_qat) - else: - b_qat = b_fold + b_qat = self.apply_quant_bias(b_fold, inp, w_qat) conv = self.conv.calc_conv(inp, w_qat, b_qat) if not (self.training and approx): return conv diff --git a/imperative/python/megengine/module/qat/linear.py b/imperative/python/megengine/module/qat/linear.py index 8647fc91..fc92ff18 100644 --- a/imperative/python/megengine/module/qat/linear.py +++ b/imperative/python/megengine/module/qat/linear.py @@ -5,7 +5,6 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -from ...quantization.utils import fake_quant_bias from .. import linear as Float from .module import QATModule @@ -22,13 +21,10 @@ class Linear(Float.Linear, QATModule): """ - def forward(self, x): + def forward(self, inp): w_qat = self.apply_quant_weight(self.weight) - if self.weight_fake_quant and self.weight_fake_quant.enabled: - b_qat = fake_quant_bias(self.bias, x, w_qat) - else: - b_qat = self.bias - return self.apply_quant_activation(self._calc_linear(x, w_qat, b_qat)) + b_qat = self.apply_quant_bias(self.bias, inp, w_qat) + return self.apply_quant_activation(self._calc_linear(inp, w_qat, b_qat)) @classmethod def from_float_module(cls, float_module: Float.Linear): diff --git a/imperative/python/megengine/module/qat/module.py b/imperative/python/megengine/module/qat/module.py index 466ba960..ad66d844 100644 --- a/imperative/python/megengine/module/qat/module.py +++ b/imperative/python/megengine/module/qat/module.py @@ -11,6 +11,7 @@ from abc import abstractmethod from ...quantization.fake_quant import FakeQuantize from ...quantization.observer import Observer from ...quantization.qconfig import QConfig +from ...quantization.utils import fake_quant_bias from ...tensor import Tensor from ..module import Module @@ -107,6 +108,24 @@ class QATModule(Module): target, self.act_fake_quant, self.act_observer ) + def apply_quant_bias(self, target: Tensor, inp: Tensor, w_qat: Tensor): + r""" + Use :func:`~.fake_quant_bias` to process ``target``. Only valid when + ``act_fake_quant`` and ``weight_fake_quant`` are both enabled. + """ + # bias should have the same dtype as activation, so act_fake_quant can also + # decide whether to do bias fakequant + if ( + self.act_fake_quant + and self.act_fake_quant.enabled + and self.weight_fake_quant + and self.weight_fake_quant.enabled + ): + b_qat = fake_quant_bias(target, inp, w_qat) + else: + b_qat = target + return b_qat + def _get_method_result( self, method: str, fake_quant: FakeQuantize, observer: Observer ): diff --git a/imperative/python/megengine/quantization/__init__.py b/imperative/python/megengine/quantization/__init__.py index 0407cd6a..d33de8d7 100644 --- a/imperative/python/megengine/quantization/__init__.py +++ b/imperative/python/megengine/quantization/__init__.py @@ -30,4 +30,10 @@ from .quantize import ( quantize_qat, reset_qconfig, ) -from .utils import QParams, QuantMode, create_qparams +from .utils import ( + QParams, + QuantMode, + create_qparams, + fake_quant_bias, + fake_quant_tensor, +) -- GitLab