diff --git a/imperative/python/megengine/module/quantized/concat.py b/imperative/python/megengine/module/quantized/concat.py index 11af852913f7430b27688663eb65808f39159a54..7fef963e827d65340cea443b702526fd2b716b63 100644 --- a/imperative/python/megengine/module/quantized/concat.py +++ b/imperative/python/megengine/module/quantized/concat.py @@ -23,7 +23,7 @@ class Concat(QuantizedModule): self.output_dtype = dtype def forward(self, inps: Iterable[Tensor], axis: int = 0): - new_inps = (x.astype(self.output_dtype) for x in inps) + new_inps = tuple(x.astype(self.output_dtype) for x in inps) return F.concat(new_inps, axis) @classmethod diff --git a/imperative/python/megengine/quantization/__init__.py b/imperative/python/megengine/quantization/__init__.py index d33de8d7b328debc989d1693b3766c404baf5a03..2d6bf959b3b460152db31a5755728e6fe5439eb9 100644 --- a/imperative/python/megengine/quantization/__init__.py +++ b/imperative/python/megengine/quantization/__init__.py @@ -6,8 +6,16 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -from .fake_quant import FakeQuantize -from .observer import Observer +from .fake_quant import TQT, FakeQuantize +from .observer import ( + ExponentialMovingAverageObserver, + HistogramObserver, + MinMaxObserver, + Observer, + PassiveObserver, + SyncExponentialMovingAverageObserver, + SyncMinMaxObserver, +) from .qconfig import ( QConfig, calibration_qconfig, diff --git a/imperative/python/test/unit/quantization/test_module.py b/imperative/python/test/unit/quantization/test_module.py index cfdf03448a6777f4ba627abbccda487694118d4d..bbb95cff6eb92a69e4eb21ef290865c86bea3ab1 100644 --- a/imperative/python/test/unit/quantization/test_module.py +++ b/imperative/python/test/unit/quantization/test_module.py @@ -30,7 +30,10 @@ min_max_fakequant_qconfig = QConfig( act_fake_quant=partial(FakeQuantize, dtype="qint8"), ) -inp_scale = np.float32(np.random.rand() + 1) + +def gen_inp_scale(): + return np.float32(np.random.rand() + 1) + min_val = np.random.randint(-127, 0, size=(2,)).astype("float32") max_val = np.random.randint(1, 127, size=(2,)).astype("float32") @@ -116,6 +119,7 @@ def test_dequant_stub(): q_net.eval() x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) + inp_scale = gen_inp_scale() x = fake_quant_act(x, inp_scale) x.qparams.scale = inp_scale @@ -192,6 +196,7 @@ def test_linear(): init_qat_net(qat_net) x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) + inp_scale = gen_inp_scale() x = fake_quant_act(x, inp_scale) x.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", inp_scale)) @@ -235,6 +240,7 @@ def test_conv(module): init_qat_net(qat_net) x = mge.tensor(np.random.normal(size=(1, 3, 3, 3)).astype("float32")) + inp_scale = gen_inp_scale() x = fake_quant_act(x, inp_scale) x.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", inp_scale)) @@ -269,3 +275,41 @@ def test_conv(module): np.testing.assert_allclose(qat_without_fakequant, normal, atol=1e-5) np.testing.assert_allclose(qat, fake_quant_normal, atol=act_scale) np.testing.assert_allclose(q, fake_quant_normal.numpy(), atol=act_scale) + + +def test_concat(): + normal_net = Float.Concat() + normal_net.eval() + + qat_net = QAT.Concat() + qat_net.eval() + disable_observer(qat_net) + + propagate_qconfig(qat_net, min_max_fakequant_qconfig) + init_qat_net(qat_net) + + inps = [] + inps_int8 = [] + for i in range(3): + inp_scale = gen_inp_scale() + inps.append(mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))) + inps[i] = fake_quant_act(inps[i], inp_scale) + inps[i].qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", inp_scale)) + inps_int8.append(quant(inps[i], inp_scale)) + + qat_from_float = QAT.Concat.from_float_module(normal_net) + qat_from_float.eval() + disable_fake_quant(qat_from_float) + disable_observer(qat_from_float) + + q_net = Q.Concat.from_qat_module(qat_net) + q_net.eval() + + normal = normal_net(inps) + qat_without_fakequant = qat_from_float(inps) + fake_quant_normal = fake_quant_act(normal_net(inps), act_scale) + qat = qat_net(inps) + q = q_net(inps_int8).numpy() * act_scale + np.testing.assert_allclose(qat_without_fakequant, normal) + np.testing.assert_allclose(qat, fake_quant_normal.numpy()) + np.testing.assert_allclose(q, fake_quant_normal.numpy())