diff --git a/python_module/megengine/module/__init__.py b/python_module/megengine/module/__init__.py index 7fe65951cc63cb30427c58b5b347e2f9a7ab9590..c2b3db8ad4fdbe0ca018c42b3d0a3eddf8d50b7b 100644 --- a/python_module/megengine/module/__init__.py +++ b/python_module/megengine/module/__init__.py @@ -9,8 +9,8 @@ from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax from .batchnorm import BatchNorm1d, BatchNorm2d, SyncBatchNorm from .concat import Concat -from .conv import Conv2d, ConvTranspose2d, LocalConv2d -from .conv_bn_relu import ConvBn2d, ConvBnRelu2d +from .conv import Conv2d, ConvRelu2d, ConvTranspose2d, LocalConv2d +from .conv_bn import ConvBn2d, ConvBnRelu2d from .dropout import Dropout from .elemwise import Elemwise from .embedding import Embedding diff --git a/python_module/megengine/module/conv.py b/python_module/megengine/module/conv.py index 96748f768b6fd52216df61d9688d480fb656db17..02165b891e615184410d682e50735c653082501b 100644 --- a/python_module/megengine/module/conv.py +++ b/python_module/megengine/module/conv.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. @@ -13,8 +12,8 @@ import numpy as np import megengine._internal as mgb +from .. import functional as F from ..core import Parameter -from ..functional import conv2d, conv_transpose2d, local_conv2d from ..utils.types import _pair, _pair_nonzero from . import init from .module import Module @@ -183,7 +182,7 @@ class Conv2d(_ConvNd): return (1, self.out_channels, 1, 1) def calc_conv(self, inp, weight, bias): - return conv2d( + return F.conv2d( inp, weight, bias, @@ -295,7 +294,7 @@ class ConvTranspose2d(_ConvNd): return (1, self.out_channels, 1, 1) def forward(self, inp): - return conv_transpose2d( + return F.conv_transpose2d( inp, self.weight, self.bias, @@ -324,7 +323,7 @@ class LocalConv2d(Conv2d): spatial dimensions. Only zero-padding is supported. Default: 0 :param groups: number of groups to divide input and output channels into, so as to perform a "grouped convolution". When ``groups`` is not 1, - ``in_channels`` and ``out_channels`` must be divisible by ``groups``. + ``in_channels`` and ``out_channels`` must be divisible by ``groups``. The shape of weight is ``(groups, output_height, output_width, in_channels // groups, *kernel_size, out_channels // groups)``. """ @@ -377,6 +376,17 @@ class LocalConv2d(Conv2d): ) def forward(self, inp): - return local_conv2d( + return F.local_conv2d( inp, self.weight, self.stride, self.padding, self.dilation, self.conv_mode ) + + +class ConvRelu2d(Conv2d): + r""" + A fused :class:`~.Module` including Conv2d and relu. Could be replaced + with :class:`~.QATModule` version :class:`~.qat.conv.ConvRelu2d` using + :func:`~.quantize.quantize_qat`. + """ + + def forward(self, inp): + return F.relu(self.calc_conv(inp, self.weight, self.bias)) diff --git a/python_module/megengine/module/conv_bn_relu.py b/python_module/megengine/module/conv_bn.py similarity index 92% rename from python_module/megengine/module/conv_bn_relu.py rename to python_module/megengine/module/conv_bn.py index bb3a857736bad07f398f5a03162796615e8c22d6..76713b0f81e502900de5ce34b2faa96ddda595a2 100644 --- a/python_module/megengine/module/conv_bn_relu.py +++ b/python_module/megengine/module/conv_bn.py @@ -50,7 +50,7 @@ class _ConvBnActivation2d(Module): class ConvBn2d(_ConvBnActivation2d): r""" A fused :class:`~.Module` including Conv2d, BatchNorm2d. Could be replaced - with :class:`~.QATModule` version :class:`~.qat.conv_bn_relu.ConvBn2d` using + with :class:`~.QATModule` version :class:`~.qat.conv_bn.ConvBn2d` using :func:`~.quantize.quantize_qat`. """ @@ -61,7 +61,7 @@ class ConvBn2d(_ConvBnActivation2d): class ConvBnRelu2d(_ConvBnActivation2d): r""" A fused :class:`~.Module` including Conv2d, BatchNorm2d and relu. Could be replaced - with :class:`~.QATModule` version :class:`~.qat.conv_bn_relu.ConvBnRelu2d` using + with :class:`~.QATModule` version :class:`~.qat.conv_bn.ConvBnRelu2d` using :func:`~.quantize.quantize_qat`. """ diff --git a/python_module/megengine/module/qat/__init__.py b/python_module/megengine/module/qat/__init__.py index 26c4b6eeb858f1755fd74914673a7e8370a5d442..b6adab4dc687a322fba6dd5652bdf8975933ad3a 100644 --- a/python_module/megengine/module/qat/__init__.py +++ b/python_module/megengine/module/qat/__init__.py @@ -6,7 +6,8 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from .concat import Concat -from .conv_bn_relu import ConvBn2d, ConvBnRelu2d +from .conv import Conv2d, ConvRelu2d +from .conv_bn import ConvBn2d, ConvBnRelu2d from .elemwise import Elemwise from .linear import Linear from .module import QATModule diff --git a/python_module/megengine/module/qat/conv.py b/python_module/megengine/module/qat/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..489f94cb4cabcb5a4e86db198fe8086ba43659d6 --- /dev/null +++ b/python_module/megengine/module/qat/conv.py @@ -0,0 +1,57 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from ... import functional as F +from .. import conv as Float +from .module import QATModule + + +class Conv2d(Float.Conv2d, QATModule): + r""" + A :class:`~.QATModule` Conv2d with QAT support. + Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`. + """ + + def calc_conv_qat(self, inp): + w_qat = self.apply_quant_weight(self.weight) + conv = self.calc_conv(inp, w_qat, self.bias) + return conv + + @classmethod + def from_float_module(cls, float_module: Float.Conv2d): + r""" + Return a :class:`~.QATModule` instance converted from + a float :class:`~.Module` instance. + """ + qat_module = cls( + float_module.in_channels, + float_module.out_channels, + float_module.kernel_size, + float_module.stride, + float_module.padding, + float_module.dilation, + float_module.groups, + float_module.bias is not None, + float_module.conv_mode.name, + float_module.compute_mode.name, + ) + qat_module.weight = float_module.weight + qat_module.bias = float_module.bias + return qat_module + + def forward(self, inp): + return self.apply_quant_activation(self.calc_conv_qat(inp)) + + +class ConvRelu2d(Conv2d): + r""" + A :class:`~.QATModule` include Conv2d and Relu with QAT support. + Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`. + """ + + def forward(self, inp): + return self.apply_quant_activation(F.relu(self.calc_conv_qat(inp))) diff --git a/python_module/megengine/module/qat/conv_bn_relu.py b/python_module/megengine/module/qat/conv_bn.py similarity index 98% rename from python_module/megengine/module/qat/conv_bn_relu.py rename to python_module/megengine/module/qat/conv_bn.py index 4a8be9e17dca2474fd69f03585977c2dd7c0533f..b62270b61d76670adfc234355634cb4fdb2f2ec7 100644 --- a/python_module/megengine/module/qat/conv_bn_relu.py +++ b/python_module/megengine/module/qat/conv_bn.py @@ -7,7 +7,7 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from ...core import ones, zeros from ...functional import add_update, relu, sqrt, sum, zero_grad -from .. import conv_bn_relu as Float +from .. import conv_bn as Float from .module import QATModule @@ -163,7 +163,7 @@ class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule): float_module.conv.padding, float_module.conv.dilation, float_module.conv.groups, - bool(float_module.conv.bias), + float_module.conv.bias is not None, float_module.conv.conv_mode.name, float_module.conv.compute_mode.name, ) diff --git a/python_module/megengine/module/quantized/__init__.py b/python_module/megengine/module/quantized/__init__.py index b79977214d279ec7c9ddfa402ba42ef0d098e980..e641476d6a363a609660fb2495bf946e91b7b6c8 100644 --- a/python_module/megengine/module/quantized/__init__.py +++ b/python_module/megengine/module/quantized/__init__.py @@ -6,7 +6,8 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from .concat import Concat -from .conv_bn_relu import ConvBn2d, ConvBnRelu2d +from .conv import Conv2d, ConvRelu2d +from .conv_bn import ConvBn2d, ConvBnRelu2d from .elemwise import Elemwise from .linear import Linear from .module import QuantizedModule diff --git a/python_module/megengine/module/quantized/conv_bn_relu.py b/python_module/megengine/module/quantized/conv.py similarity index 74% rename from python_module/megengine/module/quantized/conv_bn_relu.py rename to python_module/megengine/module/quantized/conv.py index 6b72921e0e8b45a08ebb949b0720d58b5eb0d806..3118451de6ed5b145a3ae473ea63cbed407ab254 100644 --- a/python_module/megengine/module/quantized/conv_bn_relu.py +++ b/python_module/megengine/module/quantized/conv.py @@ -7,16 +7,19 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from typing import Tuple, Union +import numpy as np + import megengine._internal as mgb from ... import module as Float from ...core import Parameter from ...functional import conv_bias_activation -from ..qat import conv_bn_relu as QAT +from ..qat import conv as QAT from .module import QuantizedModule -class _ConvBnActivation2d(Float.Conv2d, QuantizedModule): +class Conv2d(Float.Conv2d, QuantizedModule): + r"""quantized version of :class:`~.qat.conv.Conv2d`.""" r"""Applies a 2D convolution over an quantized input tensor, inference only. The parameter is same with :class: `~.Conv2d` @@ -68,40 +71,38 @@ class _ConvBnActivation2d(Float.Conv2d, QuantizedModule): ) @classmethod - def from_qat_module(cls, qat_module: QAT._ConvBnActivation2d): + def from_qat_module(cls, qat_module: QAT.Conv2d): r""" return a :class:`~.QuantizedModule` instance converted from a :class:`~.QATModule` instance. """ output_dtype = qat_module.get_activation_dtype() qconv = cls( - qat_module.conv.in_channels, - qat_module.conv.out_channels, - qat_module.conv.kernel_size, - qat_module.conv.stride, - qat_module.conv.padding, - qat_module.conv.dilation, - qat_module.conv.groups, + qat_module.in_channels, + qat_module.out_channels, + qat_module.kernel_size, + qat_module.stride, + qat_module.padding, + qat_module.dilation, + qat_module.groups, dtype=output_dtype, ) - w_fold, b_fold = qat_module.fold_weight_bias( - qat_module.bn.running_mean, qat_module.bn.running_var - ) - weight = w_fold.astype(qat_module.get_weight_dtype()) + weight = qat_module.weight.astype(qat_module.get_weight_dtype()) qconv.weight = Parameter(weight.numpy()) - qconv.bias = Parameter(b_fold.numpy()) + if qat_module.bias is not None: + qconv.bias = Parameter(qat_module.bias.numpy()) + else: + qconv.bias = Parameter( + np.zeros(qat_module._infer_bias_shape(), dtype=np.float32) + ) return qconv - -class ConvBn2d(_ConvBnActivation2d): - r"""quantized version of :class:`~.qat.conv_bn_relu.ConvBn2d`.""" - def forward(self, inp): return self.calc_conv_quantized(inp, nonlinear_mode="IDENTITY") -class ConvBnRelu2d(_ConvBnActivation2d): - r"""quantized version of :class:`~.qat.conv_bn_relu.ConvBnRelu2d`.""" +class ConvRelu2d(Conv2d): + r"""quantized version of :class:`~.qat.conv.ConvRelu2d`.""" def forward(self, inp): return self.calc_conv_quantized(inp, nonlinear_mode="RELU") diff --git a/python_module/megengine/module/quantized/conv_bn.py b/python_module/megengine/module/quantized/conv_bn.py new file mode 100644 index 0000000000000000000000000000000000000000..ceb36d13f1aa06067e89aacb097344c24f6e0a0a --- /dev/null +++ b/python_module/megengine/module/quantized/conv_bn.py @@ -0,0 +1,56 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from ...core import Parameter +from ..qat import conv_bn as QAT +from .conv import Conv2d + + +class _ConvBnActivation2d(Conv2d): + r"""Applies a 2D convolution over an quantized input tensor, inference only. + + The parameter is same with :class: `~.Conv2d` + """ + + @classmethod + def from_qat_module(cls, qat_module: QAT._ConvBnActivation2d): + r""" + return a :class:`~.QuantizedModule` instance converted from a + :class:`~.QATModule` instance. + """ + output_dtype = qat_module.get_activation_dtype() + qconv = cls( + qat_module.conv.in_channels, + qat_module.conv.out_channels, + qat_module.conv.kernel_size, + qat_module.conv.stride, + qat_module.conv.padding, + qat_module.conv.dilation, + qat_module.conv.groups, + dtype=output_dtype, + ) + w_fold, b_fold = qat_module.fold_weight_bias( + qat_module.bn.running_mean, qat_module.bn.running_var + ) + weight = w_fold.astype(qat_module.get_weight_dtype()) + qconv.weight = Parameter(weight.numpy()) + qconv.bias = Parameter(b_fold.numpy()) + return qconv + + +class ConvBn2d(_ConvBnActivation2d): + r"""quantized version of :class:`~.qat.conv_bn.ConvBn2d`.""" + + def forward(self, inp): + return self.calc_conv_quantized(inp, nonlinear_mode="IDENTITY") + + +class ConvBnRelu2d(_ConvBnActivation2d): + r"""quantized version of :class:`~.qat.conv_bn.ConvBnRelu2d`.""" + + def forward(self, inp): + return self.calc_conv_quantized(inp, nonlinear_mode="RELU") diff --git a/python_module/megengine/quantization/quantize.py b/python_module/megengine/quantization/quantize.py index c5558a2be3a4668b187f7e870fbc9aa408ad74da..2424e22ff0ad867ab0389d7270458791d4fa5852 100644 --- a/python_module/megengine/quantization/quantize.py +++ b/python_module/megengine/quantization/quantize.py @@ -104,6 +104,10 @@ def quantize_qat( for key, submodule, parent in module._flatten( with_key=True, with_parent=True, predicate=is_quantable ): + # only convert top quantable module. + if is_quantable(parent): + continue + new_mod = _float2qat_dict[type(submodule)].from_float_module(submodule) if isinstance(parent, Float.Sequential): # cannnot use setattr to be compatible with Sequential's ``__setitem__`` diff --git a/python_module/test/unit/module/test_conv_bn_relu.py b/python_module/test/unit/module/test_conv_bn_relu.py deleted file mode 100644 index 308adf4fc1e9793951b8182312a492c7b36e69a9..0000000000000000000000000000000000000000 --- a/python_module/test/unit/module/test_conv_bn_relu.py +++ /dev/null @@ -1,37 +0,0 @@ -import copy -from itertools import product - -import numpy as np - -from megengine import tensor -from megengine.module import ConvBn2d -from megengine.quantization.quantize import disable_fake_quant, quantize_qat -from megengine.test import assertTensorClose - - -def test_convbn2d(): - in_channels = 32 - out_channels = 64 - kernel_size = 3 - for groups, bias in product([1, 4], [True, False]): - module = ConvBn2d( - in_channels, out_channels, kernel_size, groups=groups, bias=bias - ) - module.train() - qat_module = quantize_qat(module, inplace=False) - disable_fake_quant(qat_module) - inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) - normal_outputs = module(inputs) - qat_outputs = qat_module(inputs) - assertTensorClose(normal_outputs, qat_outputs, max_err=5e-6) - assertTensorClose( - module.bn.running_mean, qat_module.bn.running_mean, max_err=5e-8 - ) - assertTensorClose( - module.bn.running_var, qat_module.bn.running_var, max_err=5e-7 - ) - module.eval() - normal_outputs = module(inputs) - qat_module.eval() - qat_outputs = qat_module(inputs) - assertTensorClose(normal_outputs, qat_outputs, max_err=5e-6) diff --git a/python_module/test/unit/module/test_qat.py b/python_module/test/unit/module/test_qat.py new file mode 100644 index 0000000000000000000000000000000000000000..6b6c5a86b936612a947a121c123656db7ac9a628 --- /dev/null +++ b/python_module/test/unit/module/test_qat.py @@ -0,0 +1,85 @@ +from itertools import product + +import numpy as np + +from megengine import tensor +from megengine.module import ( + Conv2d, + ConvBn2d, + ConvRelu2d, + DequantStub, + Module, + QuantStub, +) +from megengine.quantization.quantize import disable_fake_quant, quantize_qat +from megengine.test import assertTensorClose + + +def test_qat_convbn2d(): + in_channels = 32 + out_channels = 64 + kernel_size = 3 + for groups, bias in product([1, 4], [True, False]): + module = ConvBn2d( + in_channels, out_channels, kernel_size, groups=groups, bias=bias + ) + module.train() + qat_module = quantize_qat(module, inplace=False) + disable_fake_quant(qat_module) + inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) + normal_outputs = module(inputs) + qat_outputs = qat_module(inputs) + assertTensorClose(normal_outputs, qat_outputs, max_err=5e-6) + assertTensorClose( + module.bn.running_mean, qat_module.bn.running_mean, max_err=5e-8 + ) + assertTensorClose( + module.bn.running_var, qat_module.bn.running_var, max_err=5e-7 + ) + module.eval() + normal_outputs = module(inputs) + qat_module.eval() + qat_outputs = qat_module(inputs) + assertTensorClose(normal_outputs, qat_outputs, max_err=5e-6) + + +def test_qat_conv(): + + in_channels = 32 + out_channels = 64 + kernel_size = 3 + + class TestNet(Module): + def __init__(self, groups, bias): + super().__init__() + self.quant = QuantStub() + self.dequant = DequantStub() + self.conv = Conv2d( + in_channels, out_channels, kernel_size, groups=groups, bias=bias + ) + self.conv_relu = ConvRelu2d( + out_channels, in_channels, kernel_size, groups=groups, bias=bias + ) + + def forward(self, inp): + out = self.quant(inp) + out = self.conv(out) + out = self.conv_relu(out) + out = self.dequant(out) + return out + + inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) + for groups, bias in product([1, 4], [True, False]): + net = TestNet(groups, bias) + net.train() + qat_net = quantize_qat(net, inplace=False) + disable_fake_quant(qat_net) + normal_outputs = net(inputs) + qat_outputs = qat_net(inputs) + assertTensorClose(normal_outputs, qat_outputs) + + net.eval() + normal_outputs = net(inputs) + qat_net.eval() + qat_outputs = qat_net(inputs) + assertTensorClose(normal_outputs, qat_outputs)