From 6c4841e807141ef4a1ce2aa7249d5ffe638bf34b Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 23 Nov 2020 13:50:26 +0800 Subject: [PATCH] fix(mge/quantization): `disable_fake_quant` does not work correctly GitOrigin-RevId: 0c568d3335aa55cf0fc3fb569b9b124b581e5256 --- imperative/python/megengine/module/module.py | 2 +- .../python/megengine/module/qat/conv.py | 5 ++- .../python/megengine/module/qat/conv_bn.py | 5 ++- .../python/megengine/module/qat/linear.py | 5 ++- .../megengine/quantization/fake_quant.py | 2 +- .../python/megengine/quantization/observer.py | 12 +++--- .../python/megengine/quantization/utils.py | 1 + imperative/python/megengine/utils/profiler.py | 3 +- .../test/unit/core/test_imperative_rt.py | 4 +- .../python/test/unit/quantization/quantize.py | 38 ++++++++++++++++++- 10 files changed, 62 insertions(+), 15 deletions(-) diff --git a/imperative/python/megengine/module/module.py b/imperative/python/megengine/module/module.py index 41bb6720..d7ca2c91 100644 --- a/imperative/python/megengine/module/module.py +++ b/imperative/python/megengine/module/module.py @@ -57,7 +57,7 @@ def _is_module(obj): def _get_XNorm_typeclass(): from .batchnorm import _BatchNorm - from .normalization import GroupNorm, LayerNorm, InstanceNorm + from .normalization import GroupNorm, InstanceNorm, LayerNorm XNorm_types = (_BatchNorm, GroupNorm, LayerNorm, InstanceNorm) return XNorm_types diff --git a/imperative/python/megengine/module/qat/conv.py b/imperative/python/megengine/module/qat/conv.py index 315da839..a44db384 100644 --- a/imperative/python/megengine/module/qat/conv.py +++ b/imperative/python/megengine/module/qat/conv.py @@ -19,7 +19,10 @@ class Conv2d(Float.Conv2d, QATModule): def calc_conv_qat(self, inp): w_qat = self.apply_quant_weight(self.weight) - b_qat = fake_quant_bias(self.bias, inp, w_qat) + if self.weight_fake_quant and self.weight_fake_quant.enabled: + b_qat = fake_quant_bias(self.bias, inp, w_qat) + else: + b_qat = self.bias conv = self.calc_conv(inp, w_qat, b_qat) return conv diff --git a/imperative/python/megengine/module/qat/conv_bn.py b/imperative/python/megengine/module/qat/conv_bn.py index 89e270f1..3de3509b 100644 --- a/imperative/python/megengine/module/qat/conv_bn.py +++ b/imperative/python/megengine/module/qat/conv_bn.py @@ -122,7 +122,10 @@ class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule): b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd w_qat = self.apply_quant_weight(w_fold) - b_qat = fake_quant_bias(b_fold, inp, w_qat) + if self.weight_fake_quant and self.weight_fake_quant.enabled: + b_qat = fake_quant_bias(b_fold, inp, w_qat) + else: + b_qat = b_fold conv = self.conv.calc_conv(inp, w_qat, b_qat) if not (self.training and approx): return conv diff --git a/imperative/python/megengine/module/qat/linear.py b/imperative/python/megengine/module/qat/linear.py index 6c57beca..d0b454c0 100644 --- a/imperative/python/megengine/module/qat/linear.py +++ b/imperative/python/megengine/module/qat/linear.py @@ -24,7 +24,10 @@ class Linear(Float.Linear, QATModule): def forward(self, x): w_qat = self.apply_quant_weight(self.weight) - b_qat = fake_quant_bias(self.bias, x, w_qat) + if self.weight_fake_quant and self.weight_fake_quant.enabled: + b_qat = fake_quant_bias(self.bias, x, w_qat) + else: + b_qat = self.bias return self.apply_quant_activation(self._calc_linear(x, w_qat, b_qat)) @classmethod diff --git a/imperative/python/megengine/quantization/fake_quant.py b/imperative/python/megengine/quantization/fake_quant.py index 774a7cae..2846131f 100644 --- a/imperative/python/megengine/quantization/fake_quant.py +++ b/imperative/python/megengine/quantization/fake_quant.py @@ -116,7 +116,7 @@ class TQT(_FakeQuantize): def __init__(self, dtype: str, narrow_range: bool = False, enable: bool = True): super().__init__(dtype, narrow_range, enable) - self.scale = Parameter(0.0, dtype=np.float32) + self.scale = Parameter([0.0], dtype=np.float32) def fake_quant_forward(self, inp, q_dict=None): # when enable, TQT will do fakequant forward, finetune the scale diff --git a/imperative/python/megengine/quantization/observer.py b/imperative/python/megengine/quantization/observer.py index 7d3f452e..e3528e15 100644 --- a/imperative/python/megengine/quantization/observer.py +++ b/imperative/python/megengine/quantization/observer.py @@ -219,8 +219,8 @@ class HistogramObserver(MinMaxObserver): By selecting new min/max, we filter out outliers in input distribution. """ - np_min_val = self.min_val.numpy()[0] - np_max_val = self.max_val.numpy()[0] + np_min_val = self.min_val.numpy() + np_max_val = self.max_val.numpy() np_histogram = self.histogram.numpy() assert len(np_histogram) == self.bins, "bins mistmatch" bin_width = (np_max_val - np_min_val) / self.bins @@ -386,8 +386,8 @@ class HistogramObserver(MinMaxObserver): # This allows us to have a common grid of resolution s, where we can align # the input histogram # start_idx maps min_val to the histogram bin index. - np_min_val = self.min_val.numpy()[0] - np_max_val = self.max_val.numpy()[0] + np_min_val = self.min_val.numpy() + np_max_val = self.max_val.numpy() hist_bin_width = (np_max_val - np_min_val) / (self.bins * upsample_rate) downsample_rate = int( @@ -404,8 +404,8 @@ class HistogramObserver(MinMaxObserver): def sideeffect_forward(self, x_orig): x = x_orig.numpy() - min_val = self.min_val.numpy()[0] - max_val = self.max_val.numpy()[0] + min_val = self.min_val.numpy() + max_val = self.max_val.numpy() histogram = self.histogram.numpy() new_min = x.min() new_max = x.max() diff --git a/imperative/python/megengine/quantization/utils.py b/imperative/python/megengine/quantization/utils.py index 284ce8c5..970c6bee 100644 --- a/imperative/python/megengine/quantization/utils.py +++ b/imperative/python/megengine/quantization/utils.py @@ -125,5 +125,6 @@ def fake_quant_bias(bias: Tensor, inp: Tensor, w_qat: Tensor) -> Tensor: qmax = _metadata_dict["qint32"].qmax qmin = _metadata_dict["qint32"].qmin b_qat = fake_quant_tensor(b_qat, qmin, qmax, b_dict) + b_qat.q_dict.update(b_dict) return b_qat diff --git a/imperative/python/megengine/utils/profiler.py b/imperative/python/megengine/utils/profiler.py index b437c03d..0fe88de9 100644 --- a/imperative/python/megengine/utils/profiler.py +++ b/imperative/python/megengine/utils/profiler.py @@ -115,9 +115,10 @@ def _dump_compatible(entries: List[ProfileEntry], path: str): def _dump_graphviz(entries: List[ProfileEntry], path: str): - import graphviz import json + import graphviz + graph = graphviz.Digraph() graph.graph_attr["ordering"] = "out" var_cache = {} diff --git a/imperative/python/test/unit/core/test_imperative_rt.py b/imperative/python/test/unit/core/test_imperative_rt.py index bc622faf..9b8764aa 100644 --- a/imperative/python/test/unit/core/test_imperative_rt.py +++ b/imperative/python/test/unit/core/test_imperative_rt.py @@ -14,8 +14,8 @@ from megengine.core.tensor.core import apply def elemwise(*args, mode): - from megengine.core.ops.builtin import Elemwise from megengine.core._imperative_rt.imperative import apply_op + from megengine.core.ops.builtin import Elemwise return apply_op(Elemwise(mode), args) @@ -61,8 +61,8 @@ def test_tensor_on_device(): def test_raw_tensor(): - from megengine.core.tensor.raw_tensor import as_raw_tensor from megengine.core.ops.builtin import Elemwise + from megengine.core.tensor.raw_tensor import as_raw_tensor x = np.random.rand(10).astype("float32") xx = as_raw_tensor(x) diff --git a/imperative/python/test/unit/quantization/quantize.py b/imperative/python/test/unit/quantization/quantize.py index 236ef9e1..e912e0c0 100644 --- a/imperative/python/test/unit/quantization/quantize.py +++ b/imperative/python/test/unit/quantization/quantize.py @@ -5,9 +5,18 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import numpy as np +import pytest + from megengine import module as Float +from megengine import tensor from megengine.module import qat as QAT -from megengine.quantization.quantize import _get_quantable_module_names, quantize_qat +from megengine.quantization import min_max_fakequant_qconfig +from megengine.quantization.quantize import ( + _get_quantable_module_names, + disable_fake_quant, + quantize_qat, +) def test_get_quantable_module_names(): @@ -78,3 +87,30 @@ def test_convert_with_custom_mapping(): net = Net() qat_net = quantize_qat(net, inplace=False, mapping={FloatExample: QATExample}) assert isinstance(qat_net.example, QATExample) + + +def test_disable_fake_quant(): + class Net(Float.Module): + def __init__(self): + super().__init__() + self.quant = Float.QuantStub() + self.linear = Float.Linear(3, 3) + self.dequant = Float.DequantStub() + self.linear.bias.set_value(np.random.rand(3)) + + def forward(self, x): + x = self.quant(x) + x = self.linear(x) + x = self.dequant(x) + return x + + x = tensor(np.random.randint(1, 10, size=(3, 3)).astype(np.float32)) + net = Net() + y1 = net(x).numpy() + net = quantize_qat(net, min_max_fakequant_qconfig) + y2 = net(x).numpy() + disable_fake_quant(net) + y3 = net(x).numpy() + np.testing.assert_allclose(y1, y3) + with pytest.raises(AssertionError): + np.testing.assert_allclose(y2, y3) -- GitLab