提交 2694ff81 编写于 作者: M Megvii Engine Team

fix(mge/module): fix some deconv fuse bn problem

GitOrigin-RevId: e88a63328065234c8060317d7c7e558c95bfbd4b
上级 1886ebc3
......@@ -24,8 +24,7 @@ class _ConvTransposeBnActivation2d(Float._ConvTransposeBnActivation2d, QATModule
# get fold bn conv_transpose2d param
gamma = self.bn.weight
if gamma is None:
gamma = ones((self.bn.num_features), dtype="float32")
gamma = gamma.reshape(1, -1, 1, 1)
gamma = ones((1, self.bn.num_features, 1, 1), dtype="float32")
beta = self.bn.bias
if beta is None:
beta = zeros((1, self.bn.num_features, 1, 1), dtype="float32")
......@@ -44,10 +43,10 @@ class _ConvTransposeBnActivation2d(Float._ConvTransposeBnActivation2d, QATModule
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps)
scale_factor = gamma * bn_istd
if self.conv_transpose2d.groups == 1:
w_fold = self.conv_transpose2d.weight * scale_factor.reshape(-1, 1, 1, 1)
w_fold = self.conv_transpose2d.weight * scale_factor.reshape(1, -1, 1, 1)
else:
w_fold = self.conv_transpose2d.weight * scale_factor.reshape(
self.conv_transpose2d.groups, -1, 1, 1, 1
self.conv_transpose2d.groups, 1, -1, 1, 1
)
w_fold = self.apply_quant_weight(w_fold)
......
......@@ -32,15 +32,21 @@ _MAP_TO_FUSED_MODULE = {
def fold_weight_bias(
weight, bias, gamma, beta, bn_mean, bn_var, eps=1e-5, transpose=False
):
shape = (1, -1, 1, 1)
shape = (-1, 1, 1, 1)
if transpose:
shape = (-1, 1, 1, 1)
shape = (1, -1, 1, 1)
kernel_shape = weight.shape
if len(kernel_shape) == 5:
groups, num_features = kernel_shape[0], kernel_shape[1]
if transpose:
groups, num_features = kernel_shape[0], kernel_shape[2]
else:
groups, num_features = kernel_shape[0], kernel_shape[1]
else:
groups, num_features = 1, kernel_shape[0]
if transpose:
groups, num_features = 1, kernel_shape[1]
else:
groups, num_features = 1, kernel_shape[0]
out_channels = groups * num_features
if gamma is None:
......@@ -93,12 +99,37 @@ def fuse_conv_bn_relu_module(conv: Conv2d, bn: BatchNorm2d, relu: ReLU):
compute_mode=conv.compute_mode,
name=conv.name,
)
new_conv = module if bn is None or not conv.training else module.conv
if isinstance(conv, ConvTranspose2d):
module.output_padding = conv.output_padding
new_conv = (
module if bn is None or not conv.training else module.conv_transpose2d
)
else:
new_conv = module if bn is None or not conv.training else module.conv
weight, bias = conv.weight, conv.bias
if not conv.training and bn is not None:
weight, bias = fold_weight_bias(
weight, bias, bn.weight, bn.bias, bn.running_mean, bn.running_var, bn.eps,
)
if isinstance(conv, ConvTranspose2d):
weight, bias = fold_weight_bias(
weight,
bias,
bn.weight,
bn.bias,
bn.running_mean,
bn.running_var,
bn.eps,
transpose=True,
)
else:
weight, bias = fold_weight_bias(
weight,
bias,
bn.weight,
bn.bias,
bn.running_mean,
bn.running_var,
bn.eps,
)
new_conv.weight = Parameter(weight)
if bias is not None:
new_conv.bias = Parameter(bias)
......@@ -106,55 +137,3 @@ def fuse_conv_bn_relu_module(conv: Conv2d, bn: BatchNorm2d, relu: ReLU):
module.bn = deepcopy(bn)
new_conv.training = conv.training
return module
def fuse_conv_transpose2d_bn_relu_module(
conv_transpose2d: ConvTranspose2d, bn: BatchNorm2d, relu: ReLU
):
module_key = tuple([type(m) for m in [conv_transpose2d, bn, relu] if m])
if bn:
assert (
conv_transpose2d.training == bn.training
), "ConvTranspose2d and BN both must be in the same mode (train or eval)."
assert (
bn.num_features == conv_transpose2d.out_channels
), "Output channel of ConvTranspose2d must match num_features of BatchNorm2d"
module_key = module_key + (conv_transpose2d.training,)
module = _MAP_TO_FUSED_MODULE[module_key](
in_channels=conv_transpose2d.in_channels,
out_channels=conv_transpose2d.out_channels,
kernel_size=conv_transpose2d.kernel_size,
stride=conv_transpose2d.stride,
padding=conv_transpose2d.padding,
output_padding=conv_transpose2d.output_padding,
dilation=conv_transpose2d.dilation,
groups=conv_transpose2d.groups,
bias=conv_transpose2d.bias is not None,
conv_mode=conv_transpose2d.conv_mode,
compute_mode=conv_transpose2d.compute_mode,
name=conv_transpose2d.name,
)
new_conv_transpose2d = (
module
if bn is None or not conv_transpose2d.training
else module.conv_transpose2d
)
weight, bias = conv_transpose2d.weight, conv_transpose2d.bias
if not conv_transpose2d.training and bn is not None:
weight, bias = fold_weight_bias(
weight,
bias,
bn.weight,
bn.bias,
bn.running_mean,
bn.running_var,
bn.eps,
transpose=False,
)
new_conv_transpose2d.weight = Parameter(weight)
if bias is not None:
new_conv_transpose2d.bias = Parameter(bias)
if bn is not None and conv_transpose2d.training:
module.bn = deepcopy(bn)
new_conv_transpose2d.training = conv_transpose2d.training
return module
......@@ -34,35 +34,49 @@ def test_qat_convbn2d():
in_channels = 32
out_channels = 64
kernel_size = 3
class TestNet(Module):
def __init__(self, groups, bias):
super().__init__()
self.quant = QuantStub()
self.dequant = DequantStub()
self.conv_bn = ConvBn2d(
in_channels, out_channels, kernel_size, groups=groups, bias=bias,
)
def forward(self, inp):
out = self.quant(inp)
out = self.conv_bn(out)
out = self.dequant(out)
return out
inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
for groups, bias in product([1, 4], [True, False]):
module = ConvBn2d(
in_channels, out_channels, kernel_size, groups=groups, bias=bias
)
M.init.normal_(module.bn.weight)
M.init.normal_(module.bn.bias)
module.train()
qat_module = quantize_qat(module, inplace=False)
disable_fake_quant(qat_module)
inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
normal_outputs = module(inputs)
qat_outputs = qat_module(inputs)
net = TestNet(groups, bias)
net.train()
qat_net = quantize_qat(net, inplace=False)
disable_fake_quant(qat_net)
normal_outputs = net(inputs)
qat_outputs = qat_net(inputs)
np.testing.assert_allclose(
normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6
normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-4,
)
np.testing.assert_allclose(
module.bn.running_mean.numpy(),
qat_module.bn.running_mean.numpy(),
net.conv_bn.bn.running_mean.numpy(),
qat_net.conv_bn.bn.running_mean.numpy(),
atol=5e-8,
)
np.testing.assert_allclose(
module.bn.running_var.numpy(), qat_module.bn.running_var.numpy(), atol=5e-7,
net.conv_bn.bn.running_var.numpy(),
qat_net.conv_bn.bn.running_var.numpy(),
atol=5e-7,
)
module.eval()
normal_outputs = module(inputs)
qat_module.eval()
qat_outputs = qat_module(inputs)
net.eval()
normal_outputs = net(inputs)
qat_net.eval()
qat_outputs = qat_net(inputs)
np.testing.assert_allclose(
normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6
normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-4,
)
......@@ -70,40 +84,44 @@ def test_qat_convtransposebn2d():
in_channels = 32
out_channels = 64
kernel_size = 3
class TestNet(Module):
def __init__(self, groups, bias):
super().__init__()
self.quant = QuantStub()
self.dequant = DequantStub()
self.conv_transpose_bn = ConvTransposeBn2d(
in_channels, out_channels, kernel_size, groups=groups, bias=bias,
)
def forward(self, inp):
out = self.quant(inp)
out = self.conv_transpose_bn(out)
out = self.dequant(out)
return out
for groups, bias in product([1, 4], [True, False]):
module = ConvTransposeBn2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
output_padding=0,
groups=groups,
bias=bias,
)
M.init.normal_(module.bn.weight)
M.init.normal_(module.bn.bias)
module.train()
qat_module = quantize_qat(module, inplace=False)
disable_fake_quant(qat_module)
net = TestNet(groups, bias)
net.train()
qat_net = quantize_qat(net, inplace=False)
disable_fake_quant(qat_net)
inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
normal_outputs = module(inputs)
qat_outputs = qat_module(inputs)
np.testing.assert_allclose(
normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6
)
normal_outputs = net(inputs)
qat_outputs = qat_net(inputs)
np.testing.assert_allclose(
module.bn.running_mean.numpy(),
qat_module.bn.running_mean.numpy(),
atol=5e-8,
normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-5,
)
np.testing.assert_allclose(
module.bn.running_var.numpy(), qat_module.bn.running_var.numpy(), atol=5e-7,
net.conv_transpose_bn.bn.running_var.numpy(),
qat_net.conv_transpose_bn.bn.running_var.numpy(),
atol=5e-7,
)
module.eval()
normal_outputs = module(inputs)
qat_module.eval()
qat_outputs = qat_module(inputs)
net.eval()
normal_outputs = net(inputs)
qat_net.eval()
qat_outputs = qat_net(inputs)
np.testing.assert_allclose(
normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6
normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-5,
)
......
......@@ -3,6 +3,15 @@ import pytest
from megengine import Parameter, Tensor
from megengine import module as Float
from megengine.functional import ones, zeros
from megengine.module import (
BatchNorm2d,
Conv2d,
ConvBn2d,
ConvTranspose2d,
ConvTransposeBn2d,
ReLU,
)
from megengine.module import qat as QAT
from megengine.module import quantized as Q
from megengine.quantization import (
......@@ -24,6 +33,7 @@ from megengine.quantization.quantize import (
quantize_qat,
reset_qconfig,
)
from megengine.utils.bn_fusion import fuse_conv_bn_relu_module
class FloatNet(Float.Module):
......@@ -291,3 +301,85 @@ def test_convert_with_custom_mapping():
net = Net()
qat_net = quantize_qat(net, inplace=False, mapping={FloatExample: QATExample})
assert isinstance(qat_net.example, QATExample)
def test_ConvBn2d_fold_weight_bias():
in_channels = 32
out_channels = 64
kernel_size = 3
conv = Conv2d(in_channels, out_channels, kernel_size)
bn = BatchNorm2d(out_channels)
relu = ReLU()
fused_conv = fuse_conv_bn_relu_module(conv, bn, relu)
bn.eval()
fused_conv.eval()
inputs = Tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
expected_result = relu(bn(conv(inputs)))
actual_result = fused_conv(inputs)
np.testing.assert_allclose(
expected_result.numpy(), actual_result.numpy(), atol=1e-4
)
conv.eval()
bn.eval()
relu.eval()
fused_conv = fuse_conv_bn_relu_module(conv, bn, relu)
fused_conv.eval()
expected_result = relu(conv(inputs))
actual_result = fused_conv(inputs)
np.testing.assert_allclose(
expected_result.numpy(), actual_result.numpy(), atol=1e-4
)
conv.train()
bn.train()
fused_conv = fuse_conv_bn_relu_module(conv, bn, None)
fused_conv.train()
expected_result = bn(conv(inputs))
actual_result = fused_conv(inputs)
np.testing.assert_allclose(
expected_result.numpy(), actual_result.numpy(), atol=1e-4
)
def test_ConvTransposeBn2d_fold_weight_bias():
in_channels = 32
out_channels = 64
kernel_size = 3
conv = ConvTranspose2d(in_channels, out_channels, kernel_size)
bn = BatchNorm2d(out_channels)
relu = ReLU()
fused_conv = fuse_conv_bn_relu_module(conv, bn, relu)
bn.eval()
fused_conv.eval()
inputs = Tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
expected_result = relu(bn(conv(inputs)))
actual_result = fused_conv(inputs)
np.testing.assert_allclose(
expected_result.numpy(), actual_result.numpy(), atol=1e-4
)
conv.eval()
bn.eval()
relu.eval()
fused_conv = fuse_conv_bn_relu_module(conv, bn, relu)
fused_conv.eval()
expected_result = relu(conv(inputs))
actual_result = fused_conv(inputs)
np.testing.assert_allclose(
expected_result.numpy(), actual_result.numpy(), atol=1e-4
)
conv.train()
bn.train()
fused_conv = fuse_conv_bn_relu_module(conv, bn, None)
fused_conv.train()
expected_result = bn(conv(inputs))
actual_result = fused_conv(inputs)
np.testing.assert_allclose(
expected_result.numpy(), actual_result.numpy(), atol=1e-4
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册