提交 980ebf2c 编写于 作者: M Megvii Engine Team

feat(mge/module): add fused conv_bn qat approximate version

GitOrigin-RevId: 1b7284a5951229c8924cb880de41ebf58db19fea
上级 6972bfde
......@@ -8,7 +8,7 @@
from typing import Tuple, Union
from ..core import ones, zeros
from ..functional import flatten, relu, sqrt, sum
from ..functional import add_update, flatten, relu, sqrt, sum, zero_grad
from .batchnorm import BatchNorm2d
from .conv import Conv2d
from .module import QATModule
......@@ -31,7 +31,6 @@ class _ConvBn2d(QATModule):
momentum=0.9,
affine=True,
track_running_stats=True,
freeze_bn=False,
):
super().__init__()
self.conv = Conv2d(
......@@ -47,28 +46,6 @@ class _ConvBn2d(QATModule):
compute_mode,
)
self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats)
self.freeze_bn = freeze_bn
def update_bn_stats(self):
self.freeze_bn = False
return self
def freeze_bn_stats(self):
self.freeze_bn = True
return self
def get_bn_gamma_beta(self):
if self.bn.weight is None:
gamma = ones((self.bn.num_features), dtype="float32")
else:
gamma = self.bn.weight
if self.bn.bias is None:
beta = zeros((self.bn.num_features), dtype="float32")
else:
beta = self.bn.bias
return gamma, beta
def get_batch_mean_var(self, inp):
def _sum_channel(inp, axis=0, keepdims=True):
......@@ -83,8 +60,7 @@ class _ConvBn2d(QATModule):
sum2 = _sum_channel(inp ** 2, (0, 2, 3))
reduce_size = inp.shapeof().prod() / inp.shapeof(1)
batch_mean = sum1 / reduce_size
batch_var = (sum2 - sum1 ** 2 / reduce_size) / (reduce_size - 1)
batch_var = (sum2 - sum1 ** 2 / reduce_size) / reduce_size
return batch_mean, batch_var
def fold_weight_bias(self, bn_mean, bn_var):
......@@ -92,50 +68,123 @@ class _ConvBn2d(QATModule):
# bn_istd = 1 / bn_std
# w_fold = gamma / bn_std * W
# b_fold = gamma * (b - bn_mean) / bn_std + beta
gamma, beta = self.get_bn_gamma_beta()
b = self.conv.bias
if b is None:
b = zeros(self.conv._infer_bias_shape(), dtype="float32")
gamma = self.bn.weight
if gamma is None:
gamma = ones((self.bn.num_features), dtype="float32")
gamma = gamma.reshape(1, -1, 1, 1)
beta = self.bn.bias
if beta is None:
beta = zeros((self.bn.num_features), dtype="float32")
beta = beta.reshape(1, -1, 1, 1)
if bn_mean is None:
bn_mean = zeros((1, self.bn.num_features, 1, 1), dtype="float32")
if bn_var is None:
bn_var = ones((1, self.bn.num_features, 1, 1), dtype="float32")
conv_bias = self.conv.bias
if conv_bias is None:
conv_bias = zeros(self.conv._infer_bias_shape(), dtype="float32")
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps)
# bn_istd = 1 / bn_std
# w_fold = gamma / bn_std * W
scale_factor = gamma * bn_istd
if self.conv.groups == 1:
w_fold = (
self.conv.weight
* gamma.reshape(-1, 1, 1, 1)
* bn_istd.reshape(-1, 1, 1, 1)
)
w_fold = self.conv.weight * scale_factor.reshape(-1, 1, 1, 1)
else:
w_fold = (
self.conv.weight
* gamma.reshape(self.conv.groups, -1, 1, 1, 1)
* bn_istd.reshape(self.conv.groups, -1, 1, 1, 1)
w_fold = self.conv.weight * scale_factor.reshape(
self.conv.groups, -1, 1, 1, 1
)
b_fold = flatten(beta) + (
flatten(gamma) * (flatten(b) - flatten(bn_mean)) * flatten(bn_istd)
)
b_fold = b_fold.reshape(self.conv._infer_bias_shape())
# b_fold = gamma * (b - bn_mean) / bn_std + beta
b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd
return w_fold, b_fold
def calc_conv_bn_qat(self, inp):
# TODO: use pytorch method as
conv = self.conv(inp)
self.bn(conv)
def update_running_mean_and_running_var(
self, bn_mean, bn_var, num_elements_per_channel
):
# update running mean and running var. no grad, use unbiased bn var
bn_mean = zero_grad(bn_mean)
bn_var = (
zero_grad(bn_var)
* num_elements_per_channel
/ (num_elements_per_channel - 1)
)
exponential_average_factor = 1 - self.bn.momentum
add_update(
self.bn.running_mean,
delta=bn_mean,
alpha=1 - exponential_average_factor,
beta=exponential_average_factor,
)
add_update(
self.bn.running_var,
delta=bn_var,
alpha=1 - exponential_average_factor,
beta=exponential_average_factor,
)
if self.training:
def calc_conv_bn_qat(self, inp, approx=True):
if self.training and not approx:
conv = self.conv(inp)
bn_mean, bn_var = self.get_batch_mean_var(conv)
num_elements_per_channel = conv.shapeof().prod() / conv.shapeof(1)
self.update_running_mean_and_running_var(
bn_mean, bn_var, num_elements_per_channel
)
else:
bn_mean, bn_var = self.bn.running_mean, self.bn.running_var
w_fold, b_fold = self.fold_weight_bias(bn_mean, bn_var)
# get gamma and beta in BatchNorm
gamma = self.bn.weight
if gamma is None:
gamma = ones((self.bn.num_features), dtype="float32")
gamma = gamma.reshape(1, -1, 1, 1)
beta = self.bn.bias
if beta is None:
beta = zeros((self.bn.num_features), dtype="float32")
beta = beta.reshape(1, -1, 1, 1)
# conv_bias
conv_bias = self.conv.bias
if conv_bias is None:
conv_bias = zeros(self.conv._infer_bias_shape(), dtype="float32")
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps)
# bn_istd = 1 / bn_std
# w_fold = gamma / bn_std * W
scale_factor = gamma * bn_istd
if self.conv.groups == 1:
w_fold = self.conv.weight * scale_factor.reshape(-1, 1, 1, 1)
else:
w_fold = self.conv.weight * scale_factor.reshape(
self.conv.groups, -1, 1, 1, 1
)
b_fold = None
if not (self.training and approx):
# b_fold = gamma * (conv_bias - bn_mean) / bn_std + beta
b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd
w_qat = self.apply_fakequant_with_observer(
w_fold, self.weight_fake_quant, self.weight_observer
)
return self.conv.calc_conv(inp, w_qat, b_fold)
conv = self.conv.calc_conv(inp, w_qat, b_fold)
if not (self.training and approx):
return conv
# rescale conv to get original conv output
orig_conv = conv / scale_factor.reshape(1, -1, 1, 1)
if self.conv.bias is not None:
orig_conv = orig_conv + self.conv.bias
# calculate batch norm
bn_mean, bn_var = self.get_batch_mean_var(orig_conv)
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps)
conv = gamma * bn_istd * (orig_conv - bn_mean) + beta
num_elements_per_channel = conv.shapeof().prod() / conv.shapeof(1)
self.update_running_mean_and_running_var(
bn_mean, bn_var, num_elements_per_channel
)
return conv
class ConvBn2d(_ConvBn2d):
......
import copy
from itertools import product
import numpy as np
from megengine import tensor
from megengine.module import ConvBn2d
from megengine.quantization import quantize_qat
from megengine.quantization.quantize import disable_fake_quant
from megengine.test import assertTensorClose
def test_convbn2d():
in_channels = 32
out_channels = 64
kernel_size = 3
module = ConvBn2d(in_channels, out_channels, kernel_size)
quantize_qat(module)
for groups, bias in product([1, 4], [True, False]):
inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
module.train()
qat_module = copy.deepcopy(module)
disable_fake_quant(qat_module)
normal_outputs = module.forward(inputs)
qat_outputs = qat_module.forward_qat(inputs)
assertTensorClose(normal_outputs, qat_outputs, max_err=5e-6)
a = module.bn.running_mean.numpy()
b = qat_module.bn.running_mean.numpy()
assertTensorClose(
module.bn.running_mean, qat_module.bn.running_mean, max_err=5e-8
)
assertTensorClose(
module.bn.running_var, qat_module.bn.running_var, max_err=5e-7
)
module.eval()
normal_outputs = module.forward(inputs)
qat_module.eval()
qat_outputs = qat_module.forward_qat(inputs)
assertTensorClose(normal_outputs, qat_outputs, max_err=5e-6)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册