test_module.py 10.1 KB
Newer Older
1 2
from functools import partial

3 4 5 6 7 8 9 10
import numpy as np
import pytest

import megengine as mge
import megengine.functional as F
import megengine.module as Float
import megengine.module.qat as QAT
import megengine.module.quantized as Q
11
from megengine import Parameter, Tensor
12
from megengine.core.tensor import dtype
13 14 15 16 17 18 19
from megengine.quantization import (
    FakeQuantize,
    MinMaxObserver,
    QConfig,
    QuantMode,
    create_qparams,
)
20 21 22 23 24
from megengine.quantization.quantize import (
    disable_fake_quant,
    disable_observer,
    propagate_qconfig,
)
25

26
min_max_fakequant_qconfig = QConfig(
27 28 29 30
    weight_observer=partial(MinMaxObserver, dtype="qint8_narrow"),
    act_observer=partial(MinMaxObserver, dtype="qint8"),
    weight_fake_quant=partial(FakeQuantize, dtype="qint8_narrow"),
    act_fake_quant=partial(FakeQuantize, dtype="qint8"),
31
)
32

33 34 35 36

def gen_inp_scale():
    return np.float32(np.random.rand() + 1)

37 38 39 40 41 42 43 44 45 46 47 48

min_val = np.random.randint(-127, 0, size=(2,)).astype("float32")
max_val = np.random.randint(1, 127, size=(2,)).astype("float32")
weight_scale = np.float32(np.max([-min_val[0], max_val[0]]) / 254 * 2)
act_scale = np.float32(np.max([-min_val[1], max_val[1]]) / 255 * 2)


def quant(x, scale):
    inp_dtype = dtype.qint8(scale)
    return x.astype(inp_dtype)


49
def fake_quant(x, scale, qmin, qmax):
50 51
    x = x / scale
    x = F.round(x)
52
    x = F.clip(x, qmin, qmax)
53 54 55 56
    x = x * scale
    return x


57 58 59 60 61
fake_quant_act = partial(fake_quant, qmin=-128, qmax=127)
fake_quant_weight = partial(fake_quant, qmin=-127, qmax=127)
fake_quant_bias = partial(fake_quant, qmin=-(2 ** 31), qmax=2 ** 31 - 1)


62 63
def init_qat_net(net):
    if net.with_weight:
64 65
        net.weight_observer.min_val[...] = Tensor(min_val[0])
        net.weight_observer.max_val[...] = Tensor(max_val[0])
66
    if net.with_act:
67 68
        net.act_observer.min_val[...] = Tensor(min_val[1])
        net.act_observer.max_val[...] = Tensor(max_val[1])
69 70 71 72 73


def test_quant_stub():
    normal_net = Float.QuantStub()
    normal_net.eval()
74 75 76 77 78 79

    qat_from_float = QAT.QuantStub.from_float_module(normal_net)
    qat_from_float.eval()
    disable_observer(qat_from_float)
    disable_fake_quant(qat_from_float)

80 81 82 83 84 85 86 87 88 89 90 91
    qat_net = QAT.QuantStub()
    qat_net.eval()
    disable_observer(qat_net)

    propagate_qconfig(qat_net, min_max_fakequant_qconfig)
    init_qat_net(qat_net)

    q_net = Q.QuantStub.from_qat_module(qat_net)
    q_net.eval()

    x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))

92 93
    normal = normal_net(x)
    qat_without_fakequant = qat_from_float(x)
94
    fake_quant_normal = fake_quant_act(normal_net(x), act_scale)
95 96 97 98 99
    qat = qat_net(x)
    q = q_net(x).numpy() * act_scale
    np.testing.assert_allclose(qat_without_fakequant, normal)
    np.testing.assert_allclose(qat, fake_quant_normal)
    np.testing.assert_allclose(q, fake_quant_normal.numpy())
100 101 102 103 104


def test_dequant_stub():
    normal_net = Float.DequantStub()
    normal_net.eval()
105 106 107 108 109 110

    qat_from_float = QAT.DequantStub.from_float_module(normal_net)
    qat_from_float.eval()
    disable_fake_quant(qat_from_float)
    disable_observer(qat_from_float)

111 112 113 114 115 116 117 118 119 120 121
    qat_net = QAT.DequantStub()
    qat_net.eval()
    disable_observer(qat_net)

    propagate_qconfig(qat_net, min_max_fakequant_qconfig)
    init_qat_net(qat_net)

    q_net = Q.DequantStub.from_qat_module(qat_net)
    q_net.eval()

    x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
