提交 980ebf2c 编写于 作者: M Megvii Engine Team

feat(mge/module): add fused conv_bn qat approximate version

GitOrigin-RevId: 1b7284a5951229c8924cb880de41ebf58db19fea
上级 6972bfde
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
from typing import Tuple, Union from typing import Tuple, Union
from ..core import ones, zeros from ..core import ones, zeros
from ..functional import flatten, relu, sqrt, sum from ..functional import add_update, flatten, relu, sqrt, sum, zero_grad
from .batchnorm import BatchNorm2d from .batchnorm import BatchNorm2d
from .conv import Conv2d from .conv import Conv2d
from .module import QATModule from .module import QATModule
...@@ -31,7 +31,6 @@ class _ConvBn2d(QATModule): ...@@ -31,7 +31,6 @@ class _ConvBn2d(QATModule):
momentum=0.9, momentum=0.9,
affine=True, affine=True,
track_running_stats=True, track_running_stats=True,
freeze_bn=False,
): ):
super().__init__() super().__init__()
self.conv = Conv2d( self.conv = Conv2d(
...@@ -47,28 +46,6 @@ class _ConvBn2d(QATModule): ...@@ -47,28 +46,6 @@ class _ConvBn2d(QATModule):
compute_mode, compute_mode,
) )
self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats) self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats)
self.freeze_bn = freeze_bn
def update_bn_stats(self):
self.freeze_bn = False
return self
def freeze_bn_stats(self):
self.freeze_bn = True
return self
def get_bn_gamma_beta(self):
if self.bn.weight is None:
gamma = ones((self.bn.num_features), dtype="float32")
else:
gamma = self.bn.weight
if self.bn.bias is None:
beta = zeros((self.bn.num_features), dtype="float32")
else:
beta = self.bn.bias
return gamma, beta
def get_batch_mean_var(self, inp): def get_batch_mean_var(self, inp):
def _sum_channel(inp, axis=0, keepdims=True): def _sum_channel(inp, axis=0, keepdims=True):
...@@ -83,8 +60,7 @@ class _ConvBn2d(QATModule): ...@@ -83,8 +60,7 @@ class _ConvBn2d(QATModule):
sum2 = _sum_channel(inp ** 2, (0, 2, 3)) sum2 = _sum_channel(inp ** 2, (0, 2, 3))
reduce_size = inp.shapeof().prod() / inp.shapeof(1) reduce_size = inp.shapeof().prod() / inp.shapeof(1)
batch_mean = sum1 / reduce_size batch_mean = sum1 / reduce_size
batch_var = (sum2 - sum1 ** 2 / reduce_size) / (reduce_size - 1) batch_var = (sum2 - sum1 ** 2 / reduce_size) / reduce_size
return batch_mean, batch_var return batch_mean, batch_var
def fold_weight_bias(self, bn_mean, bn_var): def fold_weight_bias(self, bn_mean, bn_var):
...@@ -92,50 +68,123 @@ class _ConvBn2d(QATModule): ...@@ -92,50 +68,123 @@ class _ConvBn2d(QATModule):
# bn_istd = 1 / bn_std # bn_istd = 1 / bn_std
# w_fold = gamma / bn_std * W # w_fold = gamma / bn_std * W
# b_fold = gamma * (b - bn_mean) / bn_std + beta # b_fold = gamma * (b - bn_mean) / bn_std + beta
gamma, beta = self.get_bn_gamma_beta() gamma = self.bn.weight
b = self.conv.bias if gamma is None:
if b is None: gamma = ones((self.bn.num_features), dtype="float32")
b = zeros(self.conv._infer_bias_shape(), dtype="float32") gamma = gamma.reshape(1, -1, 1, 1)
beta = self.bn.bias
if beta is None:
beta = zeros((self.bn.num_features), dtype="float32")
beta = beta.reshape(1, -1, 1, 1)
if bn_mean is None: if bn_mean is None:
bn_mean = zeros((1, self.bn.num_features, 1, 1), dtype="float32") bn_mean = zeros((1, self.bn.num_features, 1, 1), dtype="float32")
if bn_var is None: if bn_var is None:
bn_var = ones((1, self.bn.num_features, 1, 1), dtype="float32") bn_var = ones((1, self.bn.num_features, 1, 1), dtype="float32")
conv_bias = self.conv.bias
if conv_bias is None:
conv_bias = zeros(self.conv._infer_bias_shape(), dtype="float32")
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) bn_istd = 1.0 / sqrt(bn_var + self.bn.eps)
# bn_istd = 1 / bn_std
# w_fold = gamma / bn_std * W
scale_factor = gamma * bn_istd
if self.conv.groups == 1: if self.conv.groups == 1:
w_fold = ( w_fold = self.conv.weight * scale_factor.reshape(-1, 1, 1, 1)
self.conv.weight
* gamma.reshape(-1, 1, 1, 1)
* bn_istd.reshape(-1, 1, 1, 1)
)
else: else:
w_fold = ( w_fold = self.conv.weight * scale_factor.reshape(
self.conv.weight self.conv.groups, -1, 1, 1, 1
* gamma.reshape(self.conv.groups, -1, 1, 1, 1)
* bn_istd.reshape(self.conv.groups, -1, 1, 1, 1)
) )
b_fold = flatten(beta) + (
flatten(gamma) * (flatten(b) - flatten(bn_mean)) * flatten(bn_istd)
)
b_fold = b_fold.reshape(self.conv._infer_bias_shape())
# b_fold = gamma * (b - bn_mean) / bn_std + beta
b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd
return w_fold, b_fold return w_fold, b_fold
def calc_conv_bn_qat(self, inp): def update_running_mean_and_running_var(
# TODO: use pytorch method as self, bn_mean, bn_var, num_elements_per_channel
conv = self.conv(inp) ):
self.bn(conv) # update running mean and running var. no grad, use unbiased bn var
bn_mean = zero_grad(bn_mean)
bn_var = (
zero_grad(bn_var)
* num_elements_per_channel
/ (num_elements_per_channel - 1)
)
exponential_average_factor = 1 - self.bn.momentum
add_update(
self.bn.running_mean,
delta=bn_mean,
alpha=1 - exponential_average_factor,
beta=exponential_average_factor,
)
add_update(
self.bn.running_var,
delta=bn_var,
alpha=1 - exponential_average_factor,
beta=exponential_average_factor,
)
if self.training: def calc_conv_bn_qat(self, inp, approx=True):
if self.training and not approx:
conv = self.conv(inp)
bn_mean, bn_var = self.get_batch_mean_var(conv) bn_mean, bn_var = self.get_batch_mean_var(conv)
num_elements_per_channel = conv.shapeof().prod() / conv.shapeof(1)
self.update_running_mean_and_running_var(
bn_mean, bn_var, num_elements_per_channel
)
else: else:
bn_mean, bn_var = self.bn.running_mean, self.bn.running_var bn_mean, bn_var = self.bn.running_mean, self.bn.running_var
w_fold, b_fold = self.fold_weight_bias(bn_mean, bn_var) # get gamma and beta in BatchNorm
gamma = self.bn.weight
if gamma is None:
gamma = ones((self.bn.num_features), dtype="float32")
gamma = gamma.reshape(1, -1, 1, 1)
beta = self.bn.bias
if beta is None:
beta = zeros((self.bn.num_features), dtype="float32")
beta = beta.reshape(1, -1, 1, 1)
# conv_bias
conv_bias = self.conv.bias
if conv_bias is None:
conv_bias = zeros(self.conv._infer_bias_shape(), dtype="float32")
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps)
# bn_istd = 1 / bn_std
# w_fold = gamma / bn_std * W
scale_factor = gamma * bn_istd
if self.conv.groups == 1:
w_fold = self.conv.weight * scale_factor.reshape(-1, 1, 1, 1)
else:
w_fold = self.conv.weight * scale_factor.reshape(
self.conv.groups, -1, 1, 1, 1
)
b_fold = None
if not (self.training and approx):
# b_fold = gamma * (conv_bias - bn_mean) / bn_std + beta
b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd
w_qat = self.apply_fakequant_with_observer( w_qat = self.apply_fakequant_with_observer(
w_fold, self.weight_fake_quant, self.weight_observer w_fold, self.weight_fake_quant, self.weight_observer
) )
return self.conv.calc_conv(inp, w_qat, b_fold) conv = self.conv.calc_conv(inp, w_qat, b_fold)
if not (self.training and approx):
return conv
# rescale conv to get original conv output
orig_conv = conv / scale_factor.reshape(1, -1, 1, 1)
if self.conv.bias is not None:
orig_conv = orig_conv + self.conv.bias
# calculate batch norm
bn_mean, bn_var = self.get_batch_mean_var(orig_conv)
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps)
conv = gamma * bn_istd * (orig_conv - bn_mean) + beta
num_elements_per_channel = conv.shapeof().prod() / conv.shapeof(1)
self.update_running_mean_and_running_var(
bn_mean, bn_var, num_elements_per_channel
)
return conv
class ConvBn2d(_ConvBn2d): class ConvBn2d(_ConvBn2d):
......
import copy
from itertools import product
import numpy as np
from megengine import tensor
from megengine.module import ConvBn2d
from megengine.quantization import quantize_qat
from megengine.quantization.quantize import disable_fake_quant
from megengine.test import assertTensorClose
def test_convbn2d():
in_channels = 32
out_channels = 64
kernel_size = 3
module = ConvBn2d(in_channels, out_channels, kernel_size)
quantize_qat(module)
for groups, bias in product([1, 4], [True, False]):
inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
module.train()
qat_module = copy.deepcopy(module)
disable_fake_quant(qat_module)
normal_outputs = module.forward(inputs)
qat_outputs = qat_module.forward_qat(inputs)
assertTensorClose(normal_outputs, qat_outputs, max_err=5e-6)
a = module.bn.running_mean.numpy()
b = qat_module.bn.running_mean.numpy()
assertTensorClose(
module.bn.running_mean, qat_module.bn.running_mean, max_err=5e-8
)
assertTensorClose(
module.bn.running_var, qat_module.bn.running_var, max_err=5e-7
)
module.eval()
normal_outputs = module.forward(inputs)
qat_module.eval()
qat_outputs = qat_module.forward_qat(inputs)
assertTensorClose(normal_outputs, qat_outputs, max_err=5e-6)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册