提交 6c4841e8 编写于 作者: M Megvii Engine Team

fix(mge/quantization): `disable_fake_quant` does not work correctly

GitOrigin-RevId: 0c568d3335aa55cf0fc3fb569b9b124b581e5256
上级 aa953c3b
...@@ -57,7 +57,7 @@ def _is_module(obj): ...@@ -57,7 +57,7 @@ def _is_module(obj):
def _get_XNorm_typeclass(): def _get_XNorm_typeclass():
from .batchnorm import _BatchNorm from .batchnorm import _BatchNorm
from .normalization import GroupNorm, LayerNorm, InstanceNorm from .normalization import GroupNorm, InstanceNorm, LayerNorm
XNorm_types = (_BatchNorm, GroupNorm, LayerNorm, InstanceNorm) XNorm_types = (_BatchNorm, GroupNorm, LayerNorm, InstanceNorm)
return XNorm_types return XNorm_types
......
...@@ -19,7 +19,10 @@ class Conv2d(Float.Conv2d, QATModule): ...@@ -19,7 +19,10 @@ class Conv2d(Float.Conv2d, QATModule):
def calc_conv_qat(self, inp): def calc_conv_qat(self, inp):
w_qat = self.apply_quant_weight(self.weight) w_qat = self.apply_quant_weight(self.weight)
b_qat = fake_quant_bias(self.bias, inp, w_qat) if self.weight_fake_quant and self.weight_fake_quant.enabled:
b_qat = fake_quant_bias(self.bias, inp, w_qat)
else:
b_qat = self.bias
conv = self.calc_conv(inp, w_qat, b_qat) conv = self.calc_conv(inp, w_qat, b_qat)
return conv return conv
......
...@@ -122,7 +122,10 @@ class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule): ...@@ -122,7 +122,10 @@ class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule):
b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd
w_qat = self.apply_quant_weight(w_fold) w_qat = self.apply_quant_weight(w_fold)
b_qat = fake_quant_bias(b_fold, inp, w_qat) if self.weight_fake_quant and self.weight_fake_quant.enabled:
b_qat = fake_quant_bias(b_fold, inp, w_qat)
else:
b_qat = b_fold
conv = self.conv.calc_conv(inp, w_qat, b_qat) conv = self.conv.calc_conv(inp, w_qat, b_qat)
if not (self.training and approx): if not (self.training and approx):
return conv return conv
......
...@@ -24,7 +24,10 @@ class Linear(Float.Linear, QATModule): ...@@ -24,7 +24,10 @@ class Linear(Float.Linear, QATModule):
def forward(self, x): def forward(self, x):
w_qat = self.apply_quant_weight(self.weight) w_qat = self.apply_quant_weight(self.weight)
b_qat = fake_quant_bias(self.bias, x, w_qat) if self.weight_fake_quant and self.weight_fake_quant.enabled:
b_qat = fake_quant_bias(self.bias, x, w_qat)
else:
b_qat = self.bias
return self.apply_quant_activation(self._calc_linear(x, w_qat, b_qat)) return self.apply_quant_activation(self._calc_linear(x, w_qat, b_qat))
@classmethod @classmethod
......
...@@ -116,7 +116,7 @@ class TQT(_FakeQuantize): ...@@ -116,7 +116,7 @@ class TQT(_FakeQuantize):
def __init__(self, dtype: str, narrow_range: bool = False, enable: bool = True): def __init__(self, dtype: str, narrow_range: bool = False, enable: bool = True):
super().__init__(dtype, narrow_range, enable) super().__init__(dtype, narrow_range, enable)
self.scale = Parameter(0.0, dtype=np.float32) self.scale = Parameter([0.0], dtype=np.float32)
def fake_quant_forward(self, inp, q_dict=None): def fake_quant_forward(self, inp, q_dict=None):
# when enable, TQT will do fakequant forward, finetune the scale # when enable, TQT will do fakequant forward, finetune the scale
......
...@@ -219,8 +219,8 @@ class HistogramObserver(MinMaxObserver): ...@@ -219,8 +219,8 @@ class HistogramObserver(MinMaxObserver):
By selecting new min/max, we filter out outliers in input distribution. By selecting new min/max, we filter out outliers in input distribution.
""" """
np_min_val = self.min_val.numpy()[0] np_min_val = self.min_val.numpy()
np_max_val = self.max_val.numpy()[0] np_max_val = self.max_val.numpy()
np_histogram = self.histogram.numpy() np_histogram = self.histogram.numpy()
assert len(np_histogram) == self.bins, "bins mistmatch" assert len(np_histogram) == self.bins, "bins mistmatch"
bin_width = (np_max_val - np_min_val) / self.bins bin_width = (np_max_val - np_min_val) / self.bins
...@@ -386,8 +386,8 @@ class HistogramObserver(MinMaxObserver): ...@@ -386,8 +386,8 @@ class HistogramObserver(MinMaxObserver):
# This allows us to have a common grid of resolution s, where we can align # This allows us to have a common grid of resolution s, where we can align
# the input histogram # the input histogram
# start_idx maps min_val to the histogram bin index. # start_idx maps min_val to the histogram bin index.
np_min_val = self.min_val.numpy()[0] np_min_val = self.min_val.numpy()
np_max_val = self.max_val.numpy()[0] np_max_val = self.max_val.numpy()
hist_bin_width = (np_max_val - np_min_val) / (self.bins * upsample_rate) hist_bin_width = (np_max_val - np_min_val) / (self.bins * upsample_rate)
downsample_rate = int( downsample_rate = int(
...@@ -404,8 +404,8 @@ class HistogramObserver(MinMaxObserver): ...@@ -404,8 +404,8 @@ class HistogramObserver(MinMaxObserver):
def sideeffect_forward(self, x_orig): def sideeffect_forward(self, x_orig):
x = x_orig.numpy() x = x_orig.numpy()
min_val = self.min_val.numpy()[0] min_val = self.min_val.numpy()
max_val = self.max_val.numpy()[0] max_val = self.max_val.numpy()
histogram = self.histogram.numpy() histogram = self.histogram.numpy()
new_min = x.min() new_min = x.min()
new_max = x.max() new_max = x.max()
......
...@@ -125,5 +125,6 @@ def fake_quant_bias(bias: Tensor, inp: Tensor, w_qat: Tensor) -> Tensor: ...@@ -125,5 +125,6 @@ def fake_quant_bias(bias: Tensor, inp: Tensor, w_qat: Tensor) -> Tensor:
qmax = _metadata_dict["qint32"].qmax qmax = _metadata_dict["qint32"].qmax
qmin = _metadata_dict["qint32"].qmin qmin = _metadata_dict["qint32"].qmin
b_qat = fake_quant_tensor(b_qat, qmin, qmax, b_dict) b_qat = fake_quant_tensor(b_qat, qmin, qmax, b_dict)
b_qat.q_dict.update(b_dict)
return b_qat return b_qat
...@@ -115,9 +115,10 @@ def _dump_compatible(entries: List[ProfileEntry], path: str): ...@@ -115,9 +115,10 @@ def _dump_compatible(entries: List[ProfileEntry], path: str):
def _dump_graphviz(entries: List[ProfileEntry], path: str): def _dump_graphviz(entries: List[ProfileEntry], path: str):
import graphviz
import json import json
import graphviz
graph = graphviz.Digraph() graph = graphviz.Digraph()
graph.graph_attr["ordering"] = "out" graph.graph_attr["ordering"] = "out"
var_cache = {} var_cache = {}
......
...@@ -14,8 +14,8 @@ from megengine.core.tensor.core import apply ...@@ -14,8 +14,8 @@ from megengine.core.tensor.core import apply
def elemwise(*args, mode): def elemwise(*args, mode):
from megengine.core.ops.builtin import Elemwise
from megengine.core._imperative_rt.imperative import apply_op from megengine.core._imperative_rt.imperative import apply_op
from megengine.core.ops.builtin import Elemwise
return apply_op(Elemwise(mode), args) return apply_op(Elemwise(mode), args)
...@@ -61,8 +61,8 @@ def test_tensor_on_device(): ...@@ -61,8 +61,8 @@ def test_tensor_on_device():
def test_raw_tensor(): def test_raw_tensor():
from megengine.core.tensor.raw_tensor import as_raw_tensor
from megengine.core.ops.builtin import Elemwise from megengine.core.ops.builtin import Elemwise
from megengine.core.tensor.raw_tensor import as_raw_tensor
x = np.random.rand(10).astype("float32") x = np.random.rand(10).astype("float32")
xx = as_raw_tensor(x) xx = as_raw_tensor(x)
......
...@@ -5,9 +5,18 @@ ...@@ -5,9 +5,18 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
import pytest
from megengine import module as Float from megengine import module as Float
from megengine import tensor
from megengine.module import qat as QAT from megengine.module import qat as QAT
from megengine.quantization.quantize import _get_quantable_module_names, quantize_qat from megengine.quantization import min_max_fakequant_qconfig
from megengine.quantization.quantize import (
_get_quantable_module_names,
disable_fake_quant,
quantize_qat,
)
def test_get_quantable_module_names(): def test_get_quantable_module_names():
...@@ -78,3 +87,30 @@ def test_convert_with_custom_mapping(): ...@@ -78,3 +87,30 @@ def test_convert_with_custom_mapping():
net = Net() net = Net()
qat_net = quantize_qat(net, inplace=False, mapping={FloatExample: QATExample}) qat_net = quantize_qat(net, inplace=False, mapping={FloatExample: QATExample})
assert isinstance(qat_net.example, QATExample) assert isinstance(qat_net.example, QATExample)
def test_disable_fake_quant():
class Net(Float.Module):
def __init__(self):
super().__init__()
self.quant = Float.QuantStub()
self.linear = Float.Linear(3, 3)
self.dequant = Float.DequantStub()
self.linear.bias.set_value(np.random.rand(3))
def forward(self, x):
x = self.quant(x)
x = self.linear(x)
x = self.dequant(x)
return x
x = tensor(np.random.randint(1, 10, size=(3, 3)).astype(np.float32))
net = Net()
y1 = net(x).numpy()
net = quantize_qat(net, min_max_fakequant_qconfig)
y2 = net(x).numpy()
disable_fake_quant(net)
y3 = net(x).numpy()
np.testing.assert_allclose(y1, y3)
with pytest.raises(AssertionError):
np.testing.assert_allclose(y2, y3)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册