From bc9f9cd4fc08552133bc9768088842f7c8b86d9c Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 18 May 2023 15:36:53 +0800 Subject: [PATCH] feat(imperative/module): add linear fuse bn and relu support GitOrigin-RevId: c342f687974f9b47304c26ac39a59b74caaf7f70 --- .../python/megengine/module/__init__.py | 3 +- imperative/python/megengine/module/linear.py | 15 +- .../python/megengine/module/linear_bn.py | 50 ++++++ .../python/megengine/module/qat/__init__.py | 3 +- .../python/megengine/module/qat/linear.py | 18 +- .../python/megengine/module/qat/linear_bn.py | 164 ++++++++++++++++++ .../megengine/module/quantized/__init__.py | 3 +- .../megengine/module/quantized/linear.py | 23 ++- .../megengine/module/quantized/linear_bn.py | 40 +++++ .../python/megengine/utils/bn_fusion.py | 97 ++++++++++- .../python/test/unit/module/test_qat.py | 128 ++++++++++++++ .../test/unit/quantization/test_quantize.py | 48 ++++- 12 files changed, 581 insertions(+), 11 deletions(-) create mode 100644 imperative/python/megengine/module/linear_bn.py create mode 100644 imperative/python/megengine/module/qat/linear_bn.py create mode 100644 imperative/python/megengine/module/quantized/linear_bn.py diff --git a/imperative/python/megengine/module/__init__.py b/imperative/python/megengine/module/__init__.py index 89e1cbd02..60a21c2a2 100644 --- a/imperative/python/megengine/module/__init__.py +++ b/imperative/python/megengine/module/__init__.py @@ -24,7 +24,8 @@ from .dropout import Dropout from .elemwise import Elemwise from .embedding import Embedding from .identity import Identity -from .linear import Linear +from .linear import Linear, LinearRelu +from .linear_bn import LinearBn1d, LinearBnRelu1d from .lrn import LocalResponseNorm from .module import Module from .multiheadattn import MultiHeadAttention diff --git a/imperative/python/megengine/module/linear.py b/imperative/python/megengine/module/linear.py index a1f404c3e..426f4b247 100644 --- a/imperative/python/megengine/module/linear.py +++ b/imperative/python/megengine/module/linear.py @@ -1,6 +1,6 @@ import numpy as np -from ..functional.nn import linear +from ..functional.nn import linear, relu from ..tensor import Parameter from . import init from .module import Module @@ -62,13 +62,22 @@ class Linear(Module): if self.bias is not None: init.zeros_(self.bias) - def _calc_linear(self, x, weight, bias): + def calc_linear(self, x, weight, bias): return linear(x, weight, bias, compute_mode=self.compute_mode) def forward(self, x): - return self._calc_linear(x, self.weight, self.bias) + return self.calc_linear(x, self.weight, self.bias) def _module_info_string(self) -> str: return "in_features={}, out_features={}, bias={}".format( self.in_features, self.out_features, self.bias is not None ) + + +class LinearRelu(Linear): + r"""A fused :class:`~.Module` including :class:`~.module.Linear` and :func:`~.relu`. + Could be replaced with :class:`~.QATModule` version :class:`~.qat.LinearRelu` using :func:`~.quantize.quantize_qat`. + """ + + def forward(self, inp): + return relu(self.calc_linear(inp, self.weight, self.bias)) diff --git a/imperative/python/megengine/module/linear_bn.py b/imperative/python/megengine/module/linear_bn.py new file mode 100644 index 000000000..f89cd69f3 --- /dev/null +++ b/imperative/python/megengine/module/linear_bn.py @@ -0,0 +1,50 @@ +import numpy as np + +from ..functional import relu +from ..tensor import Parameter +from .batchnorm import BatchNorm1d +from .linear import Linear +from .module import Module + + +class _LinearBnActivation1d(Module): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + compute_mode: str = "default", + eps=1e-5, + momentum=0.9, + affine=True, + track_running_stats=True, + **kwargs + ): + super().__init__(**kwargs) + self.out_features = out_features + self.in_features = in_features + self.bias = None + if bias: + b_shape = (out_features,) + self.bias = Parameter(np.zeros(b_shape, dtype=np.float32)) + self.linear = Linear(in_features, out_features, bias, compute_mode, **kwargs,) + self.bn = BatchNorm1d(out_features, eps, momentum, affine, track_running_stats) + + +class LinearBn1d(_LinearBnActivation1d): + r"""A fused :class:`~.Module` including :class:`~.module.Linear` and :class:`~.module.BatchNorm1d`. + Could be replaced with :class:`~.QATModule` version :class:`~.qat.LinearBn1d` using + :func:`~.quantize.quantize_qat`. + """ + + def forward(self, inp): + return self.bn(self.linear(inp)) + + +class LinearBnRelu1d(_LinearBnActivation1d): + r"""A fused :class:`~.Module` including :class:`~.module.Linear`, :class:`~.module.BatchNorm1d` and :func:`~.relu`. + Could be replaced with :class:`~.QATModule` version :class:`~.qat.LinearBnRelu1d` using :func:`~.quantize.quantize_qat`. + """ + + def forward(self, inp): + return relu(self.bn(self.linear(inp))) diff --git a/imperative/python/megengine/module/qat/__init__.py b/imperative/python/megengine/module/qat/__init__.py index 2a95dabf6..bfae36ae1 100644 --- a/imperative/python/megengine/module/qat/__init__.py +++ b/imperative/python/megengine/module/qat/__init__.py @@ -4,6 +4,7 @@ from .conv import Conv2d, ConvRelu2d, ConvTranspose2d, ConvTransposeRelu2d from .conv_bn import ConvBn2d, ConvBnRelu2d from .conv_transpose_bn import ConvTransposeBn2d, ConvTransposeBnRelu2d from .elemwise import Elemwise -from .linear import Linear +from .linear import Linear, LinearRelu +from .linear_bn import LinearBn1d, LinearBnRelu1d from .module import QATModule from .quant_dequant import DequantStub, QuantStub diff --git a/imperative/python/megengine/module/qat/linear.py b/imperative/python/megengine/module/qat/linear.py index 1e879a282..c32a2a6a6 100644 --- a/imperative/python/megengine/module/qat/linear.py +++ b/imperative/python/megengine/module/qat/linear.py @@ -1,3 +1,4 @@ +from ... import functional as F from .. import linear as Float from .module import QATModule @@ -13,10 +14,16 @@ class Linear(Float.Linear, QATModule): Default: True """ + def calc_linear_qat(self, inp): + w_qat = self.apply_quant_weight(self.weight) + b_qat = self.apply_quant_bias(self.bias, inp, w_qat) + linear = self.calc_linear(inp, w_qat, b_qat) + return linear + def forward(self, inp): w_qat = self.apply_quant_weight(self.weight) b_qat = self.apply_quant_bias(self.bias, inp, w_qat) - return self.apply_quant_activation(self._calc_linear(inp, w_qat, b_qat)) + return self.apply_quant_activation(self.calc_linear(inp, w_qat, b_qat)) @classmethod def from_float_module(cls, float_module: Float.Linear): @@ -30,3 +37,12 @@ class Linear(Float.Linear, QATModule): qmod.weight = float_module.weight qmod.bias = float_module.bias return qmod + + +class LinearRelu(Linear): + r"""A :class:`~.QATModule` include :class:`~.module.Linear` and :func:`~.relu` with QAT support. + Could be applied with :class:`~.Observer` and :class:`~.quantization.fake_quant.FakeQuantize`. + """ + + def forward(self, inp): + return self.apply_quant_activation(F.relu(self.calc_linear_qat(inp))) diff --git a/imperative/python/megengine/module/qat/linear_bn.py b/imperative/python/megengine/module/qat/linear_bn.py new file mode 100644 index 000000000..7ee49b534 --- /dev/null +++ b/imperative/python/megengine/module/qat/linear_bn.py @@ -0,0 +1,164 @@ +from ...functional import linear, ones, relu, sqrt, sum, zeros +from .. import linear_bn as Float +from .module import QATModule + + +class _LinearBnActivation1d(Float._LinearBnActivation1d, QATModule): + def get_batch_mean_var(self, inp): + def _sum_channel(inp, axis=0, keepdims=True): + if isinstance(axis, int): + out = sum(inp, axis=axis, keepdims=keepdims) + elif isinstance(axis, tuple): + for idx, elem in enumerate(axis): + out = sum(inp if idx == 0 else out, axis=elem, keepdims=keepdims) + return out + + sum1 = _sum_channel(inp, (0, 2, 3)) + sum2 = _sum_channel(inp ** 2, (0, 2, 3)) + reduce_size = inp.size / inp.shape[1] + batch_mean = sum1 / reduce_size + batch_var = (sum2 - sum1 ** 2 / reduce_size) / reduce_size + return batch_mean, batch_var + + def fold_weight_bias(self, bn_mean, bn_var): + weight_shape = [1] * len(self.linear.weight.shape) + weight_shape[0] = -1 + bias_shape = [1] * len(self.linear.weight.shape) + bias_shape[1] = -1 + + # get fold bn linear param + gamma = self.bn.weight + if gamma is None: + gamma = ones((self.bn.num_features,), dtype="float32") + gamma = gamma.reshape(-1) + beta = self.bn.bias + if beta is None: + beta = zeros((self.bn.num_features,), dtype="float32") + beta = beta.reshape(-1) + + if bn_mean is None: + bn_mean = zeros((self.bn.num_features,), dtype="float32") + bn_mean = bn_mean.reshape(-1) + if bn_var is None: + bn_var = ones((self.bn.num_features,), dtype="float32") + bn_var = bn_var.reshape(-1) + + linear_bias = self.linear.bias + if linear_bias is None: + linear_bias = zeros(beta.shape(), dtype="float32") + + bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) + scale_factor = gamma * bn_istd + w_fold = self.linear.weight * scale_factor.reshape(weight_shape) + w_fold = self.apply_quant_weight(w_fold) + b_fold = beta + gamma * (linear_bias - bn_mean) * bn_istd + return w_fold, b_fold + + def update_running_mean_and_running_var( + self, bn_mean, bn_var, num_elements_per_channel + ): + # update running mean and running var. no grad, use unbiased bn var + bn_mean = bn_mean.detach() + bn_var = ( + bn_var.detach() * num_elements_per_channel / (num_elements_per_channel - 1) + ) + exponential_average_factor = 1 - self.bn.momentum + self.bn.running_mean *= self.bn.momentum + self.bn.running_mean += exponential_average_factor * bn_mean + self.bn.running_var *= self.bn.momentum + self.bn.running_var += exponential_average_factor * bn_var + + def calc_linear_bn_qat(self, inp, approx=True): + if self.training and not approx: + linear = self.linear(inp) + bn_mean, bn_var = self.get_batch_mean_var(linear) + num_elements_per_channel = linear.size / linear.shape[1] + self.update_running_mean_and_running_var( + bn_mean, bn_var, num_elements_per_channel + ) + else: + bn_mean, bn_var = self.bn.running_mean, self.bn.running_var + + bn_mean, bn_var = ( + self.bn.running_mean.reshape(-1), + self.bn.running_var.reshape(-1), + ) + + weight_shape = [1] * len(self.linear.weight.shape) + weight_shape[0] = -1 + bias_shape = [1] * len(self.linear.weight.shape) + bias_shape[1] = -1 + + # get gamma and beta in BatchNorm + gamma = self.bn.weight + if gamma is None: + gamma = ones((self.bn.num_features,), dtype="float32") + gamma = gamma.reshape(-1) + beta = self.bn.bias + if beta is None: + beta = zeros((self.bn.num_features,), dtype="float32") + beta = beta.reshape(-1) + + # linear_bias + linear_bias = self.linear.bias + if linear_bias is None: + linear_bias = zeros(beta.shape, dtype="float32") + + bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) + scale_factor = gamma * bn_istd + + w_fold = self.linear.weight * scale_factor.reshape(weight_shape) + b_fold = None + + if not (self.training and approx): + b_fold = beta + gamma * (linear_bias - bn_mean) * bn_istd + + w_qat = self.apply_quant_weight(w_fold) + b_qat = self.apply_quant_bias(b_fold, inp, w_qat) + linear = self.linear.calc_linear(inp, w_qat, b_qat) + if not (self.training and approx): + return linear + + # rescale linear to get original linear output + orig_linear = linear / scale_factor.reshape(bias_shape) + if self.linear.bias is not None: + orig_linear = orig_linear + self.linear.bias.reshape(bias_shape) + # calculate batch norm + linear = self.bn(orig_linear) + return linear + + @classmethod + def from_float_module(cls, float_module: Float._LinearBnActivation1d): + qat_module = cls( + float_module.linear.in_features, + float_module.linear.out_features, + float_module.linear.bias is not None, + float_module.linear.compute_mode, + float_module.bn.eps, + float_module.bn.momentum, + float_module.bn.affine, + float_module.bn.track_running_stats, + name=float_module.name, + ) + qat_module.linear.weight = float_module.linear.weight + qat_module.linear.bias = float_module.linear.bias + qat_module.bn = float_module.bn + return qat_module + + +class LinearBn1d(_LinearBnActivation1d): + r"""A fused :class:`~.QATModule` including :class:`~.module.Linear` and :class:`~.module.BatchNorm1d` with QAT support. + Could be applied with :class:`~.Observer` and :class:`~.quantization.fake_quant.FakeQuantize`. + """ + + def forward(self, inp): + return self.apply_quant_activation(self.calc_linear_bn_qat(inp)) + + +class LinearBnRelu1d(_LinearBnActivation1d): + r"""A fused :class:`~.QATModule` including :class:`~.module.Linear`, :class:`~.module.BatchNorm1d` and :func:`~.relu` with QAT support. + Could be applied with :class:`~.Observer` and :class:`~.quantization.fake_quant.FakeQuantize`. + """ + + def forward(self, inp): + return self.apply_quant_activation(relu(self.calc_linear_bn_qat(inp))) diff --git a/imperative/python/megengine/module/quantized/__init__.py b/imperative/python/megengine/module/quantized/__init__.py index 11ff807ec..349b39664 100644 --- a/imperative/python/megengine/module/quantized/__init__.py +++ b/imperative/python/megengine/module/quantized/__init__.py @@ -4,6 +4,7 @@ from .conv import Conv2d, ConvRelu2d, ConvTranspose2d, ConvTransposeRelu2d from .conv_bn import ConvBn2d, ConvBnRelu2d from .conv_transpose_bn import ConvTransposeBn2d, ConvTransposeBnRelu2d from .elemwise import Elemwise -from .linear import Linear +from .linear import Linear, LinearRelu +from .linear_bn import LinearBn1d, LinearBnRelu1d from .module import QuantizedModule from .quant_dequant import DequantStub, QuantStub diff --git a/imperative/python/megengine/module/quantized/linear.py b/imperative/python/megengine/module/quantized/linear.py index 031b5a50b..4f3c65bc6 100644 --- a/imperative/python/megengine/module/quantized/linear.py +++ b/imperative/python/megengine/module/quantized/linear.py @@ -1,6 +1,7 @@ import numpy as np from ... import functional as F +from ... import module as Float from ...core.tensor import dtype from ...tensor import Parameter from ..qat import linear as QAT @@ -16,20 +17,30 @@ class Linear(QuantizedModule): self.bias = None self.output_dtype = dtype - def forward(self, inp): + def calc_linear_quantized(self, inp, nonlinear_mode="identity"): if self.training: raise ValueError("quantized module only support inference.") + + assert nonlinear_mode in ["identity", "relu"] + inp_scale = dtype.get_scale(inp.dtype) w_scale = dtype.get_scale(self.weight.dtype) bias_dtype = dtype.qint32(inp_scale * w_scale) - ret = F.nn.linear( + ret = F.linear( inp, self.weight, None if self.bias is None else self.bias.astype(bias_dtype), ) ret = ret if self.output_dtype is None else ret.astype(self.output_dtype) + + if nonlinear_mode == "relu": + ret = F.relu(ret) + return ret + def forward(self, inp): + return self.calc_linear_quantized(inp) + @classmethod def from_qat_module(cls, qat_module: QAT.Linear): r""" @@ -38,8 +49,16 @@ class Linear(QuantizedModule): """ output_dtype = qat_module.get_activation_dtype() qmod = cls(dtype=output_dtype, name=qat_module.name) + qmod.name = qat_module.name weight = qat_module.weight.astype(qat_module.get_weight_dtype()) qmod.weight = Parameter(weight.numpy(), name=qat_module.weight.name) if qat_module.bias is not None: qmod.bias = Parameter(qat_module.bias.numpy(), name=qat_module.bias.name) return qmod + + +class LinearRelu(Linear): + r"""Quantized version of :class:`~.qat.LinearRelu`.""" + + def forward(self, inp): + return self.calc_linear_quantized(inp, nonlinear_mode="relu") diff --git a/imperative/python/megengine/module/quantized/linear_bn.py b/imperative/python/megengine/module/quantized/linear_bn.py new file mode 100644 index 000000000..49d732dbc --- /dev/null +++ b/imperative/python/megengine/module/quantized/linear_bn.py @@ -0,0 +1,40 @@ +from ...tensor import Parameter +from ..qat import linear_bn as QAT +from .linear import Linear + + +class _LinearBnActivation1d(Linear): + r"""Applies a Linear over a quantized input tensor, used for inference only. + """ + + @classmethod + def from_qat_module(cls, qat_module: QAT._LinearBnActivation1d): + r""" + Return a :class:`~.QuantizedModule` instance converted from a + :class:`~.QATModule` instance. + """ + output_dtype = qat_module.get_activation_dtype() + qlinear = cls(dtype=output_dtype, name=qat_module.name,) + w_fold, b_fold = qat_module.fold_weight_bias( + qat_module.bn.running_mean, qat_module.bn.running_var + ) + weight = w_fold.astype(qat_module.get_weight_dtype()) + qlinear.weight = Parameter(weight.numpy(), name=qat_module.linear.weight.name) + qlinear.bias = Parameter(b_fold.numpy()) + if qat_module.linear.bias is not None: + qlinear.bias.name = qat_module.linear.bias.name + return qlinear + + +class LinearBn1d(_LinearBnActivation1d): + r"""Quantized version of :class:`~.qat.LinearBn1d`.""" + + def forward(self, inp): + return self.calc_linear_quantized(inp, nonlinear_mode="identity") + + +class LinearBnRelu1d(_LinearBnActivation1d): + r"""Quantized version of :class:`~.qat.LinearBnRelu1d`.""" + + def forward(self, inp): + return self.calc_linear_quantized(inp, nonlinear_mode="relu") diff --git a/imperative/python/megengine/utils/bn_fusion.py b/imperative/python/megengine/utils/bn_fusion.py index 01e435f42..4814e63c7 100644 --- a/imperative/python/megengine/utils/bn_fusion.py +++ b/imperative/python/megengine/utils/bn_fusion.py @@ -2,6 +2,7 @@ from copy import deepcopy from ..functional import ones, sqrt, zeros from ..module import ( + BatchNorm1d, BatchNorm2d, Conv2d, ConvBn2d, @@ -11,6 +12,10 @@ from ..module import ( ConvTransposeBn2d, ConvTransposeBnRelu2d, ConvTransposeRelu2d, + Linear, + LinearBn1d, + LinearBnRelu1d, + LinearRelu, ReLU, ) from ..tensor import Parameter @@ -26,10 +31,15 @@ _MAP_TO_FUSED_MODULE = { (ConvTranspose2d, BatchNorm2d, False): ConvTranspose2d, (ConvTranspose2d, BatchNorm2d, True): ConvTransposeBn2d, (ConvTranspose2d, ReLU): ConvTransposeRelu2d, + (Linear, BatchNorm1d, ReLU, False): LinearRelu, + (Linear, BatchNorm1d, ReLU, True): LinearBnRelu1d, + (Linear, BatchNorm1d, False): Linear, + (Linear, BatchNorm1d, True): LinearBn1d, + (Linear, ReLU): LinearRelu, } -def fold_weight_bias( +def _fold_conv_bn_weight_bias( weight, bias, gamma, beta, bn_mean, bn_var, eps=1e-5, transpose=False ): shape = (-1, 1, 1, 1) @@ -76,6 +86,57 @@ def fold_weight_bias( return w_fold, b_fold +def _fold_linear_bn_weight_bias(weight, bias, gamma, beta, bn_mean, bn_var, eps=1e-5): + bn_mean, bn_var = bn_mean.reshape(-1), bn_var.reshape(-1) + weight_shape = [1] * len(weight.shape) + weight_shape[0] = -1 + bias_shape = [1] * len(weight.shape) + bias_shape[1] = -1 + + out_features = weight.shape[0] + if gamma is None: + gamma = ones((out_features,), dtype="float32") + else: + gamma = gamma.reshape(-1) + if beta is None: + beta = zeros((out_features,), dtype="float32") + else: + beta = beta.reshape(-1) + + if bn_mean is None: + bn_mean = zeros((out_features,), dtype="float32") + else: + bn_mean = bn_mean.reshape(-1) + if bn_var is None: + bn_var = ones((out_features,), dtype="float32") + else: + bn_var = bn_var.reshape(-1) + + if bias is None: + bias = zeros((beta.shape), dtype="float32") + else: + bias = bias.reshape(-1) + + bn_istd = 1.0 / sqrt(bn_var + eps) + scale_factor = gamma * bn_istd + + w_fold = weight * scale_factor.reshape(*weight_shape) + b_fold = beta + gamma * (bias - bn_mean) * bn_istd + + return w_fold, b_fold + + +def fold_weight_bias( + weight, bias, gamma, beta, bn_mean, bn_var, eps=1e-5, transpose=False +): + if weight.ndim != 2: + return _fold_conv_bn_weight_bias( + weight, bias, gamma, beta, bn_mean, bn_var, eps, transpose + ) + + return _fold_linear_bn_weight_bias(weight, bias, gamma, beta, bn_mean, bn_var, eps) + + def fuse_conv_bn_relu_module(conv: Conv2d, bn: BatchNorm2d, relu: ReLU): module_key = tuple([type(m) for m in [conv, bn, relu] if m]) if bn: @@ -137,3 +198,37 @@ def fuse_conv_bn_relu_module(conv: Conv2d, bn: BatchNorm2d, relu: ReLU): module.bn = deepcopy(bn) new_conv.training = conv.training return module + + +def fuse_linear_bn_relu_module(linear: Linear, bn: BatchNorm1d, relu: ReLU): + module_key = tuple([type(m) for m in [linear, bn, relu] if m]) + if bn: + assert ( + linear.training == bn.training + ), "Linear and BN both must be in the same mode (train or eval)." + assert ( + bn.num_features == linear.out_features + ), "Output channel of Linear must match num_features of BatchNorm1d" + module_key = module_key + (linear.training,) + module = _MAP_TO_FUSED_MODULE[module_key]( + in_features=linear.in_features, + out_features=linear.out_features, + bias=linear.bias is not None, + compute_mode=linear.compute_mode, + name=linear.name, + ) + + new_linear = module if bn is None or not linear.training else module.linear + + weight, bias = linear.weight, linear.bias + if not linear.training and bn is not None: + weight, bias = fold_weight_bias( + weight, bias, bn.weight, bn.bias, bn.running_mean, bn.running_var, bn.eps, + ) + new_linear.weight = Parameter(weight) + if bias is not None: + new_linear.bias = Parameter(bias) + if bn is not None and linear.training: + module.bn = deepcopy(bn) + new_linear.training = linear.training + return module diff --git a/imperative/python/test/unit/module/test_qat.py b/imperative/python/test/unit/module/test_qat.py index 7b99fe0fe..1ab3fc0f7 100644 --- a/imperative/python/test/unit/module/test_qat.py +++ b/imperative/python/test/unit/module/test_qat.py @@ -19,6 +19,10 @@ from megengine.module import ( ConvTransposeBn2d, ConvTransposeRelu2d, DequantStub, + Linear, + LinearBn1d, + LinearBnRelu1d, + LinearRelu, Module, QuantStub, ) @@ -330,3 +334,127 @@ def test_qat_conv_transpose2d(): np.testing.assert_allclose( normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-6 ) + + +def test_qat_linearbn1d(): + in_features = 10 + out_features = 5 + + class TestNet(Module): + def __init__(self, bias): + super().__init__() + self.quant = QuantStub() + self.dequant = DequantStub() + self.linear_bn = LinearBn1d(in_features, out_features, bias=bias,) + + def forward(self, inp): + out = self.quant(inp) + out = self.linear_bn(out) + out = self.dequant(out) + return out + + inputs = tensor(np.random.randn(4, in_features).astype(np.float32)) + for bias in [True, False]: + net = TestNet(bias) + net.train() + qat_net = quantize_qat(net, inplace=False) + disable_fake_quant(qat_net) + normal_outputs = net(inputs) + qat_outputs = qat_net(inputs) + np.testing.assert_allclose( + normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-6, + ) + np.testing.assert_allclose( + net.linear_bn.bn.running_mean.numpy(), + qat_net.linear_bn.bn.running_mean.numpy(), + atol=5e-8, + ) + np.testing.assert_allclose( + net.linear_bn.bn.running_var.numpy(), + qat_net.linear_bn.bn.running_var.numpy(), + atol=5e-7, + ) + + net.eval() + normal_outputs = net(inputs) + qat_net.eval() + qat_outputs = qat_net(inputs) + np.testing.assert_allclose( + normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-6, + ) + + +def test_qat_linear_relu(): + in_features = 10 + out_features = 5 + + class TestNet(Module): + def __init__(self, bias): + super().__init__() + self.quant = QuantStub() + self.dequant = DequantStub() + self.linear_relu = LinearRelu(in_features, out_features, bias=bias,) + + def forward(self, inp): + out = self.quant(inp) + out = self.linear_relu(out) + out = self.dequant(out) + return out + + inputs = tensor(np.random.randn(4, in_features).astype(np.float32)) + for bias in [True, False]: + net = TestNet(bias) + net.train() + qat_net = quantize_qat(net, inplace=False) + disable_fake_quant(qat_net) + normal_outputs = net(inputs) + qat_outputs = qat_net(inputs) + np.testing.assert_allclose( + normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-6, + ) + + net.eval() + normal_outputs = net(inputs) + qat_net.eval() + qat_outputs = qat_net(inputs) + np.testing.assert_allclose( + normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-6, + ) + + +def test_qat_linear_bn_relu(): + in_features = 10 + out_features = 5 + + class TestNet(Module): + def __init__(self, bias): + super().__init__() + self.quant = QuantStub() + self.dequant = DequantStub() + self.linear_bn_relu = LinearBnRelu1d(in_features, out_features, bias=bias,) + + def forward(self, inp): + out = self.quant(inp) + out = self.linear_bn_relu(out) + out = self.dequant(out) + return out + + inputs = tensor(np.random.randn(4, in_features).astype(np.float32)) + for bias in [True, False]: + net = TestNet(bias) + net.train() + qat_net = quantize_qat(net, inplace=False) + disable_fake_quant(qat_net) + normal_outputs = net(inputs) + qat_outputs = qat_net(inputs) + np.testing.assert_allclose( + normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-6, + ) + + net.eval() + normal_outputs = net(inputs) + qat_net.eval() + qat_outputs = qat_net(inputs) + np.testing.assert_allclose( + normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-6, + ) diff --git a/imperative/python/test/unit/quantization/test_quantize.py b/imperative/python/test/unit/quantization/test_quantize.py index 82661bc2d..3412a5ffd 100644 --- a/imperative/python/test/unit/quantization/test_quantize.py +++ b/imperative/python/test/unit/quantization/test_quantize.py @@ -5,11 +5,14 @@ from megengine import Parameter, Tensor from megengine import module as Float from megengine.functional import ones, zeros from megengine.module import ( + BatchNorm1d, BatchNorm2d, Conv2d, ConvBn2d, ConvTranspose2d, ConvTransposeBn2d, + Linear, + LinearBn1d, ReLU, ) from megengine.module import qat as QAT @@ -33,7 +36,10 @@ from megengine.quantization.quantize import ( quantize_qat, reset_qconfig, ) -from megengine.utils.bn_fusion import fuse_conv_bn_relu_module +from megengine.utils.bn_fusion import ( + fuse_conv_bn_relu_module, + fuse_linear_bn_relu_module, +) class FloatNet(Float.Module): @@ -383,3 +389,43 @@ def test_ConvTransposeBn2d_fold_weight_bias(): np.testing.assert_allclose( expected_result.numpy(), actual_result.numpy(), atol=1e-4 ) + + +def test_LinearBn1d_fold_weight_bias(): + in_features = 10 + out_features = 5 + + linear = Linear(in_features, out_features) + bn = BatchNorm1d(out_features) + relu = ReLU() + + fused_linear = fuse_linear_bn_relu_module(linear, bn, relu) + bn.eval() + fused_linear.eval() + inputs = Tensor(np.random.randn(4, in_features).astype(np.float32)) + expected_result = relu(bn(linear(inputs))) + actual_result = fused_linear(inputs) + np.testing.assert_allclose( + expected_result.numpy(), actual_result.numpy(), atol=1e-4 + ) + + linear.eval() + bn.eval() + relu.eval() + fused_linear = fuse_linear_bn_relu_module(linear, bn, relu) + fused_linear.eval() + expected_result = relu(linear(inputs)) + actual_result = fused_linear(inputs) + np.testing.assert_allclose( + expected_result.numpy(), actual_result.numpy(), atol=1e-4 + ) + + linear.train() + bn.train() + fused_linear = fuse_linear_bn_relu_module(linear, bn, None) + fused_linear.train() + expected_result = bn(linear(inputs)) + actual_result = fused_linear(inputs) + np.testing.assert_allclose( + expected_result.numpy(), actual_result.numpy(), atol=1e-4 + ) -- GitLab