122
    inp_scale = gen_inp_scale()
123
    x = fake_quant_act(x, inp_scale)
124
    x.qparams.scale = inp_scale
125

126 127 128 129 130 131 132 133
    normal = normal_net(x)
    qat_without_fakequant = qat_from_float(x)
    fake_quant_normal = normal_net(x)
    qat = qat_net(x)
    q = q_net(quant(x, inp_scale)).numpy()
    np.testing.assert_allclose(qat_without_fakequant, normal)
    np.testing.assert_allclose(qat, fake_quant_normal)
    np.testing.assert_allclose(q, fake_quant_normal.numpy())
134 135


136
@pytest.mark.parametrize("kind", ["cos", "relu", "add", "mul", "fuse_add_relu"])
137 138 139
def test_elemwise(kind):
    normal_net = Float.Elemwise(kind)
    normal_net.eval()
140 141 142 143 144 145

    qat_from_float = QAT.Elemwise.from_float_module(normal_net)
    qat_from_float.eval()
    disable_observer(qat_from_float)
    disable_fake_quant(qat_from_float)

146 147 148 149 150 151 152 153 154 155 156 157
    qat_net = QAT.Elemwise(kind)
    qat_net.eval()
    disable_observer(qat_net)

    propagate_qconfig(qat_net, min_max_fakequant_qconfig)
    init_qat_net(qat_net)

    q_net = Q.Elemwise.from_qat_module(qat_net)
    q_net.eval()

    x1_scale = np.float32(np.random.rand() + 1)
    x1 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
158
    x1 = fake_quant_act(x1, x1_scale)
159
    x1.qparams.scale = x1_scale
160 161 162

    x2_scale = np.float32(np.random.rand() + 1)
    x2 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
163
    x2 = fake_quant_act(x2, x2_scale)
164
    x2.qparams.scale = x2_scale
165 166 167 168

    x1_int8 = quant(x1, x1_scale)
    x2_int8 = quant(x2, x2_scale)

169
    # test correctness of `Float`, `QAT` and `Quantized`
170
    if kind in ("add", "mul", "fuse_add_relu"):
171 172
        normal = normal_net(x1, x2)
        qat_without_fakequant = qat_from_float(x1, x2)
173
        fake_quant_normal = fake_quant_act(normal_net(x1, x2), act_scale)
174 175
        qat = qat_net(x1, x2)
        q = q_net(x1_int8, x2_int8).numpy() * act_scale
176
    else:
177 178
        normal = normal_net(x1)
        qat_without_fakequant = qat_from_float(x1)
179
        fake_quant_normal = fake_quant_act(normal_net(x1), act_scale)
180 181 182 183 184
        qat = qat_net(x1)
        q = q_net(x1_int8).numpy() * act_scale
    np.testing.assert_allclose(qat_without_fakequant, normal)
    np.testing.assert_allclose(qat, fake_quant_normal)
    np.testing.assert_allclose(q, fake_quant_normal.numpy())
185 186 187 188 189 190 191 192 193 194 195 196 197 198


def test_linear():
    normal_net = Float.Linear(3, 3, bias=True)
    normal_net.eval()

    qat_net = QAT.Linear(3, 3, bias=True)
    qat_net.eval()
    disable_observer(qat_net)

    propagate_qconfig(qat_net, min_max_fakequant_qconfig)
    init_qat_net(qat_net)

    x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
199
    inp_scale = gen_inp_scale()
200
    x = fake_quant_act(x, inp_scale)
201
    x.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", inp_scale))
202 203 204 205 206

    x_int8 = quant(x, inp_scale)

    weight = np.random.normal(size=(3, 3)).astype("float32")
    bias = np.random.normal(size=(3,)).astype("float32")
207 208 209 210
    normal_net.weight[...] = fake_quant_weight(weight, weight_scale)
    normal_net.bias[...] = fake_quant_bias(bias, inp_scale * weight_scale)
    qat_net.weight[...] = Parameter(weight)
    qat_net.bias[...] = Parameter(bias)
211

212 213 214 215 216
    qat_from_float = QAT.Linear.from_float_module(normal_net)
    qat_from_float.eval()
    disable_fake_quant(qat_from_float)
    disable_observer(qat_from_float)

217 218 219
    q_net = Q.Linear.from_qat_module(qat_net)
    q_net.eval()

220 221
    normal = normal_net(x)
    qat_without_fakequant = qat_from_float(x)
