From 94796060b67d49d48417d8b6c8c3d01c245a0e59 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 15 Dec 2020 14:41:05 +0800 Subject: [PATCH] test(mge/quantization): qat module, quantized op and module GitOrigin-RevId: 17b28060ccf320972a22663eee15578b4b9798ee --- .../megengine/module/quantized/linear.py | 7 +- .../python/megengine/quantization/__init__.py | 1 + .../python/megengine/quantization/qconfig.py | 2 + .../python/megengine/quantization/quantize.py | 18 +- .../test/unit/quantization/test_module.py | 203 ++++++++++++++++++ .../test/unit/quantization/test_observer.py | 4 +- .../python/test/unit/quantization/test_op.py | 161 ++++++++++++++ .../test/unit/quantization/test_quantize.py | 14 +- 8 files changed, 389 insertions(+), 21 deletions(-) create mode 100644 imperative/python/test/unit/quantization/test_module.py create mode 100644 imperative/python/test/unit/quantization/test_op.py diff --git a/imperative/python/megengine/module/quantized/linear.py b/imperative/python/megengine/module/quantized/linear.py index 7f8ac43dc..1a817c5cc 100644 --- a/imperative/python/megengine/module/quantized/linear.py +++ b/imperative/python/megengine/module/quantized/linear.py @@ -29,11 +29,14 @@ class Linear(QuantizedModule): inp_scale = dtype.get_scale(inp.dtype) w_scale = dtype.get_scale(self.weight.dtype) bias_dtype = dtype.qint32(inp_scale * w_scale) - return F.nn.linear( + ret = F.nn.linear( inp, self.weight, None if self.bias is None else self.bias.astype(bias_dtype), - ).astype(self.output_dtype) + ) + ret = ret if self.output_dtype is None else ret.astype(self.output_dtype) + return ret + @classmethod def from_qat_module(cls, qat_module: QAT.Linear): diff --git a/imperative/python/megengine/quantization/__init__.py b/imperative/python/megengine/quantization/__init__.py index 427365e5a..452450781 100644 --- a/imperative/python/megengine/quantization/__init__.py +++ b/imperative/python/megengine/quantization/__init__.py @@ -12,6 +12,7 @@ from .observer import HistogramObserver, Observer from .qconfig import ( QConfig, calibration_qconfig, + easyquant_qconfig, ema_fakequant_qconfig, ema_lowbit_fakequant_qconfig, min_max_fakequant_qconfig, diff --git a/imperative/python/megengine/quantization/qconfig.py b/imperative/python/megengine/quantization/qconfig.py index 5213a622f..5eef6c32d 100644 --- a/imperative/python/megengine/quantization/qconfig.py +++ b/imperative/python/megengine/quantization/qconfig.py @@ -138,3 +138,5 @@ passive_qconfig = QConfig( weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True), act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False), ) + +easyquant_qconfig = passive_qconfig diff --git a/imperative/python/megengine/quantization/quantize.py b/imperative/python/megengine/quantization/quantize.py index d6216d9b7..09b6466dd 100644 --- a/imperative/python/megengine/quantization/quantize.py +++ b/imperative/python/megengine/quantization/quantize.py @@ -223,11 +223,11 @@ def apply_easy_quant(module, data, start=0.8, stop=1.2, num=40): mod._forward_hooks.clear() - fp32_in = [_[:batch_size] for _ in inputs] - int8_in = [_[batch_size:] for _ in inputs] + normal_in = [_[:batch_size] for _ in inputs] + fakequant_in = [_[batch_size:] for _ in inputs] disable_fake_quant(mod) - fp32_out = mod(*fp32_in) + normal_out = mod(*normal_in) enable_fake_quant(mod) ob = getattr(mod, where) @@ -239,19 +239,15 @@ def apply_easy_quant(module, data, start=0.8, stop=1.2, num=40): best_scale = 0 for scale in np.linspace(start * orig_scale, stop * orig_scale, num): ob.scale = scale - int8_out = mod(*int8_in) - dis = get_cosine(fp32_out, int8_out) + fakequant_out = mod(*fakequant_in) + dis = get_cosine(normal_out, fakequant_out) if dis > distance: distance = dis best_scale = scale ob.scale = best_scale - if where == "act_observer": - int8_out = mod(*int8_in) - return concat([fp32_out, int8_out]) - else: - int8_out = outputs[batch_size:] - return concat([fp32_out, int8_out]) + fakequant_out = outputs[batch_size:] + return concat([normal_out, fakequant_out]) data = concat([data, data]) diff --git a/imperative/python/test/unit/quantization/test_module.py b/imperative/python/test/unit/quantization/test_module.py new file mode 100644 index 000000000..0961bae77 --- /dev/null +++ b/imperative/python/test/unit/quantization/test_module.py @@ -0,0 +1,203 @@ +import numpy as np +import pytest + +import megengine as mge +import megengine.functional as F +import megengine.module as Float +import megengine.module.qat as QAT +import megengine.module.quantized as Q +from megengine.core.tensor import dtype +from megengine.quantization import min_max_fakequant_qconfig +from megengine.quantization.quantize import disable_observer, propagate_qconfig + +""" +Calculate testing scales based on ``min_max_fakequant_qconfig`` +""" + +inp_scale = np.float32(np.random.rand() + 1) + +min_val = np.random.randint(-127, 0, size=(2,)).astype("float32") +max_val = np.random.randint(1, 127, size=(2,)).astype("float32") +weight_scale = np.float32(np.max([-min_val[0], max_val[0]]) / 254 * 2) +act_scale = np.float32(np.max([-min_val[1], max_val[1]]) / 255 * 2) + + +def quant(x, scale): + inp_dtype = dtype.qint8(scale) + return x.astype(inp_dtype) + + +def fake_quant(x, scale): + x = x / scale + x = F.round(x) + x = F.clip(x, -128, 127) + x = x * scale + return x + + +def init_qat_net(net): + if net.with_weight: + net.weight_observer.min_val.set_value(min_val[0]) + net.weight_observer.max_val.set_value(max_val[0]) + if net.with_act: + net.act_observer.min_val.set_value(min_val[1]) + net.act_observer.max_val.set_value(max_val[1]) + + +def test_quant_stub(): + normal_net = Float.QuantStub() + normal_net.eval() + qat_net = QAT.QuantStub() + qat_net.eval() + disable_observer(qat_net) + + propagate_qconfig(qat_net, min_max_fakequant_qconfig) + init_qat_net(qat_net) + + q_net = Q.QuantStub.from_qat_module(qat_net) + q_net.eval() + + x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) + + normal_out = fake_quant(normal_net(x), act_scale) + qat_out = qat_net(x) + q_out = q_net(x).numpy() * act_scale + np.testing.assert_allclose(qat_out, normal_out) + np.testing.assert_allclose(q_out, normal_out.numpy()) + + +def test_dequant_stub(): + normal_net = Float.DequantStub() + normal_net.eval() + qat_net = QAT.DequantStub() + qat_net.eval() + disable_observer(qat_net) + + propagate_qconfig(qat_net, min_max_fakequant_qconfig) + init_qat_net(qat_net) + + q_net = Q.DequantStub.from_qat_module(qat_net) + q_net.eval() + + x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) + x = fake_quant(x, inp_scale) + x.q_dict["scale"] = inp_scale + + normal_out = normal_net(x) + qat_out = qat_net(x) + q_out = q_net(quant(x, inp_scale)).numpy() + np.testing.assert_allclose(qat_out, normal_out) + np.testing.assert_allclose(q_out, normal_out.numpy()) + + +@pytest.mark.parametrize("kind", ["COS", "RELU", "ADD", "MUL", "FUSE_ADD_RELU"]) +def test_elemwise(kind): + normal_net = Float.Elemwise(kind) + normal_net.eval() + qat_net = QAT.Elemwise(kind) + qat_net.eval() + disable_observer(qat_net) + + propagate_qconfig(qat_net, min_max_fakequant_qconfig) + init_qat_net(qat_net) + + q_net = Q.Elemwise.from_qat_module(qat_net) + q_net.eval() + + x1_scale = np.float32(np.random.rand() + 1) + x1 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) + x1 = fake_quant(x1, x1_scale) + x1.q_dict["scale"] = x1_scale + + x2_scale = np.float32(np.random.rand() + 1) + x2 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) + x2 = fake_quant(x2, x2_scale) + x2.q_dict["scale"] = x2_scale + + x1_int8 = quant(x1, x1_scale) + x2_int8 = quant(x2, x2_scale) + + if kind in ("ADD", "MUL", "FUSE_ADD_RELU"): + normal_out = fake_quant(normal_net(x1, x2), act_scale) + qat_out = qat_net(x1, x2) + q_out = q_net(x1_int8, x2_int8).numpy() * act_scale + else: + normal_out = fake_quant(normal_net(x1), act_scale) + qat_out = qat_net(x1) + q_out = q_net(x1_int8).numpy() * act_scale + np.testing.assert_allclose(qat_out, normal_out) + np.testing.assert_allclose(q_out, normal_out.numpy()) + + +def test_linear(): + normal_net = Float.Linear(3, 3, bias=True) + normal_net.eval() + + qat_net = QAT.Linear(3, 3, bias=True) + qat_net.eval() + disable_observer(qat_net) + + propagate_qconfig(qat_net, min_max_fakequant_qconfig) + init_qat_net(qat_net) + + x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) + x = fake_quant(x, inp_scale) + x.q_dict["scale"] = inp_scale + + x_int8 = quant(x, inp_scale) + + weight = np.random.normal(size=(3, 3)).astype("float32") + bias = np.random.normal(size=(3,)).astype("float32") + normal_net.weight.set_value(fake_quant(weight, weight_scale)) + normal_net.bias.set_value(fake_quant(bias, inp_scale * weight_scale)) + qat_net.weight.set_value(weight) + qat_net.bias.set_value(bias) + + q_net = Q.Linear.from_qat_module(qat_net) + q_net.eval() + + normal_out = fake_quant(normal_net(x), act_scale) + qat_out = qat_net(x) + q_out = q_net(x_int8).numpy() * act_scale + np.testing.assert_allclose(qat_out, normal_out) + np.testing.assert_allclose(q_out, normal_out.numpy()) + + +@pytest.mark.parametrize("module", ["Conv2d", "ConvBn2d", "ConvBnRelu2d"]) +def test_conv(module): + normal_net = getattr(Float, module)(3, 3, 3, 1, 1, 1, bias=True) + normal_net.eval() + qat_net = getattr(QAT, module)(3, 3, 3, 1, 1, 1, bias=True) + qat_net.eval() + disable_observer(qat_net) + + propagate_qconfig(qat_net, min_max_fakequant_qconfig) + init_qat_net(qat_net) + + x = mge.tensor(np.random.normal(size=(1, 3, 3, 3)).astype("float32")) + x = fake_quant(x, inp_scale) + x.q_dict["scale"] = inp_scale + + x_int8 = quant(x, inp_scale) + + weight = np.random.normal(size=(3, 3, 3, 3)).astype("float32") + bias = np.random.normal(size=(1, 3, 1, 1)).astype("float32") + if module in ("ConvBn2d", "ConvBnRelu2d"): + normal_net.conv.weight.set_value(fake_quant(weight, weight_scale)) + normal_net.conv.bias.set_value(fake_quant(bias, inp_scale * weight_scale)) + qat_net.conv.weight.set_value(weight) + qat_net.conv.bias.set_value(bias) + else: + normal_net.weight.set_value(fake_quant(weight, weight_scale)) + normal_net.bias.set_value(fake_quant(bias, inp_scale * weight_scale)) + qat_net.weight.set_value(weight) + qat_net.bias.set_value(bias) + + q_net = getattr(Q, module).from_qat_module(qat_net) + q_net.eval() + + normal_out = fake_quant(normal_net(x), act_scale) + qat_out = qat_net(x) + q_out = q_net(x_int8).numpy() * act_scale + np.testing.assert_allclose(qat_out, normal_out) + np.testing.assert_allclose(q_out, normal_out.numpy()) diff --git a/imperative/python/test/unit/quantization/test_observer.py b/imperative/python/test/unit/quantization/test_observer.py index f27a8514d..aa9622640 100644 --- a/imperative/python/test/unit/quantization/test_observer.py +++ b/imperative/python/test/unit/quantization/test_observer.py @@ -103,7 +103,7 @@ def test_sync_exponential_moving_average_observer(): y2 = mge.tensor(x2[rank * 3 : (rank + 1) * 3]) m(y1) m(y2) - np.testing.assert_allclose(m.min_val.numpy(), expected_min) - np.testing.assert_allclose(m.max_val.numpy(), expected_max) + np.testing.assert_allclose(m.min_val.numpy(), expected_min, atol=1e-6) + np.testing.assert_allclose(m.max_val.numpy(), expected_max, atol=1e-6) worker() diff --git a/imperative/python/test/unit/quantization/test_op.py b/imperative/python/test/unit/quantization/test_op.py new file mode 100644 index 000000000..20e5b2fc6 --- /dev/null +++ b/imperative/python/test/unit/quantization/test_op.py @@ -0,0 +1,161 @@ +import numpy as np +import pytest + +import megengine as mge +import megengine.functional as F +from megengine.core.tensor import dtype +from megengine.distributed.helper import get_device_count_by_fork +from megengine.functional.elemwise import _elemwise_multi_type, _elwise + + +def quant(x, scale): + x_dtype = dtype.qint8(scale) + return x.astype(x_dtype) + + +def fake_quant(x, scale): + x = x / scale + x = F.round(x) + x = F.clip(x, -128, 127) + x = x * scale + return x + + +@pytest.mark.parametrize("kind", ["ABS", "SIN", "SUB", "MUL", "FUSE_ADD_TANH"]) +def test_elemwise(kind): + x1 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) + x1_scale = np.float32(np.random.rand() + 1) + x1 = fake_quant(x1, x1_scale) + x1.q_dict["scale"] = x1_scale + x1_int8 = quant(x1, x1_scale) + + x2 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) + x2_scale = np.float32(np.random.rand() + 1) + x2 = fake_quant(x2, x2_scale) + x2.q_dict["scale"] = x2_scale + x2_int8 = quant(x2, x2_scale) + + output_scale = np.float32(np.random.rand() + 1) + output_dtype = dtype.qint8(output_scale) + + quantized_kind = "Q" + kind + if kind in ("ABS", "SIN"): + desired_out = fake_quant(_elwise(x1, mode=kind), output_scale) + actual_out = ( + _elemwise_multi_type( + x1_int8, mode=quantized_kind, dtype=output_dtype + ).numpy() + * output_scale + ) + else: + desired_out = fake_quant(_elwise(x1, x2, mode=kind), output_scale) + actual_out = ( + _elemwise_multi_type( + x1_int8, x2_int8, mode=quantized_kind, dtype=output_dtype + ).numpy() + * output_scale + ) + np.testing.assert_allclose(actual_out, desired_out.numpy()) + + +@pytest.mark.skipif( + get_device_count_by_fork("gpu") > 0, reason="cuda does not support nchw int8" +) +def test_conv_bias(): + inp_scale = np.float32(np.random.rand() + 1) + w_scale = np.float32(np.random.rand() + 1) + outp_scale = np.float32(np.random.rand() + 1) + inp_dtype = dtype.qint8(inp_scale) + w_dtype = dtype.qint8(w_scale) + b_dtype = dtype.qint32(inp_scale * w_scale) + out_dtype = dtype.qint8(outp_scale) + + def run( + N, + IC, + OC, + IH, + IW, + KH, + KW, + PH, + PW, + SH, + SW, + has_bias=True, + nonlinear_mode="IDENTITY", + ): + inp_v = np.random.normal(size=(N, IC, IH, IW)) + w_v = np.random.normal(size=(OC, IC, KH, KW)) + b_v = np.random.normal(size=(1, OC, 1, 1)) + inp_scale = dtype.get_scale(inp_dtype) + w_scale = dtype.get_scale(w_dtype) + b_scale = dtype.get_scale(b_dtype) + + inpv = dtype.convert_to_qint8(inp_v * inp_scale, inp_dtype) + wv = dtype.convert_to_qint8(w_v * w_scale, w_dtype) + bv = dtype.convert_to_qint32(b_v * b_scale, b_dtype) + + inp_int8 = mge.tensor(inpv, dtype=inp_dtype) + w_int8 = mge.Parameter(wv, dtype=w_dtype) + b_int32 = mge.Parameter(bv, dtype=b_dtype) + + inp_fp32 = inp_int8.astype("float32") + w_fp32 = w_int8.astype("float32") + b_fp32 = b_int32.astype("float32") + + def convert_to_nchw4(var): + var = F.reshape( + var, (var.shape[0], var.shape[1] // 4, 4, var.shape[2], var.shape[3]) + ) + var = F.transpose(var, (0, 1, 3, 4, 2)) + return var + + def run_conv2d(inp, w, b): + O = F.conv2d( + inp, w, b if has_bias else None, stride=(SH, SW), padding=(PH, PW), + ) + if nonlinear_mode == "RELU": + return F.relu(O) + else: + return O + + def run_conv_bias(inp, w, b, format="NCHW"): + b = b if has_bias else mge.Parameter(np.zeros_like(b.numpy())) + if format == "NCHW4": + inp = convert_to_nchw4(inp) + w = convert_to_nchw4(w) + b = convert_to_nchw4(b) + return F.quantized.conv_bias_activation( + inp, + w, + b, + stride=(SH, SW), + padding=(PH, PW), + dtype=out_dtype, + nonlinear_mode=nonlinear_mode, + ) + + format = "NCHW4" if mge.is_cuda_available() else "NCHW" + + expected = run_conv2d(inp_fp32, w_fp32, b_fp32) + expected = expected.astype(out_dtype).astype("float32") + result = run_conv_bias(inp_int8, w_int8, b_int32, format=format).astype( + "float32" + ) + if format == "NCHW4": + result = F.transpose(result, (0, 1, 4, 2, 3)) + expected = F.flatten(expected) + result = F.flatten(result) + np.testing.assert_allclose(result.numpy(), expected.numpy(), atol=outp_scale) + + run(1, 4, 4, 24, 33, 1, 1, 2, 3, 1, 1, False) + run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1, False) + run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False) + + run(1, 4, 4, 24, 33, 1, 1, 2, 3, 1, 1) + run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1) + run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2) + + run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False, "RELU") + run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "RELU") diff --git a/imperative/python/test/unit/quantization/test_quantize.py b/imperative/python/test/unit/quantization/test_quantize.py index 20562d8be..0ce72422c 100644 --- a/imperative/python/test/unit/quantization/test_quantize.py +++ b/imperative/python/test/unit/quantization/test_quantize.py @@ -88,12 +88,14 @@ def test_propagate_qconfig(): def init_qat_net(): net = QATNet() propagate_qconfig(net, min_max_fakequant_qconfig) - min_val = np.random.randint(-127, 0, size=(2,)) - max_val = np.random.randint(1, 127, size=(2,)) - net.linear.weight_observer.min_val.set_value(min_val[0]) - net.linear.weight_observer.max_val.set_value(max_val[0]) - net.linear.act_observer.min_val.set_value(min_val[1]) - net.linear.act_observer.max_val.set_value(max_val[1]) + min_val = np.random.randint(-127, 0, size=(3,)) + max_val = np.random.randint(1, 127, size=(3,)) + net.quant.act_observer.min_val.set_value(min_val[0]) + net.quant.act_observer.max_val.set_value(max_val[0]) + net.linear.weight_observer.min_val.set_value(min_val[1]) + net.linear.weight_observer.max_val.set_value(max_val[1]) + net.linear.act_observer.min_val.set_value(min_val[2]) + net.linear.act_observer.max_val.set_value(max_val[2]) return net -- GitLab