test_qat.py 2.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
from itertools import product

import numpy as np

from megengine import tensor
from megengine.module import (
    Conv2d,
    ConvBn2d,
    ConvRelu2d,
    DequantStub,
    Module,
    QuantStub,
)
from megengine.quantization.quantize import disable_fake_quant, quantize_qat
from megengine.test import assertTensorClose


def test_qat_convbn2d():
    in_channels = 32
    out_channels = 64
    kernel_size = 3
    for groups, bias in product([1, 4], [True, False]):
        module = ConvBn2d(
            in_channels, out_channels, kernel_size, groups=groups, bias=bias
        )
        module.train()
        qat_module = quantize_qat(module, inplace=False)
        disable_fake_quant(qat_module)
        inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
        normal_outputs = module(inputs)
        # import pdb
        # pdb.set_trace()
        qat_outputs = qat_module(inputs)
        assertTensorClose(normal_outputs.numpy(), qat_outputs.numpy(), max_err=5e-6)
        assertTensorClose(
            module.bn.running_mean.numpy(),
            qat_module.bn.running_mean.numpy(),
            max_err=5e-8,
        )
        assertTensorClose(
            module.bn.running_var.numpy(),
            qat_module.bn.running_var.numpy(),
            max_err=5e-7,
        )
        module.eval()
        normal_outputs = module(inputs)
        qat_module.eval()
        qat_outputs = qat_module(inputs)
        assertTensorClose(normal_outputs.numpy(), qat_outputs.numpy(), max_err=5e-6)


def test_qat_conv():

    in_channels = 32
    out_channels = 64
    kernel_size = 3

    class TestNet(Module):
        def __init__(self, groups, bias):
            super().__init__()
            self.quant = QuantStub()
            self.dequant = DequantStub()
            self.conv = Conv2d(
                in_channels, out_channels, kernel_size, groups=groups, bias=bias
            )
            self.conv_relu = ConvRelu2d(
                out_channels, in_channels, kernel_size, groups=groups, bias=bias
            )

        def forward(self, inp):
            out = self.quant(inp)
            out = self.conv(out)
            out = self.conv_relu(out)
            out = self.dequant(out)
            return out

    inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
    for groups, bias in product([1, 4], [True, False]):
        net = TestNet(groups, bias)
        net.train()
        qat_net = quantize_qat(net, inplace=False)
        disable_fake_quant(qat_net)
        normal_outputs = net(inputs)
        qat_outputs = qat_net(inputs)
        assertTensorClose(normal_outputs.numpy(), qat_outputs.numpy())

        net.eval()
        normal_outputs = net(inputs)
        qat_net.eval()
        qat_outputs = qat_net(inputs)
        assertTensorClose(normal_outputs.numpy(), qat_outputs.numpy())