222
    fake_quant_normal = fake_quant_act(normal_net(x), act_scale)
223 224 225
    qat = qat_net(x)
    q = q_net(x_int8).numpy() * act_scale
    np.testing.assert_allclose(qat_without_fakequant, normal)
226
    np.testing.assert_allclose(qat, fake_quant_normal.numpy())
227
    np.testing.assert_allclose(q, fake_quant_normal.numpy())
228 229 230 231 232 233


@pytest.mark.parametrize("module", ["Conv2d", "ConvBn2d", "ConvBnRelu2d"])
def test_conv(module):
    normal_net = getattr(Float, module)(3, 3, 3, 1, 1, 1, bias=True)
    normal_net.eval()
234

235 236 237 238 239 240 241 242
    qat_net = getattr(QAT, module)(3, 3, 3, 1, 1, 1, bias=True)
    qat_net.eval()
    disable_observer(qat_net)

    propagate_qconfig(qat_net, min_max_fakequant_qconfig)
    init_qat_net(qat_net)

    x = mge.tensor(np.random.normal(size=(1, 3, 3, 3)).astype("float32"))
243
    inp_scale = gen_inp_scale()
244
    x = fake_quant_act(x, inp_scale)
245
    x.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", inp_scale))
246 247 248 249 250 251

    x_int8 = quant(x, inp_scale)

    weight = np.random.normal(size=(3, 3, 3, 3)).astype("float32")
    bias = np.random.normal(size=(1, 3, 1, 1)).astype("float32")
    if module in ("ConvBn2d", "ConvBnRelu2d"):
252 253 254 255
        normal_net.conv.weight[...] = fake_quant_weight(weight, weight_scale)
        normal_net.conv.bias[...] = fake_quant_bias(bias, inp_scale * weight_scale)
        qat_net.conv.weight[...] = Parameter(weight)
        qat_net.conv.bias[...] = Parameter(bias)
256
    else:
257 258 259 260
        normal_net.weight[...] = fake_quant_weight(weight, weight_scale)
        normal_net.bias[...] = fake_quant_bias(bias, inp_scale * weight_scale)
        qat_net.weight[...] = Parameter(weight)
        qat_net.bias[...] = Parameter(bias)
261

262 263 264 265 266
    qat_from_float = getattr(QAT, module).from_float_module(normal_net)
    qat_from_float.eval()
    disable_observer(qat_from_float)
    disable_fake_quant(qat_from_float)

267 268 269
    q_net = getattr(Q, module).from_qat_module(qat_net)
    q_net.eval()

270 271
    normal = normal_net(x)
    qat_without_fakequant = qat_from_float(x)
272
    fake_quant_normal = fake_quant_act(normal_net(x), act_scale)
273 274
    qat = qat_net(x)
    q = q_net(x_int8).numpy() * act_scale
275 276 277
    np.testing.assert_allclose(qat_without_fakequant, normal, atol=1e-5)
    np.testing.assert_allclose(qat, fake_quant_normal, atol=act_scale)
    np.testing.assert_allclose(q, fake_quant_normal.numpy(), atol=act_scale)
278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315


def test_concat():
    normal_net = Float.Concat()
    normal_net.eval()

    qat_net = QAT.Concat()
    qat_net.eval()
    disable_observer(qat_net)

    propagate_qconfig(qat_net, min_max_fakequant_qconfig)
    init_qat_net(qat_net)

    inps = []
    inps_int8 = []
    for i in range(3):
        inp_scale = gen_inp_scale()
        inps.append(mge.tensor(np.random.normal(size=(3, 3)).astype("float32")))
        inps[i] = fake_quant_act(inps[i], inp_scale)
        inps[i].qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", inp_scale))
        inps_int8.append(quant(inps[i], inp_scale))

    qat_from_float = QAT.Concat.from_float_module(normal_net)
    qat_from_float.eval()
    disable_fake_quant(qat_from_float)
    disable_observer(qat_from_float)

    q_net = Q.Concat.from_qat_module(qat_net)
    q_net.eval()

    normal = normal_net(inps)
    qat_without_fakequant = qat_from_float(inps)
    fake_quant_normal = fake_quant_act(normal_net(inps), act_scale)
    qat = qat_net(inps)
    q = q_net(inps_int8).numpy() * act_scale
    np.testing.assert_allclose(qat_without_fakequant, normal)
    np.testing.assert_allclose(qat, fake_quant_normal.numpy())
    np.testing.assert_allclose(q, fake_quant_normal.numpy())