提交 0b4918b2 编写于 作者: M Megvii Engine Team

test(mge/quantization): classmethod `from_float_module` of qat module

GitOrigin-RevId: 95c3d45f83349825b7913556899002efdacdc971
上级 bb369383
......@@ -8,7 +8,11 @@ import megengine.module.qat as QAT
import megengine.module.quantized as Q
from megengine.core.tensor import dtype
from megengine.quantization import min_max_fakequant_qconfig
from megengine.quantization.quantize import disable_observer, propagate_qconfig
from megengine.quantization.quantize import (
disable_fake_quant,
disable_observer,
propagate_qconfig,
)
"""
Calculate testing scales based on ``min_max_fakequant_qconfig``
......@@ -47,6 +51,12 @@ def init_qat_net(net):
def test_quant_stub():
normal_net = Float.QuantStub()
normal_net.eval()
qat_from_float = QAT.QuantStub.from_float_module(normal_net)
qat_from_float.eval()
disable_observer(qat_from_float)
disable_fake_quant(qat_from_float)
qat_net = QAT.QuantStub()
qat_net.eval()
disable_observer(qat_net)
......@@ -59,16 +69,25 @@ def test_quant_stub():
x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
normal_out = fake_quant(normal_net(x), act_scale)
qat_out = qat_net(x)
q_out = q_net(x).numpy() * act_scale
np.testing.assert_allclose(qat_out, normal_out)
np.testing.assert_allclose(q_out, normal_out.numpy())
normal = normal_net(x)
qat_without_fakequant = qat_from_float(x)
fake_quant_normal = fake_quant(normal_net(x), act_scale)
qat = qat_net(x)
q = q_net(x).numpy() * act_scale
np.testing.assert_allclose(qat_without_fakequant, normal)
np.testing.assert_allclose(qat, fake_quant_normal)
np.testing.assert_allclose(q, fake_quant_normal.numpy())
def test_dequant_stub():
normal_net = Float.DequantStub()
normal_net.eval()
qat_from_float = QAT.DequantStub.from_float_module(normal_net)
qat_from_float.eval()
disable_fake_quant(qat_from_float)
disable_observer(qat_from_float)
qat_net = QAT.DequantStub()
qat_net.eval()
disable_observer(qat_net)
......@@ -83,17 +102,26 @@ def test_dequant_stub():
x = fake_quant(x, inp_scale)
x.q_dict["scale"] = inp_scale
normal_out = normal_net(x)
qat_out = qat_net(x)
q_out = q_net(quant(x, inp_scale)).numpy()
np.testing.assert_allclose(qat_out, normal_out)
np.testing.assert_allclose(q_out, normal_out.numpy())
normal = normal_net(x)
qat_without_fakequant = qat_from_float(x)
fake_quant_normal = normal_net(x)
qat = qat_net(x)
q = q_net(quant(x, inp_scale)).numpy()
np.testing.assert_allclose(qat_without_fakequant, normal)
np.testing.assert_allclose(qat, fake_quant_normal)
np.testing.assert_allclose(q, fake_quant_normal.numpy())
@pytest.mark.parametrize("kind", ["COS", "RELU", "ADD", "MUL", "FUSE_ADD_RELU"])
def test_elemwise(kind):
normal_net = Float.Elemwise(kind)
normal_net.eval()
qat_from_float = QAT.Elemwise.from_float_module(normal_net)
qat_from_float.eval()
disable_observer(qat_from_float)
disable_fake_quant(qat_from_float)
qat_net = QAT.Elemwise(kind)
qat_net.eval()
disable_observer(qat_net)
......@@ -117,16 +145,22 @@ def test_elemwise(kind):
x1_int8 = quant(x1, x1_scale)
x2_int8 = quant(x2, x2_scale)
# test correctness of `Float`, `QAT` and `Quantized`
if kind in ("ADD", "MUL", "FUSE_ADD_RELU"):
normal_out = fake_quant(normal_net(x1, x2), act_scale)
qat_out = qat_net(x1, x2)
q_out = q_net(x1_int8, x2_int8).numpy() * act_scale
normal = normal_net(x1, x2)
qat_without_fakequant = qat_from_float(x1, x2)
fake_quant_normal = fake_quant(normal_net(x1, x2), act_scale)
qat = qat_net(x1, x2)
q = q_net(x1_int8, x2_int8).numpy() * act_scale
else:
normal_out = fake_quant(normal_net(x1), act_scale)
qat_out = qat_net(x1)
q_out = q_net(x1_int8).numpy() * act_scale
np.testing.assert_allclose(qat_out, normal_out)
np.testing.assert_allclose(q_out, normal_out.numpy())
normal = normal_net(x1)
qat_without_fakequant = qat_from_float(x1)
fake_quant_normal = fake_quant(normal_net(x1), act_scale)
qat = qat_net(x1)
q = q_net(x1_int8).numpy() * act_scale
np.testing.assert_allclose(qat_without_fakequant, normal)
np.testing.assert_allclose(qat, fake_quant_normal)
np.testing.assert_allclose(q, fake_quant_normal.numpy())
def test_linear():
......@@ -153,20 +187,29 @@ def test_linear():
qat_net.weight.set_value(weight)
qat_net.bias.set_value(bias)
qat_from_float = QAT.Linear.from_float_module(normal_net)
qat_from_float.eval()
disable_fake_quant(qat_from_float)
disable_observer(qat_from_float)
q_net = Q.Linear.from_qat_module(qat_net)
q_net.eval()
normal_out = fake_quant(normal_net(x), act_scale)
qat_out = qat_net(x)
q_out = q_net(x_int8).numpy() * act_scale
np.testing.assert_allclose(qat_out, normal_out)
np.testing.assert_allclose(q_out, normal_out.numpy())
normal = normal_net(x)
qat_without_fakequant = qat_from_float(x)
fake_quant_normal = fake_quant(normal_net(x), act_scale)
qat = qat_net(x)
q = q_net(x_int8).numpy() * act_scale
np.testing.assert_allclose(qat_without_fakequant, normal)
np.testing.assert_allclose(qat, fake_quant_normal)
np.testing.assert_allclose(q, fake_quant_normal.numpy())
@pytest.mark.parametrize("module", ["Conv2d", "ConvBn2d", "ConvBnRelu2d"])
def test_conv(module):
normal_net = getattr(Float, module)(3, 3, 3, 1, 1, 1, bias=True)
normal_net.eval()
qat_net = getattr(QAT, module)(3, 3, 3, 1, 1, 1, bias=True)
qat_net.eval()
disable_observer(qat_net)
......@@ -193,11 +236,19 @@ def test_conv(module):
qat_net.weight.set_value(weight)
qat_net.bias.set_value(bias)
qat_from_float = getattr(QAT, module).from_float_module(normal_net)
qat_from_float.eval()
disable_observer(qat_from_float)
disable_fake_quant(qat_from_float)
q_net = getattr(Q, module).from_qat_module(qat_net)
q_net.eval()
normal_out = fake_quant(normal_net(x), act_scale)
qat_out = qat_net(x)
q_out = q_net(x_int8).numpy() * act_scale
np.testing.assert_allclose(qat_out, normal_out)
np.testing.assert_allclose(q_out, normal_out.numpy())
normal = normal_net(x)
qat_without_fakequant = qat_from_float(x)
fake_quant_normal = fake_quant(normal_net(x), act_scale)
qat = qat_net(x)
q = q_net(x_int8).numpy() * act_scale
np.testing.assert_allclose(qat_without_fakequant, normal, atol=1e-6)
np.testing.assert_allclose(qat, fake_quant_normal)
np.testing.assert_allclose(q, fake_quant_normal.numpy())
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册