提交 94796060 编写于 作者: M Megvii Engine Team

test(mge/quantization): qat module, quantized op and module

GitOrigin-RevId: 17b28060ccf320972a22663eee15578b4b9798ee
上级 ab9f44f1
......@@ -29,11 +29,14 @@ class Linear(QuantizedModule):
inp_scale = dtype.get_scale(inp.dtype)
w_scale = dtype.get_scale(self.weight.dtype)
bias_dtype = dtype.qint32(inp_scale * w_scale)
return F.nn.linear(
ret = F.nn.linear(
inp,
self.weight,
None if self.bias is None else self.bias.astype(bias_dtype),
).astype(self.output_dtype)
)
ret = ret if self.output_dtype is None else ret.astype(self.output_dtype)
return ret
@classmethod
def from_qat_module(cls, qat_module: QAT.Linear):
......
......@@ -12,6 +12,7 @@ from .observer import HistogramObserver, Observer
from .qconfig import (
QConfig,
calibration_qconfig,
easyquant_qconfig,
ema_fakequant_qconfig,
ema_lowbit_fakequant_qconfig,
min_max_fakequant_qconfig,
......
......@@ -138,3 +138,5 @@ passive_qconfig = QConfig(
weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True),
act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False),
)
easyquant_qconfig = passive_qconfig
......@@ -223,11 +223,11 @@ def apply_easy_quant(module, data, start=0.8, stop=1.2, num=40):
mod._forward_hooks.clear()
fp32_in = [_[:batch_size] for _ in inputs]
int8_in = [_[batch_size:] for _ in inputs]
normal_in = [_[:batch_size] for _ in inputs]
fakequant_in = [_[batch_size:] for _ in inputs]
disable_fake_quant(mod)
fp32_out = mod(*fp32_in)
normal_out = mod(*normal_in)
enable_fake_quant(mod)
ob = getattr(mod, where)
......@@ -239,19 +239,15 @@ def apply_easy_quant(module, data, start=0.8, stop=1.2, num=40):
best_scale = 0
for scale in np.linspace(start * orig_scale, stop * orig_scale, num):
ob.scale = scale
int8_out = mod(*int8_in)
dis = get_cosine(fp32_out, int8_out)
fakequant_out = mod(*fakequant_in)
dis = get_cosine(normal_out, fakequant_out)
if dis > distance:
distance = dis
best_scale = scale
ob.scale = best_scale
if where == "act_observer":
int8_out = mod(*int8_in)
return concat([fp32_out, int8_out])
else:
int8_out = outputs[batch_size:]
return concat([fp32_out, int8_out])
fakequant_out = outputs[batch_size:]
return concat([normal_out, fakequant_out])
data = concat([data, data])
......
import numpy as np
import pytest
import megengine as mge
import megengine.functional as F
import megengine.module as Float
import megengine.module.qat as QAT
import megengine.module.quantized as Q
from megengine.core.tensor import dtype
from megengine.quantization import min_max_fakequant_qconfig
from megengine.quantization.quantize import disable_observer, propagate_qconfig
"""
Calculate testing scales based on ``min_max_fakequant_qconfig``
"""
inp_scale = np.float32(np.random.rand() + 1)
min_val = np.random.randint(-127, 0, size=(2,)).astype("float32")
max_val = np.random.randint(1, 127, size=(2,)).astype("float32")
weight_scale = np.float32(np.max([-min_val[0], max_val[0]]) / 254 * 2)
act_scale = np.float32(np.max([-min_val[1], max_val[1]]) / 255 * 2)
def quant(x, scale):
inp_dtype = dtype.qint8(scale)
return x.astype(inp_dtype)
def fake_quant(x, scale):
x = x / scale
x = F.round(x)
x = F.clip(x, -128, 127)
x = x * scale
return x
def init_qat_net(net):
if net.with_weight:
net.weight_observer.min_val.set_value(min_val[0])
net.weight_observer.max_val.set_value(max_val[0])
if net.with_act:
net.act_observer.min_val.set_value(min_val[1])
net.act_observer.max_val.set_value(max_val[1])
def test_quant_stub():
normal_net = Float.QuantStub()
normal_net.eval()
qat_net = QAT.QuantStub()
qat_net.eval()
disable_observer(qat_net)
propagate_qconfig(qat_net, min_max_fakequant_qconfig)
init_qat_net(qat_net)
q_net = Q.QuantStub.from_qat_module(qat_net)
q_net.eval()
x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
normal_out = fake_quant(normal_net(x), act_scale)
qat_out = qat_net(x)
q_out = q_net(x).numpy() * act_scale
np.testing.assert_allclose(qat_out, normal_out)
np.testing.assert_allclose(q_out, normal_out.numpy())
def test_dequant_stub():
normal_net = Float.DequantStub()
normal_net.eval()
qat_net = QAT.DequantStub()
qat_net.eval()
disable_observer(qat_net)
propagate_qconfig(qat_net, min_max_fakequant_qconfig)
init_qat_net(qat_net)
q_net = Q.DequantStub.from_qat_module(qat_net)
q_net.eval()
x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
x = fake_quant(x, inp_scale)
x.q_dict["scale"] = inp_scale
normal_out = normal_net(x)
qat_out = qat_net(x)
q_out = q_net(quant(x, inp_scale)).numpy()
np.testing.assert_allclose(qat_out, normal_out)
np.testing.assert_allclose(q_out, normal_out.numpy())
@pytest.mark.parametrize("kind", ["COS", "RELU", "ADD", "MUL", "FUSE_ADD_RELU"])
def test_elemwise(kind):
normal_net = Float.Elemwise(kind)
normal_net.eval()
qat_net = QAT.Elemwise(kind)
qat_net.eval()
disable_observer(qat_net)
propagate_qconfig(qat_net, min_max_fakequant_qconfig)
init_qat_net(qat_net)
q_net = Q.Elemwise.from_qat_module(qat_net)
q_net.eval()
x1_scale = np.float32(np.random.rand() + 1)
x1 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
x1 = fake_quant(x1, x1_scale)
x1.q_dict["scale"] = x1_scale
x2_scale = np.float32(np.random.rand() + 1)
x2 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
x2 = fake_quant(x2, x2_scale)
x2.q_dict["scale"] = x2_scale
x1_int8 = quant(x1, x1_scale)
x2_int8 = quant(x2, x2_scale)
if kind in ("ADD", "MUL", "FUSE_ADD_RELU"):
normal_out = fake_quant(normal_net(x1, x2), act_scale)
qat_out = qat_net(x1, x2)
q_out = q_net(x1_int8, x2_int8).numpy() * act_scale
else:
normal_out = fake_quant(normal_net(x1), act_scale)
qat_out = qat_net(x1)
q_out = q_net(x1_int8).numpy() * act_scale
np.testing.assert_allclose(qat_out, normal_out)
np.testing.assert_allclose(q_out, normal_out.numpy())
def test_linear():
normal_net = Float.Linear(3, 3, bias=True)
normal_net.eval()
qat_net = QAT.Linear(3, 3, bias=True)
qat_net.eval()
disable_observer(qat_net)
propagate_qconfig(qat_net, min_max_fakequant_qconfig)
init_qat_net(qat_net)
x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
x = fake_quant(x, inp_scale)
x.q_dict["scale"] = inp_scale
x_int8 = quant(x, inp_scale)
weight = np.random.normal(size=(3, 3)).astype("float32")
bias = np.random.normal(size=(3,)).astype("float32")
normal_net.weight.set_value(fake_quant(weight, weight_scale))
normal_net.bias.set_value(fake_quant(bias, inp_scale * weight_scale))
qat_net.weight.set_value(weight)
qat_net.bias.set_value(bias)
q_net = Q.Linear.from_qat_module(qat_net)
q_net.eval()
normal_out = fake_quant(normal_net(x), act_scale)
qat_out = qat_net(x)
q_out = q_net(x_int8).numpy() * act_scale
np.testing.assert_allclose(qat_out, normal_out)
np.testing.assert_allclose(q_out, normal_out.numpy())
@pytest.mark.parametrize("module", ["Conv2d", "ConvBn2d", "ConvBnRelu2d"])
def test_conv(module):
normal_net = getattr(Float, module)(3, 3, 3, 1, 1, 1, bias=True)
normal_net.eval()
qat_net = getattr(QAT, module)(3, 3, 3, 1, 1, 1, bias=True)
qat_net.eval()
disable_observer(qat_net)
propagate_qconfig(qat_net, min_max_fakequant_qconfig)
init_qat_net(qat_net)
x = mge.tensor(np.random.normal(size=(1, 3, 3, 3)).astype("float32"))
x = fake_quant(x, inp_scale)
x.q_dict["scale"] = inp_scale
x_int8 = quant(x, inp_scale)
weight = np.random.normal(size=(3, 3, 3, 3)).astype("float32")
bias = np.random.normal(size=(1, 3, 1, 1)).astype("float32")
if module in ("ConvBn2d", "ConvBnRelu2d"):
normal_net.conv.weight.set_value(fake_quant(weight, weight_scale))
normal_net.conv.bias.set_value(fake_quant(bias, inp_scale * weight_scale))
qat_net.conv.weight.set_value(weight)
qat_net.conv.bias.set_value(bias)
else:
normal_net.weight.set_value(fake_quant(weight, weight_scale))
normal_net.bias.set_value(fake_quant(bias, inp_scale * weight_scale))
qat_net.weight.set_value(weight)
qat_net.bias.set_value(bias)
q_net = getattr(Q, module).from_qat_module(qat_net)
q_net.eval()
normal_out = fake_quant(normal_net(x), act_scale)
qat_out = qat_net(x)
q_out = q_net(x_int8).numpy() * act_scale
np.testing.assert_allclose(qat_out, normal_out)
np.testing.assert_allclose(q_out, normal_out.numpy())
......@@ -103,7 +103,7 @@ def test_sync_exponential_moving_average_observer():
y2 = mge.tensor(x2[rank * 3 : (rank + 1) * 3])
m(y1)
m(y2)
np.testing.assert_allclose(m.min_val.numpy(), expected_min)
np.testing.assert_allclose(m.max_val.numpy(), expected_max)
np.testing.assert_allclose(m.min_val.numpy(), expected_min, atol=1e-6)
np.testing.assert_allclose(m.max_val.numpy(), expected_max, atol=1e-6)
worker()
import numpy as np
import pytest
import megengine as mge
import megengine.functional as F
from megengine.core.tensor import dtype
from megengine.distributed.helper import get_device_count_by_fork
from megengine.functional.elemwise import _elemwise_multi_type, _elwise
def quant(x, scale):
x_dtype = dtype.qint8(scale)
return x.astype(x_dtype)
def fake_quant(x, scale):
x = x / scale
x = F.round(x)
x = F.clip(x, -128, 127)
x = x * scale
return x
@pytest.mark.parametrize("kind", ["ABS", "SIN", "SUB", "MUL", "FUSE_ADD_TANH"])
def test_elemwise(kind):
x1 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
x1_scale = np.float32(np.random.rand() + 1)
x1 = fake_quant(x1, x1_scale)
x1.q_dict["scale"] = x1_scale
x1_int8 = quant(x1, x1_scale)
x2 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
x2_scale = np.float32(np.random.rand() + 1)
x2 = fake_quant(x2, x2_scale)
x2.q_dict["scale"] = x2_scale
x2_int8 = quant(x2, x2_scale)
output_scale = np.float32(np.random.rand() + 1)
output_dtype = dtype.qint8(output_scale)
quantized_kind = "Q" + kind
if kind in ("ABS", "SIN"):
desired_out = fake_quant(_elwise(x1, mode=kind), output_scale)
actual_out = (
_elemwise_multi_type(
x1_int8, mode=quantized_kind, dtype=output_dtype
).numpy()
* output_scale
)
else:
desired_out = fake_quant(_elwise(x1, x2, mode=kind), output_scale)
actual_out = (
_elemwise_multi_type(
x1_int8, x2_int8, mode=quantized_kind, dtype=output_dtype
).numpy()
* output_scale
)
np.testing.assert_allclose(actual_out, desired_out.numpy())
@pytest.mark.skipif(
get_device_count_by_fork("gpu") > 0, reason="cuda does not support nchw int8"
)
def test_conv_bias():
inp_scale = np.float32(np.random.rand() + 1)
w_scale = np.float32(np.random.rand() + 1)
outp_scale = np.float32(np.random.rand() + 1)
inp_dtype = dtype.qint8(inp_scale)
w_dtype = dtype.qint8(w_scale)
b_dtype = dtype.qint32(inp_scale * w_scale)
out_dtype = dtype.qint8(outp_scale)
def run(
N,
IC,
OC,
IH,
IW,
KH,
KW,
PH,
PW,
SH,
SW,
has_bias=True,
nonlinear_mode="IDENTITY",
):
inp_v = np.random.normal(size=(N, IC, IH, IW))
w_v = np.random.normal(size=(OC, IC, KH, KW))
b_v = np.random.normal(size=(1, OC, 1, 1))
inp_scale = dtype.get_scale(inp_dtype)
w_scale = dtype.get_scale(w_dtype)
b_scale = dtype.get_scale(b_dtype)
inpv = dtype.convert_to_qint8(inp_v * inp_scale, inp_dtype)
wv = dtype.convert_to_qint8(w_v * w_scale, w_dtype)
bv = dtype.convert_to_qint32(b_v * b_scale, b_dtype)
inp_int8 = mge.tensor(inpv, dtype=inp_dtype)
w_int8 = mge.Parameter(wv, dtype=w_dtype)
b_int32 = mge.Parameter(bv, dtype=b_dtype)
inp_fp32 = inp_int8.astype("float32")
w_fp32 = w_int8.astype("float32")
b_fp32 = b_int32.astype("float32")
def convert_to_nchw4(var):
var = F.reshape(
var, (var.shape[0], var.shape[1] // 4, 4, var.shape[2], var.shape[3])
)
var = F.transpose(var, (0, 1, 3, 4, 2))
return var
def run_conv2d(inp, w, b):
O = F.conv2d(
inp, w, b if has_bias else None, stride=(SH, SW), padding=(PH, PW),
)
if nonlinear_mode == "RELU":
return F.relu(O)
else:
return O
def run_conv_bias(inp, w, b, format="NCHW"):
b = b if has_bias else mge.Parameter(np.zeros_like(b.numpy()))
if format == "NCHW4":
inp = convert_to_nchw4(inp)
w = convert_to_nchw4(w)
b = convert_to_nchw4(b)
return F.quantized.conv_bias_activation(
inp,
w,
b,
stride=(SH, SW),
padding=(PH, PW),
dtype=out_dtype,
nonlinear_mode=nonlinear_mode,
)
format = "NCHW4" if mge.is_cuda_available() else "NCHW"
expected = run_conv2d(inp_fp32, w_fp32, b_fp32)
expected = expected.astype(out_dtype).astype("float32")
result = run_conv_bias(inp_int8, w_int8, b_int32, format=format).astype(
"float32"
)
if format == "NCHW4":
result = F.transpose(result, (0, 1, 4, 2, 3))
expected = F.flatten(expected)
result = F.flatten(result)
np.testing.assert_allclose(result.numpy(), expected.numpy(), atol=outp_scale)
run(1, 4, 4, 24, 33, 1, 1, 2, 3, 1, 1, False)
run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1, False)
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False)
run(1, 4, 4, 24, 33, 1, 1, 2, 3, 1, 1)
run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1)
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2)
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False, "RELU")
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "RELU")
......@@ -88,12 +88,14 @@ def test_propagate_qconfig():
def init_qat_net():
net = QATNet()
propagate_qconfig(net, min_max_fakequant_qconfig)
min_val = np.random.randint(-127, 0, size=(2,))
max_val = np.random.randint(1, 127, size=(2,))
net.linear.weight_observer.min_val.set_value(min_val[0])
net.linear.weight_observer.max_val.set_value(max_val[0])
net.linear.act_observer.min_val.set_value(min_val[1])
net.linear.act_observer.max_val.set_value(max_val[1])
min_val = np.random.randint(-127, 0, size=(3,))
max_val = np.random.randint(1, 127, size=(3,))
net.quant.act_observer.min_val.set_value(min_val[0])
net.quant.act_observer.max_val.set_value(max_val[0])
net.linear.weight_observer.min_val.set_value(min_val[1])
net.linear.weight_observer.max_val.set_value(max_val[1])
net.linear.act_observer.min_val.set_value(min_val[2])
net.linear.act_observer.max_val.set_value(max_val[2])
return net
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册