From 980ebf2c72ab7654c9311f3fcd88ae87861f3cfb Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Sat, 16 May 2020 16:44:38 +0800 Subject: [PATCH] feat(mge/module): add fused conv_bn qat approximate version GitOrigin-RevId: 1b7284a5951229c8924cb880de41ebf58db19fea --- .../megengine/module/conv_bn_relu.py | 149 ++++++++++++------ .../test/unit/module/test_conv_bn_relu.py | 39 +++++ 2 files changed, 138 insertions(+), 50 deletions(-) create mode 100644 python_module/test/unit/module/test_conv_bn_relu.py diff --git a/python_module/megengine/module/conv_bn_relu.py b/python_module/megengine/module/conv_bn_relu.py index 15964fcd..af088fea 100644 --- a/python_module/megengine/module/conv_bn_relu.py +++ b/python_module/megengine/module/conv_bn_relu.py @@ -8,7 +8,7 @@ from typing import Tuple, Union from ..core import ones, zeros -from ..functional import flatten, relu, sqrt, sum +from ..functional import add_update, flatten, relu, sqrt, sum, zero_grad from .batchnorm import BatchNorm2d from .conv import Conv2d from .module import QATModule @@ -31,7 +31,6 @@ class _ConvBn2d(QATModule): momentum=0.9, affine=True, track_running_stats=True, - freeze_bn=False, ): super().__init__() self.conv = Conv2d( @@ -47,28 +46,6 @@ class _ConvBn2d(QATModule): compute_mode, ) self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats) - self.freeze_bn = freeze_bn - - def update_bn_stats(self): - self.freeze_bn = False - return self - - def freeze_bn_stats(self): - self.freeze_bn = True - return self - - def get_bn_gamma_beta(self): - if self.bn.weight is None: - gamma = ones((self.bn.num_features), dtype="float32") - else: - gamma = self.bn.weight - - if self.bn.bias is None: - beta = zeros((self.bn.num_features), dtype="float32") - else: - beta = self.bn.bias - - return gamma, beta def get_batch_mean_var(self, inp): def _sum_channel(inp, axis=0, keepdims=True): @@ -83,8 +60,7 @@ class _ConvBn2d(QATModule): sum2 = _sum_channel(inp ** 2, (0, 2, 3)) reduce_size = inp.shapeof().prod() / inp.shapeof(1) batch_mean = sum1 / reduce_size - batch_var = (sum2 - sum1 ** 2 / reduce_size) / (reduce_size - 1) - + batch_var = (sum2 - sum1 ** 2 / reduce_size) / reduce_size return batch_mean, batch_var def fold_weight_bias(self, bn_mean, bn_var): @@ -92,50 +68,123 @@ class _ConvBn2d(QATModule): # bn_istd = 1 / bn_std # w_fold = gamma / bn_std * W # b_fold = gamma * (b - bn_mean) / bn_std + beta - gamma, beta = self.get_bn_gamma_beta() - b = self.conv.bias - if b is None: - b = zeros(self.conv._infer_bias_shape(), dtype="float32") + gamma = self.bn.weight + if gamma is None: + gamma = ones((self.bn.num_features), dtype="float32") + gamma = gamma.reshape(1, -1, 1, 1) + beta = self.bn.bias + if beta is None: + beta = zeros((self.bn.num_features), dtype="float32") + beta = beta.reshape(1, -1, 1, 1) + if bn_mean is None: bn_mean = zeros((1, self.bn.num_features, 1, 1), dtype="float32") if bn_var is None: bn_var = ones((1, self.bn.num_features, 1, 1), dtype="float32") + conv_bias = self.conv.bias + if conv_bias is None: + conv_bias = zeros(self.conv._infer_bias_shape(), dtype="float32") + bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) + # bn_istd = 1 / bn_std + # w_fold = gamma / bn_std * W + scale_factor = gamma * bn_istd if self.conv.groups == 1: - w_fold = ( - self.conv.weight - * gamma.reshape(-1, 1, 1, 1) - * bn_istd.reshape(-1, 1, 1, 1) - ) + w_fold = self.conv.weight * scale_factor.reshape(-1, 1, 1, 1) else: - w_fold = ( - self.conv.weight - * gamma.reshape(self.conv.groups, -1, 1, 1, 1) - * bn_istd.reshape(self.conv.groups, -1, 1, 1, 1) + w_fold = self.conv.weight * scale_factor.reshape( + self.conv.groups, -1, 1, 1, 1 ) - b_fold = flatten(beta) + ( - flatten(gamma) * (flatten(b) - flatten(bn_mean)) * flatten(bn_istd) - ) - b_fold = b_fold.reshape(self.conv._infer_bias_shape()) + # b_fold = gamma * (b - bn_mean) / bn_std + beta + b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd return w_fold, b_fold - def calc_conv_bn_qat(self, inp): - # TODO: use pytorch method as - conv = self.conv(inp) - self.bn(conv) + def update_running_mean_and_running_var( + self, bn_mean, bn_var, num_elements_per_channel + ): + # update running mean and running var. no grad, use unbiased bn var + bn_mean = zero_grad(bn_mean) + bn_var = ( + zero_grad(bn_var) + * num_elements_per_channel + / (num_elements_per_channel - 1) + ) + exponential_average_factor = 1 - self.bn.momentum + add_update( + self.bn.running_mean, + delta=bn_mean, + alpha=1 - exponential_average_factor, + beta=exponential_average_factor, + ) + add_update( + self.bn.running_var, + delta=bn_var, + alpha=1 - exponential_average_factor, + beta=exponential_average_factor, + ) - if self.training: + def calc_conv_bn_qat(self, inp, approx=True): + if self.training and not approx: + conv = self.conv(inp) bn_mean, bn_var = self.get_batch_mean_var(conv) + num_elements_per_channel = conv.shapeof().prod() / conv.shapeof(1) + self.update_running_mean_and_running_var( + bn_mean, bn_var, num_elements_per_channel + ) else: bn_mean, bn_var = self.bn.running_mean, self.bn.running_var - w_fold, b_fold = self.fold_weight_bias(bn_mean, bn_var) + # get gamma and beta in BatchNorm + gamma = self.bn.weight + if gamma is None: + gamma = ones((self.bn.num_features), dtype="float32") + gamma = gamma.reshape(1, -1, 1, 1) + beta = self.bn.bias + if beta is None: + beta = zeros((self.bn.num_features), dtype="float32") + beta = beta.reshape(1, -1, 1, 1) + # conv_bias + conv_bias = self.conv.bias + if conv_bias is None: + conv_bias = zeros(self.conv._infer_bias_shape(), dtype="float32") + + bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) + # bn_istd = 1 / bn_std + # w_fold = gamma / bn_std * W + scale_factor = gamma * bn_istd + if self.conv.groups == 1: + w_fold = self.conv.weight * scale_factor.reshape(-1, 1, 1, 1) + else: + w_fold = self.conv.weight * scale_factor.reshape( + self.conv.groups, -1, 1, 1, 1 + ) + b_fold = None + if not (self.training and approx): + # b_fold = gamma * (conv_bias - bn_mean) / bn_std + beta + b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd + w_qat = self.apply_fakequant_with_observer( w_fold, self.weight_fake_quant, self.weight_observer ) - return self.conv.calc_conv(inp, w_qat, b_fold) + conv = self.conv.calc_conv(inp, w_qat, b_fold) + if not (self.training and approx): + return conv + + # rescale conv to get original conv output + orig_conv = conv / scale_factor.reshape(1, -1, 1, 1) + if self.conv.bias is not None: + orig_conv = orig_conv + self.conv.bias + # calculate batch norm + bn_mean, bn_var = self.get_batch_mean_var(orig_conv) + bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) + conv = gamma * bn_istd * (orig_conv - bn_mean) + beta + num_elements_per_channel = conv.shapeof().prod() / conv.shapeof(1) + self.update_running_mean_and_running_var( + bn_mean, bn_var, num_elements_per_channel + ) + return conv class ConvBn2d(_ConvBn2d): diff --git a/python_module/test/unit/module/test_conv_bn_relu.py b/python_module/test/unit/module/test_conv_bn_relu.py new file mode 100644 index 00000000..c713448f --- /dev/null +++ b/python_module/test/unit/module/test_conv_bn_relu.py @@ -0,0 +1,39 @@ +import copy +from itertools import product + +import numpy as np + +from megengine import tensor +from megengine.module import ConvBn2d +from megengine.quantization import quantize_qat +from megengine.quantization.quantize import disable_fake_quant +from megengine.test import assertTensorClose + + +def test_convbn2d(): + in_channels = 32 + out_channels = 64 + kernel_size = 3 + module = ConvBn2d(in_channels, out_channels, kernel_size) + quantize_qat(module) + for groups, bias in product([1, 4], [True, False]): + inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) + module.train() + qat_module = copy.deepcopy(module) + disable_fake_quant(qat_module) + normal_outputs = module.forward(inputs) + qat_outputs = qat_module.forward_qat(inputs) + assertTensorClose(normal_outputs, qat_outputs, max_err=5e-6) + a = module.bn.running_mean.numpy() + b = qat_module.bn.running_mean.numpy() + assertTensorClose( + module.bn.running_mean, qat_module.bn.running_mean, max_err=5e-8 + ) + assertTensorClose( + module.bn.running_var, qat_module.bn.running_var, max_err=5e-7 + ) + module.eval() + normal_outputs = module.forward(inputs) + qat_module.eval() + qat_outputs = qat_module.forward_qat(inputs) + assertTensorClose(normal_outputs, qat_outputs, max_err=5e-6) -- GitLab