From dc96f6aa49da23296ed38f350081e85e80051f5c Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 30 Mar 2021 11:16:03 +0800 Subject: [PATCH] fix(mge/quantization): fix quantized concat forward problem GitOrigin-RevId: dc21b340d1732fa5c1c186904d2a6c1b13e10121 --- .../megengine/module/quantized/concat.py | 2 +- .../python/megengine/quantization/__init__.py | 12 ++++- .../test/unit/quantization/test_module.py | 46 ++++++++++++++++++- 3 files changed, 56 insertions(+), 4 deletions(-) diff --git a/imperative/python/megengine/module/quantized/concat.py b/imperative/python/megengine/module/quantized/concat.py index 11af85291..7fef963e8 100644 --- a/imperative/python/megengine/module/quantized/concat.py +++ b/imperative/python/megengine/module/quantized/concat.py @@ -23,7 +23,7 @@ class Concat(QuantizedModule): self.output_dtype = dtype def forward(self, inps: Iterable[Tensor], axis: int = 0): - new_inps = (x.astype(self.output_dtype) for x in inps) + new_inps = tuple(x.astype(self.output_dtype) for x in inps) return F.concat(new_inps, axis) @classmethod diff --git a/imperative/python/megengine/quantization/__init__.py b/imperative/python/megengine/quantization/__init__.py index d33de8d7b..2d6bf959b 100644 --- a/imperative/python/megengine/quantization/__init__.py +++ b/imperative/python/megengine/quantization/__init__.py @@ -6,8 +6,16 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -from .fake_quant import FakeQuantize -from .observer import Observer +from .fake_quant import TQT, FakeQuantize +from .observer import ( + ExponentialMovingAverageObserver, + HistogramObserver, + MinMaxObserver, + Observer, + PassiveObserver, + SyncExponentialMovingAverageObserver, + SyncMinMaxObserver, +) from .qconfig import ( QConfig, calibration_qconfig, diff --git a/imperative/python/test/unit/quantization/test_module.py b/imperative/python/test/unit/quantization/test_module.py index cfdf03448..bbb95cff6 100644 --- a/imperative/python/test/unit/quantization/test_module.py +++ b/imperative/python/test/unit/quantization/test_module.py @@ -30,7 +30,10 @@ min_max_fakequant_qconfig = QConfig( act_fake_quant=partial(FakeQuantize, dtype="qint8"), ) -inp_scale = np.float32(np.random.rand() + 1) + +def gen_inp_scale(): + return np.float32(np.random.rand() + 1) + min_val = np.random.randint(-127, 0, size=(2,)).astype("float32") max_val = np.random.randint(1, 127, size=(2,)).astype("float32") @@ -116,6 +119,7 @@ def test_dequant_stub(): q_net.eval() x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) + inp_scale = gen_inp_scale() x = fake_quant_act(x, inp_scale) x.qparams.scale = inp_scale @@ -192,6 +196,7 @@ def test_linear(): init_qat_net(qat_net) x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) + inp_scale = gen_inp_scale() x = fake_quant_act(x, inp_scale) x.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", inp_scale)) @@ -235,6 +240,7 @@ def test_conv(module): init_qat_net(qat_net) x = mge.tensor(np.random.normal(size=(1, 3, 3, 3)).astype("float32")) + inp_scale = gen_inp_scale() x = fake_quant_act(x, inp_scale) x.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", inp_scale)) @@ -269,3 +275,41 @@ def test_conv(module): np.testing.assert_allclose(qat_without_fakequant, normal, atol=1e-5) np.testing.assert_allclose(qat, fake_quant_normal, atol=act_scale) np.testing.assert_allclose(q, fake_quant_normal.numpy(), atol=act_scale) + + +def test_concat(): + normal_net = Float.Concat() + normal_net.eval() + + qat_net = QAT.Concat() + qat_net.eval() + disable_observer(qat_net) + + propagate_qconfig(qat_net, min_max_fakequant_qconfig) + init_qat_net(qat_net) + + inps = [] + inps_int8 = [] + for i in range(3): + inp_scale = gen_inp_scale() + inps.append(mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))) + inps[i] = fake_quant_act(inps[i], inp_scale) + inps[i].qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", inp_scale)) + inps_int8.append(quant(inps[i], inp_scale)) + + qat_from_float = QAT.Concat.from_float_module(normal_net) + qat_from_float.eval() + disable_fake_quant(qat_from_float) + disable_observer(qat_from_float) + + q_net = Q.Concat.from_qat_module(qat_net) + q_net.eval() + + normal = normal_net(inps) + qat_without_fakequant = qat_from_float(inps) + fake_quant_normal = fake_quant_act(normal_net(inps), act_scale) + qat = qat_net(inps) + q = q_net(inps_int8).numpy() * act_scale + np.testing.assert_allclose(qat_without_fakequant, normal) + np.testing.assert_allclose(qat, fake_quant_normal.numpy()) + np.testing.assert_allclose(q, fake_quant_normal.numpy()) -- GitLab