提交 6c4841e8 编写于 作者: M Megvii Engine Team

fix(mge/quantization): `disable_fake_quant` does not work correctly

GitOrigin-RevId: 0c568d3335aa55cf0fc3fb569b9b124b581e5256
上级 aa953c3b
......@@ -57,7 +57,7 @@ def _is_module(obj):
def _get_XNorm_typeclass():
from .batchnorm import _BatchNorm
from .normalization import GroupNorm, LayerNorm, InstanceNorm
from .normalization import GroupNorm, InstanceNorm, LayerNorm
XNorm_types = (_BatchNorm, GroupNorm, LayerNorm, InstanceNorm)
return XNorm_types
......
......@@ -19,7 +19,10 @@ class Conv2d(Float.Conv2d, QATModule):
def calc_conv_qat(self, inp):
w_qat = self.apply_quant_weight(self.weight)
b_qat = fake_quant_bias(self.bias, inp, w_qat)
if self.weight_fake_quant and self.weight_fake_quant.enabled:
b_qat = fake_quant_bias(self.bias, inp, w_qat)
else:
b_qat = self.bias
conv = self.calc_conv(inp, w_qat, b_qat)
return conv
......
......@@ -122,7 +122,10 @@ class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule):
b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd
w_qat = self.apply_quant_weight(w_fold)
b_qat = fake_quant_bias(b_fold, inp, w_qat)
if self.weight_fake_quant and self.weight_fake_quant.enabled:
b_qat = fake_quant_bias(b_fold, inp, w_qat)
else:
b_qat = b_fold
conv = self.conv.calc_conv(inp, w_qat, b_qat)
if not (self.training and approx):
return conv
......
......@@ -24,7 +24,10 @@ class Linear(Float.Linear, QATModule):
def forward(self, x):
w_qat = self.apply_quant_weight(self.weight)
b_qat = fake_quant_bias(self.bias, x, w_qat)
if self.weight_fake_quant and self.weight_fake_quant.enabled:
b_qat = fake_quant_bias(self.bias, x, w_qat)
else:
b_qat = self.bias
return self.apply_quant_activation(self._calc_linear(x, w_qat, b_qat))
@classmethod
......
......@@ -116,7 +116,7 @@ class TQT(_FakeQuantize):
def __init__(self, dtype: str, narrow_range: bool = False, enable: bool = True):
super().__init__(dtype, narrow_range, enable)
self.scale = Parameter(0.0, dtype=np.float32)
self.scale = Parameter([0.0], dtype=np.float32)
def fake_quant_forward(self, inp, q_dict=None):
# when enable, TQT will do fakequant forward, finetune the scale
......
......@@ -219,8 +219,8 @@ class HistogramObserver(MinMaxObserver):
By selecting new min/max, we filter out outliers in input distribution.
"""
np_min_val = self.min_val.numpy()[0]
np_max_val = self.max_val.numpy()[0]
np_min_val = self.min_val.numpy()
np_max_val = self.max_val.numpy()
np_histogram = self.histogram.numpy()
assert len(np_histogram) == self.bins, "bins mistmatch"
bin_width = (np_max_val - np_min_val) / self.bins
......@@ -386,8 +386,8 @@ class HistogramObserver(MinMaxObserver):
# This allows us to have a common grid of resolution s, where we can align
# the input histogram
# start_idx maps min_val to the histogram bin index.
np_min_val = self.min_val.numpy()[0]
np_max_val = self.max_val.numpy()[0]
np_min_val = self.min_val.numpy()
np_max_val = self.max_val.numpy()
hist_bin_width = (np_max_val - np_min_val) / (self.bins * upsample_rate)
downsample_rate = int(
......@@ -404,8 +404,8 @@ class HistogramObserver(MinMaxObserver):
def sideeffect_forward(self, x_orig):
x = x_orig.numpy()
min_val = self.min_val.numpy()[0]
max_val = self.max_val.numpy()[0]
min_val = self.min_val.numpy()
max_val = self.max_val.numpy()
histogram = self.histogram.numpy()
new_min = x.min()
new_max = x.max()
......
......@@ -125,5 +125,6 @@ def fake_quant_bias(bias: Tensor, inp: Tensor, w_qat: Tensor) -> Tensor:
qmax = _metadata_dict["qint32"].qmax
qmin = _metadata_dict["qint32"].qmin
b_qat = fake_quant_tensor(b_qat, qmin, qmax, b_dict)
b_qat.q_dict.update(b_dict)
return b_qat
......@@ -115,9 +115,10 @@ def _dump_compatible(entries: List[ProfileEntry], path: str):
def _dump_graphviz(entries: List[ProfileEntry], path: str):
import graphviz
import json
import graphviz
graph = graphviz.Digraph()
graph.graph_attr["ordering"] = "out"
var_cache = {}
......
......@@ -14,8 +14,8 @@ from megengine.core.tensor.core import apply
def elemwise(*args, mode):
from megengine.core.ops.builtin import Elemwise
from megengine.core._imperative_rt.imperative import apply_op
from megengine.core.ops.builtin import Elemwise
return apply_op(Elemwise(mode), args)
......@@ -61,8 +61,8 @@ def test_tensor_on_device():
def test_raw_tensor():
from megengine.core.tensor.raw_tensor import as_raw_tensor
from megengine.core.ops.builtin import Elemwise
from megengine.core.tensor.raw_tensor import as_raw_tensor
x = np.random.rand(10).astype("float32")
xx = as_raw_tensor(x)
......
......@@ -5,9 +5,18 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
import pytest
from megengine import module as Float
from megengine import tensor
from megengine.module import qat as QAT
from megengine.quantization.quantize import _get_quantable_module_names, quantize_qat
from megengine.quantization import min_max_fakequant_qconfig
from megengine.quantization.quantize import (
_get_quantable_module_names,
disable_fake_quant,
quantize_qat,
)
def test_get_quantable_module_names():
......@@ -78,3 +87,30 @@ def test_convert_with_custom_mapping():
net = Net()
qat_net = quantize_qat(net, inplace=False, mapping={FloatExample: QATExample})
assert isinstance(qat_net.example, QATExample)
def test_disable_fake_quant():
class Net(Float.Module):
def __init__(self):
super().__init__()
self.quant = Float.QuantStub()
self.linear = Float.Linear(3, 3)
self.dequant = Float.DequantStub()
self.linear.bias.set_value(np.random.rand(3))
def forward(self, x):
x = self.quant(x)
x = self.linear(x)
x = self.dequant(x)
return x
x = tensor(np.random.randint(1, 10, size=(3, 3)).astype(np.float32))
net = Net()
y1 = net(x).numpy()
net = quantize_qat(net, min_max_fakequant_qconfig)
y2 = net(x).numpy()
disable_fake_quant(net)
y3 = net(x).numpy()
np.testing.assert_allclose(y1, y3)
with pytest.raises(AssertionError):
np.testing.assert_allclose(y2, y3)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册