提交 ce88e6c4 编写于 作者: M Megvii Engine Team

feat(mge/quantization): use extra act_fakequant to decide whether to do bias fakequant

GitOrigin-RevId: bf54012155292a5270db34d49cc5e761776b9ad3
上级 1d7dd001
...@@ -5,8 +5,6 @@ ...@@ -5,8 +5,6 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ...quantization.utils import fake_quant_bias
from .. import batch_matmul_activation as Float from .. import batch_matmul_activation as Float
from .module import QATModule from .module import QATModule
...@@ -18,7 +16,7 @@ class BatchMatMulActivation(Float.BatchMatMulActivation, QATModule): ...@@ -18,7 +16,7 @@ class BatchMatMulActivation(Float.BatchMatMulActivation, QATModule):
def forward(self, inp): def forward(self, inp):
w_qat = self.apply_quant_weight(self.weight) w_qat = self.apply_quant_weight(self.weight)
b_qat = fake_quant_bias(self.bias, inp, w_qat) b_qat = self.apply_quant_bias(self.bias, inp, w_qat)
return self.apply_quant_activation(self._calc_linear(inp, w_qat, b_qat)) return self.apply_quant_activation(self._calc_linear(inp, w_qat, b_qat))
@classmethod @classmethod
......
...@@ -6,7 +6,6 @@ ...@@ -6,7 +6,6 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ... import functional as F from ... import functional as F
from ...quantization.utils import fake_quant_bias
from .. import conv as Float from .. import conv as Float
from .module import QATModule from .module import QATModule
...@@ -19,10 +18,7 @@ class Conv2d(Float.Conv2d, QATModule): ...@@ -19,10 +18,7 @@ class Conv2d(Float.Conv2d, QATModule):
def calc_conv_qat(self, inp): def calc_conv_qat(self, inp):
w_qat = self.apply_quant_weight(self.weight) w_qat = self.apply_quant_weight(self.weight)
if self.weight_fake_quant and self.weight_fake_quant.enabled: b_qat = self.apply_quant_bias(self.bias, inp, w_qat)
b_qat = fake_quant_bias(self.bias, inp, w_qat)
else:
b_qat = self.bias
conv = self.calc_conv(inp, w_qat, b_qat) conv = self.calc_conv(inp, w_qat, b_qat)
return conv return conv
......
...@@ -6,7 +6,6 @@ ...@@ -6,7 +6,6 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ...functional import ones, relu, sqrt, sum, zeros from ...functional import ones, relu, sqrt, sum, zeros
from ...quantization.utils import fake_quant_bias
from .. import conv_bn as Float from .. import conv_bn as Float
from .module import QATModule from .module import QATModule
...@@ -122,10 +121,7 @@ class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule): ...@@ -122,10 +121,7 @@ class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule):
b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd
w_qat = self.apply_quant_weight(w_fold) w_qat = self.apply_quant_weight(w_fold)
if self.weight_fake_quant and self.weight_fake_quant.enabled: b_qat = self.apply_quant_bias(b_fold, inp, w_qat)
b_qat = fake_quant_bias(b_fold, inp, w_qat)
else:
b_qat = b_fold
conv = self.conv.calc_conv(inp, w_qat, b_qat) conv = self.conv.calc_conv(inp, w_qat, b_qat)
if not (self.training and approx): if not (self.training and approx):
return conv return conv
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ...quantization.utils import fake_quant_bias
from .. import linear as Float from .. import linear as Float
from .module import QATModule from .module import QATModule
...@@ -22,13 +21,10 @@ class Linear(Float.Linear, QATModule): ...@@ -22,13 +21,10 @@ class Linear(Float.Linear, QATModule):
""" """
def forward(self, x): def forward(self, inp):
w_qat = self.apply_quant_weight(self.weight) w_qat = self.apply_quant_weight(self.weight)
if self.weight_fake_quant and self.weight_fake_quant.enabled: b_qat = self.apply_quant_bias(self.bias, inp, w_qat)
b_qat = fake_quant_bias(self.bias, x, w_qat) return self.apply_quant_activation(self._calc_linear(inp, w_qat, b_qat))
else:
b_qat = self.bias
return self.apply_quant_activation(self._calc_linear(x, w_qat, b_qat))
@classmethod @classmethod
def from_float_module(cls, float_module: Float.Linear): def from_float_module(cls, float_module: Float.Linear):
......
...@@ -11,6 +11,7 @@ from abc import abstractmethod ...@@ -11,6 +11,7 @@ from abc import abstractmethod
from ...quantization.fake_quant import FakeQuantize from ...quantization.fake_quant import FakeQuantize
from ...quantization.observer import Observer from ...quantization.observer import Observer
from ...quantization.qconfig import QConfig from ...quantization.qconfig import QConfig
from ...quantization.utils import fake_quant_bias
from ...tensor import Tensor from ...tensor import Tensor
from ..module import Module from ..module import Module
...@@ -107,6 +108,24 @@ class QATModule(Module): ...@@ -107,6 +108,24 @@ class QATModule(Module):
target, self.act_fake_quant, self.act_observer target, self.act_fake_quant, self.act_observer
) )
def apply_quant_bias(self, target: Tensor, inp: Tensor, w_qat: Tensor):
r"""
Use :func:`~.fake_quant_bias` to process ``target``. Only valid when
``act_fake_quant`` and ``weight_fake_quant`` are both enabled.
"""
# bias should have the same dtype as activation, so act_fake_quant can also
# decide whether to do bias fakequant
if (
self.act_fake_quant
and self.act_fake_quant.enabled
and self.weight_fake_quant
and self.weight_fake_quant.enabled
):
b_qat = fake_quant_bias(target, inp, w_qat)
else:
b_qat = target
return b_qat
def _get_method_result( def _get_method_result(
self, method: str, fake_quant: FakeQuantize, observer: Observer self, method: str, fake_quant: FakeQuantize, observer: Observer
): ):
......
...@@ -30,4 +30,10 @@ from .quantize import ( ...@@ -30,4 +30,10 @@ from .quantize import (
quantize_qat, quantize_qat,
reset_qconfig, reset_qconfig,
) )
from .utils import QParams, QuantMode, create_qparams from .utils import (
QParams,
QuantMode,
create_qparams,
fake_quant_bias,
fake_quant_tensor,
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册