提交 e6820b91 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

feat(mge/module): add conv and conv_relu quantization module

GitOrigin-RevId: 9cd668d97b4ccae8adfa801fd43856cd3fdca813
上级 a1f8ecc7
...@@ -9,8 +9,8 @@ ...@@ -9,8 +9,8 @@
from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax
from .batchnorm import BatchNorm1d, BatchNorm2d, SyncBatchNorm from .batchnorm import BatchNorm1d, BatchNorm2d, SyncBatchNorm
from .concat import Concat from .concat import Concat
from .conv import Conv2d, ConvTranspose2d, LocalConv2d from .conv import Conv2d, ConvRelu2d, ConvTranspose2d, LocalConv2d
from .conv_bn_relu import ConvBn2d, ConvBnRelu2d from .conv_bn import ConvBn2d, ConvBnRelu2d
from .dropout import Dropout from .dropout import Dropout
from .elemwise import Elemwise from .elemwise import Elemwise
from .embedding import Embedding from .embedding import Embedding
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
# #
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
...@@ -13,8 +12,8 @@ import numpy as np ...@@ -13,8 +12,8 @@ import numpy as np
import megengine._internal as mgb import megengine._internal as mgb
from .. import functional as F
from ..core import Parameter from ..core import Parameter
from ..functional import conv2d, conv_transpose2d, local_conv2d
from ..utils.types import _pair, _pair_nonzero from ..utils.types import _pair, _pair_nonzero
from . import init from . import init
from .module import Module from .module import Module
...@@ -183,7 +182,7 @@ class Conv2d(_ConvNd): ...@@ -183,7 +182,7 @@ class Conv2d(_ConvNd):
return (1, self.out_channels, 1, 1) return (1, self.out_channels, 1, 1)
def calc_conv(self, inp, weight, bias): def calc_conv(self, inp, weight, bias):
return conv2d( return F.conv2d(
inp, inp,
weight, weight,
bias, bias,
...@@ -295,7 +294,7 @@ class ConvTranspose2d(_ConvNd): ...@@ -295,7 +294,7 @@ class ConvTranspose2d(_ConvNd):
return (1, self.out_channels, 1, 1) return (1, self.out_channels, 1, 1)
def forward(self, inp): def forward(self, inp):
return conv_transpose2d( return F.conv_transpose2d(
inp, inp,
self.weight, self.weight,
self.bias, self.bias,
...@@ -377,6 +376,17 @@ class LocalConv2d(Conv2d): ...@@ -377,6 +376,17 @@ class LocalConv2d(Conv2d):
) )
def forward(self, inp): def forward(self, inp):
return local_conv2d( return F.local_conv2d(
inp, self.weight, self.stride, self.padding, self.dilation, self.conv_mode inp, self.weight, self.stride, self.padding, self.dilation, self.conv_mode
) )
class ConvRelu2d(Conv2d):
r"""
A fused :class:`~.Module` including Conv2d and relu. Could be replaced
with :class:`~.QATModule` version :class:`~.qat.conv.ConvRelu2d` using
:func:`~.quantize.quantize_qat`.
"""
def forward(self, inp):
return F.relu(self.calc_conv(inp, self.weight, self.bias))
...@@ -50,7 +50,7 @@ class _ConvBnActivation2d(Module): ...@@ -50,7 +50,7 @@ class _ConvBnActivation2d(Module):
class ConvBn2d(_ConvBnActivation2d): class ConvBn2d(_ConvBnActivation2d):
r""" r"""
A fused :class:`~.Module` including Conv2d, BatchNorm2d. Could be replaced A fused :class:`~.Module` including Conv2d, BatchNorm2d. Could be replaced
with :class:`~.QATModule` version :class:`~.qat.conv_bn_relu.ConvBn2d` using with :class:`~.QATModule` version :class:`~.qat.conv_bn.ConvBn2d` using
:func:`~.quantize.quantize_qat`. :func:`~.quantize.quantize_qat`.
""" """
...@@ -61,7 +61,7 @@ class ConvBn2d(_ConvBnActivation2d): ...@@ -61,7 +61,7 @@ class ConvBn2d(_ConvBnActivation2d):
class ConvBnRelu2d(_ConvBnActivation2d): class ConvBnRelu2d(_ConvBnActivation2d):
r""" r"""
A fused :class:`~.Module` including Conv2d, BatchNorm2d and relu. Could be replaced A fused :class:`~.Module` including Conv2d, BatchNorm2d and relu. Could be replaced
with :class:`~.QATModule` version :class:`~.qat.conv_bn_relu.ConvBnRelu2d` using with :class:`~.QATModule` version :class:`~.qat.conv_bn.ConvBnRelu2d` using
:func:`~.quantize.quantize_qat`. :func:`~.quantize.quantize_qat`.
""" """
......
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .concat import Concat from .concat import Concat
from .conv_bn_relu import ConvBn2d, ConvBnRelu2d from .conv import Conv2d, ConvRelu2d
from .conv_bn import ConvBn2d, ConvBnRelu2d
from .elemwise import Elemwise from .elemwise import Elemwise
from .linear import Linear from .linear import Linear
from .module import QATModule from .module import QATModule
......
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ... import functional as F
from .. import conv as Float
from .module import QATModule
class Conv2d(Float.Conv2d, QATModule):
r"""
A :class:`~.QATModule` Conv2d with QAT support.
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`.
"""
def calc_conv_qat(self, inp):
w_qat = self.apply_quant_weight(self.weight)
conv = self.calc_conv(inp, w_qat, self.bias)
return conv
@classmethod
def from_float_module(cls, float_module: Float.Conv2d):
r"""
Return a :class:`~.QATModule` instance converted from
a float :class:`~.Module` instance.
"""
qat_module = cls(
float_module.in_channels,
float_module.out_channels,
float_module.kernel_size,
float_module.stride,
float_module.padding,
float_module.dilation,
float_module.groups,
float_module.bias is not None,
float_module.conv_mode.name,
float_module.compute_mode.name,
)
qat_module.weight = float_module.weight
qat_module.bias = float_module.bias
return qat_module
def forward(self, inp):
return self.apply_quant_activation(self.calc_conv_qat(inp))
class ConvRelu2d(Conv2d):
r"""
A :class:`~.QATModule` include Conv2d and Relu with QAT support.
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`.
"""
def forward(self, inp):
return self.apply_quant_activation(F.relu(self.calc_conv_qat(inp)))
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ...core import ones, zeros from ...core import ones, zeros
from ...functional import add_update, relu, sqrt, sum, zero_grad from ...functional import add_update, relu, sqrt, sum, zero_grad
from .. import conv_bn_relu as Float from .. import conv_bn as Float
from .module import QATModule from .module import QATModule
...@@ -163,7 +163,7 @@ class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule): ...@@ -163,7 +163,7 @@ class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule):
float_module.conv.padding, float_module.conv.padding,
float_module.conv.dilation, float_module.conv.dilation,
float_module.conv.groups, float_module.conv.groups,
bool(float_module.conv.bias), float_module.conv.bias is not None,
float_module.conv.conv_mode.name, float_module.conv.conv_mode.name,
float_module.conv.compute_mode.name, float_module.conv.compute_mode.name,
) )
......
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .concat import Concat from .concat import Concat
from .conv_bn_relu import ConvBn2d, ConvBnRelu2d from .conv import Conv2d, ConvRelu2d
from .conv_bn import ConvBn2d, ConvBnRelu2d
from .elemwise import Elemwise from .elemwise import Elemwise
from .linear import Linear from .linear import Linear
from .module import QuantizedModule from .module import QuantizedModule
......
...@@ -7,16 +7,19 @@ ...@@ -7,16 +7,19 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Tuple, Union from typing import Tuple, Union
import numpy as np
import megengine._internal as mgb import megengine._internal as mgb
from ... import module as Float from ... import module as Float
from ...core import Parameter from ...core import Parameter
from ...functional import conv_bias_activation from ...functional import conv_bias_activation
from ..qat import conv_bn_relu as QAT from ..qat import conv as QAT
from .module import QuantizedModule from .module import QuantizedModule
class _ConvBnActivation2d(Float.Conv2d, QuantizedModule): class Conv2d(Float.Conv2d, QuantizedModule):
r"""quantized version of :class:`~.qat.conv.Conv2d`."""
r"""Applies a 2D convolution over an quantized input tensor, inference only. r"""Applies a 2D convolution over an quantized input tensor, inference only.
The parameter is same with :class: `~.Conv2d` The parameter is same with :class: `~.Conv2d`
...@@ -68,40 +71,38 @@ class _ConvBnActivation2d(Float.Conv2d, QuantizedModule): ...@@ -68,40 +71,38 @@ class _ConvBnActivation2d(Float.Conv2d, QuantizedModule):
) )
@classmethod @classmethod
def from_qat_module(cls, qat_module: QAT._ConvBnActivation2d): def from_qat_module(cls, qat_module: QAT.Conv2d):
r""" r"""
return a :class:`~.QuantizedModule` instance converted from a return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance. :class:`~.QATModule` instance.
""" """
output_dtype = qat_module.get_activation_dtype() output_dtype = qat_module.get_activation_dtype()
qconv = cls( qconv = cls(
qat_module.conv.in_channels, qat_module.in_channels,
qat_module.conv.out_channels, qat_module.out_channels,
qat_module.conv.kernel_size, qat_module.kernel_size,
qat_module.conv.stride, qat_module.stride,
qat_module.conv.padding, qat_module.padding,
qat_module.conv.dilation, qat_module.dilation,
qat_module.conv.groups, qat_module.groups,
dtype=output_dtype, dtype=output_dtype,
) )
w_fold, b_fold = qat_module.fold_weight_bias( weight = qat_module.weight.astype(qat_module.get_weight_dtype())
qat_module.bn.running_mean, qat_module.bn.running_var
)
weight = w_fold.astype(qat_module.get_weight_dtype())
qconv.weight = Parameter(weight.numpy()) qconv.weight = Parameter(weight.numpy())
qconv.bias = Parameter(b_fold.numpy()) if qat_module.bias is not None:
qconv.bias = Parameter(qat_module.bias.numpy())
else:
qconv.bias = Parameter(
np.zeros(qat_module._infer_bias_shape(), dtype=np.float32)
)
return qconv return qconv
class ConvBn2d(_ConvBnActivation2d):
r"""quantized version of :class:`~.qat.conv_bn_relu.ConvBn2d`."""
def forward(self, inp): def forward(self, inp):
return self.calc_conv_quantized(inp, nonlinear_mode="IDENTITY") return self.calc_conv_quantized(inp, nonlinear_mode="IDENTITY")
class ConvBnRelu2d(_ConvBnActivation2d): class ConvRelu2d(Conv2d):
r"""quantized version of :class:`~.qat.conv_bn_relu.ConvBnRelu2d`.""" r"""quantized version of :class:`~.qat.conv.ConvRelu2d`."""
def forward(self, inp): def forward(self, inp):
return self.calc_conv_quantized(inp, nonlinear_mode="RELU") return self.calc_conv_quantized(inp, nonlinear_mode="RELU")
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ...core import Parameter
from ..qat import conv_bn as QAT
from .conv import Conv2d
class _ConvBnActivation2d(Conv2d):
r"""Applies a 2D convolution over an quantized input tensor, inference only.
The parameter is same with :class: `~.Conv2d`
"""
@classmethod
def from_qat_module(cls, qat_module: QAT._ConvBnActivation2d):
r"""
return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance.
"""
output_dtype = qat_module.get_activation_dtype()
qconv = cls(
qat_module.conv.in_channels,
qat_module.conv.out_channels,
qat_module.conv.kernel_size,
qat_module.conv.stride,
qat_module.conv.padding,
qat_module.conv.dilation,
qat_module.conv.groups,
dtype=output_dtype,
)
w_fold, b_fold = qat_module.fold_weight_bias(
qat_module.bn.running_mean, qat_module.bn.running_var
)
weight = w_fold.astype(qat_module.get_weight_dtype())
qconv.weight = Parameter(weight.numpy())
qconv.bias = Parameter(b_fold.numpy())
return qconv
class ConvBn2d(_ConvBnActivation2d):
r"""quantized version of :class:`~.qat.conv_bn.ConvBn2d`."""
def forward(self, inp):
return self.calc_conv_quantized(inp, nonlinear_mode="IDENTITY")
class ConvBnRelu2d(_ConvBnActivation2d):
r"""quantized version of :class:`~.qat.conv_bn.ConvBnRelu2d`."""
def forward(self, inp):
return self.calc_conv_quantized(inp, nonlinear_mode="RELU")
...@@ -104,6 +104,10 @@ def quantize_qat( ...@@ -104,6 +104,10 @@ def quantize_qat(
for key, submodule, parent in module._flatten( for key, submodule, parent in module._flatten(
with_key=True, with_parent=True, predicate=is_quantable with_key=True, with_parent=True, predicate=is_quantable
): ):
# only convert top quantable module.
if is_quantable(parent):
continue
new_mod = _float2qat_dict[type(submodule)].from_float_module(submodule) new_mod = _float2qat_dict[type(submodule)].from_float_module(submodule)
if isinstance(parent, Float.Sequential): if isinstance(parent, Float.Sequential):
# cannnot use setattr to be compatible with Sequential's ``__setitem__`` # cannnot use setattr to be compatible with Sequential's ``__setitem__``
......
import copy
from itertools import product from itertools import product
import numpy as np import numpy as np
from megengine import tensor from megengine import tensor
from megengine.module import ConvBn2d from megengine.module import (
Conv2d,
ConvBn2d,
ConvRelu2d,
DequantStub,
Module,
QuantStub,
)
from megengine.quantization.quantize import disable_fake_quant, quantize_qat from megengine.quantization.quantize import disable_fake_quant, quantize_qat
from megengine.test import assertTensorClose from megengine.test import assertTensorClose
def test_convbn2d(): def test_qat_convbn2d():
in_channels = 32 in_channels = 32
out_channels = 64 out_channels = 64
kernel_size = 3 kernel_size = 3
...@@ -35,3 +41,45 @@ def test_convbn2d(): ...@@ -35,3 +41,45 @@ def test_convbn2d():
qat_module.eval() qat_module.eval()
qat_outputs = qat_module(inputs) qat_outputs = qat_module(inputs)
assertTensorClose(normal_outputs, qat_outputs, max_err=5e-6) assertTensorClose(normal_outputs, qat_outputs, max_err=5e-6)
def test_qat_conv():
in_channels = 32
out_channels = 64
kernel_size = 3
class TestNet(Module):
def __init__(self, groups, bias):
super().__init__()
self.quant = QuantStub()
self.dequant = DequantStub()
self.conv = Conv2d(
in_channels, out_channels, kernel_size, groups=groups, bias=bias
)
self.conv_relu = ConvRelu2d(
out_channels, in_channels, kernel_size, groups=groups, bias=bias
)
def forward(self, inp):
out = self.quant(inp)
out = self.conv(out)
out = self.conv_relu(out)
out = self.dequant(out)
return out
inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
for groups, bias in product([1, 4], [True, False]):
net = TestNet(groups, bias)
net.train()
qat_net = quantize_qat(net, inplace=False)
disable_fake_quant(qat_net)
normal_outputs = net(inputs)
qat_outputs = qat_net(inputs)
assertTensorClose(normal_outputs, qat_outputs)
net.eval()
normal_outputs = net(inputs)
qat_net.eval()
qat_outputs = qat_net(inputs)
assertTensorClose(normal_outputs, qat_outputs)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册