From caf1fac2512d87a30e0466946be06d43e58f1e91 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 25 May 2020 19:17:05 +0800 Subject: [PATCH] refactor(mge/quantization): split `QATModule` and refactor convert api GitOrigin-RevId: 80cfb12d10590bbc88fd98370f5e3cf5d196d586 --- python_module/megengine/functional/nn.py | 8 +- python_module/megengine/module/__init__.py | 2 +- python_module/megengine/module/concat.py | 13 +- .../megengine/module/conv_bn_relu.py | 170 +-------------- python_module/megengine/module/elemwise.py | 13 +- python_module/megengine/module/linear.py | 15 +- python_module/megengine/module/module.py | 96 --------- .../megengine/module/qat/__init__.py | 13 ++ python_module/megengine/module/qat/concat.py | 30 +++ .../megengine/module/qat/conv_bn_relu.py | 193 ++++++++++++++++++ .../megengine/module/qat/elemwise.py | 29 +++ python_module/megengine/module/qat/linear.py | 37 ++++ python_module/megengine/module/qat/module.py | 96 +++++++++ .../megengine/module/qat/quant_dequant.py | 45 ++++ .../megengine/module/quant_dequant.py | 20 +- .../megengine/module/quantized/__init__.py | 1 + .../megengine/module/quantized/concat.py | 27 +-- .../module/quantized/conv_bn_relu.py | 68 +++--- .../megengine/module/quantized/elemwise.py | 29 +-- .../megengine/module/quantized/linear.py | 41 ++-- .../megengine/module/quantized/module.py | 31 +++ .../module/quantized/quant_dequant.py | 52 +++-- .../megengine/quantization/__init__.py | 9 - .../megengine/quantization/qconfig.py | 10 +- .../megengine/quantization/quantize.py | 107 +++++++--- .../test/unit/module/test_conv_bn_relu.py | 22 +- .../test/unit/quantization/quantize.py | 38 ++++ 27 files changed, 735 insertions(+), 480 deletions(-) create mode 100644 python_module/megengine/module/qat/__init__.py create mode 100644 python_module/megengine/module/qat/concat.py create mode 100644 python_module/megengine/module/qat/conv_bn_relu.py create mode 100644 python_module/megengine/module/qat/elemwise.py create mode 100644 python_module/megengine/module/qat/linear.py create mode 100644 python_module/megengine/module/qat/module.py create mode 100644 python_module/megengine/module/qat/quant_dequant.py create mode 100644 python_module/megengine/module/quantized/module.py create mode 100644 python_module/test/unit/quantization/quantize.py diff --git a/python_module/megengine/functional/nn.py b/python_module/megengine/functional/nn.py index 51852253e..eea340239 100644 --- a/python_module/megengine/functional/nn.py +++ b/python_module/megengine/functional/nn.py @@ -27,10 +27,10 @@ from .utils import _decide_comp_node_and_comp_graph def linear(inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor: """Applies a linear transformation to the input. - Refer to :class:`~.Linear` for more information. + Refer to :class:`~.module.linear.Linear` for more information. :param inp: the input tensor with shape `(N, in_features)`. - :param weight: the weight with shape `(out_features, in_features)`. + :param weight: the weight with shape `(out_features, in_features)`. :param bias: the bias with shape `(out_features,)`. Default: ``None`` """ @@ -300,9 +300,9 @@ def leaky_relu(inp: Tensor, negative_slope: float = 0.01) -> Tensor: def softplus(inp: Tensor, beta: float = 1, threshold: float = 20) -> Tensor: r""" Performs the elementwise function: - + .. math:: - + \mathsf{softplus}(x) = \log(1+\exp(\beta x)) / \beta. For numerical stability the identity function is used when :math:`\beta x > \textrm{threshold}`. diff --git a/python_module/megengine/module/__init__.py b/python_module/megengine/module/__init__.py index ae6a11940..0391f29d1 100644 --- a/python_module/megengine/module/__init__.py +++ b/python_module/megengine/module/__init__.py @@ -16,7 +16,7 @@ from .elemwise import Elemwise from .embedding import Embedding from .identity import Identity from .linear import Linear -from .module import Module, QATModule +from .module import Module from .parampack import ParamPack from .pooling import AvgPool2d, MaxPool2d from .quant_dequant import DequantStub, QuantStub diff --git a/python_module/megengine/module/concat.py b/python_module/megengine/module/concat.py index b62086d0d..453f951b1 100644 --- a/python_module/megengine/module/concat.py +++ b/python_module/megengine/module/concat.py @@ -9,19 +9,14 @@ from typing import Iterable from .. import functional as F from ..core.tensor import Tensor -from .module import QATModule +from .module import Module -class Concat(QATModule): +class Concat(Module): r""" - A :class:`~.QATModule` to do functional concat, should replace concat with this module, - supporting ``qat`` mode and ``quantized`` mode. + A :class:`~.Module` to do functional concat. Could be replaced with :class:`~.QATModule` + version :class:`~.qat.concat.Concat` using :func:`~.quantize.quantize_qat`. """ def forward(self, inps: Iterable[Tensor], axis: int = 0): return F.concat(inps, axis) - - def forward_qat(self, inps: Iterable[Tensor], axis: int = 0): - return self.apply_fakequant_with_observer( - self.forward(inps, axis), self.act_fake_quant, self.act_observer - ) diff --git a/python_module/megengine/module/conv_bn_relu.py b/python_module/megengine/module/conv_bn_relu.py index af088feaf..bb3a85773 100644 --- a/python_module/megengine/module/conv_bn_relu.py +++ b/python_module/megengine/module/conv_bn_relu.py @@ -7,14 +7,13 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from typing import Tuple, Union -from ..core import ones, zeros -from ..functional import add_update, flatten, relu, sqrt, sum, zero_grad +from ..functional import relu from .batchnorm import BatchNorm2d from .conv import Conv2d -from .module import QATModule +from .module import Module -class _ConvBn2d(QATModule): +class _ConvBnActivation2d(Module): def __init__( self, in_channels: int, @@ -47,171 +46,24 @@ class _ConvBn2d(QATModule): ) self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats) - def get_batch_mean_var(self, inp): - def _sum_channel(inp, axis=0, keepdims=True): - if isinstance(axis, int): - out = sum(inp, axis=axis, keepdims=keepdims) - elif isinstance(axis, tuple): - for idx, elem in enumerate(axis): - out = sum(inp if idx == 0 else out, axis=elem, keepdims=keepdims) - return out - sum1 = _sum_channel(inp, (0, 2, 3)) - sum2 = _sum_channel(inp ** 2, (0, 2, 3)) - reduce_size = inp.shapeof().prod() / inp.shapeof(1) - batch_mean = sum1 / reduce_size - batch_var = (sum2 - sum1 ** 2 / reduce_size) / reduce_size - return batch_mean, batch_var - - def fold_weight_bias(self, bn_mean, bn_var): - # get fold bn conv param - # bn_istd = 1 / bn_std - # w_fold = gamma / bn_std * W - # b_fold = gamma * (b - bn_mean) / bn_std + beta - gamma = self.bn.weight - if gamma is None: - gamma = ones((self.bn.num_features), dtype="float32") - gamma = gamma.reshape(1, -1, 1, 1) - beta = self.bn.bias - if beta is None: - beta = zeros((self.bn.num_features), dtype="float32") - beta = beta.reshape(1, -1, 1, 1) - - if bn_mean is None: - bn_mean = zeros((1, self.bn.num_features, 1, 1), dtype="float32") - if bn_var is None: - bn_var = ones((1, self.bn.num_features, 1, 1), dtype="float32") - - conv_bias = self.conv.bias - if conv_bias is None: - conv_bias = zeros(self.conv._infer_bias_shape(), dtype="float32") - - bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) - # bn_istd = 1 / bn_std - # w_fold = gamma / bn_std * W - scale_factor = gamma * bn_istd - if self.conv.groups == 1: - w_fold = self.conv.weight * scale_factor.reshape(-1, 1, 1, 1) - else: - w_fold = self.conv.weight * scale_factor.reshape( - self.conv.groups, -1, 1, 1, 1 - ) - - # b_fold = gamma * (b - bn_mean) / bn_std + beta - b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd - return w_fold, b_fold - - def update_running_mean_and_running_var( - self, bn_mean, bn_var, num_elements_per_channel - ): - # update running mean and running var. no grad, use unbiased bn var - bn_mean = zero_grad(bn_mean) - bn_var = ( - zero_grad(bn_var) - * num_elements_per_channel - / (num_elements_per_channel - 1) - ) - exponential_average_factor = 1 - self.bn.momentum - add_update( - self.bn.running_mean, - delta=bn_mean, - alpha=1 - exponential_average_factor, - beta=exponential_average_factor, - ) - add_update( - self.bn.running_var, - delta=bn_var, - alpha=1 - exponential_average_factor, - beta=exponential_average_factor, - ) - - def calc_conv_bn_qat(self, inp, approx=True): - if self.training and not approx: - conv = self.conv(inp) - bn_mean, bn_var = self.get_batch_mean_var(conv) - num_elements_per_channel = conv.shapeof().prod() / conv.shapeof(1) - self.update_running_mean_and_running_var( - bn_mean, bn_var, num_elements_per_channel - ) - else: - bn_mean, bn_var = self.bn.running_mean, self.bn.running_var - - # get gamma and beta in BatchNorm - gamma = self.bn.weight - if gamma is None: - gamma = ones((self.bn.num_features), dtype="float32") - gamma = gamma.reshape(1, -1, 1, 1) - beta = self.bn.bias - if beta is None: - beta = zeros((self.bn.num_features), dtype="float32") - beta = beta.reshape(1, -1, 1, 1) - # conv_bias - conv_bias = self.conv.bias - if conv_bias is None: - conv_bias = zeros(self.conv._infer_bias_shape(), dtype="float32") - - bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) - # bn_istd = 1 / bn_std - # w_fold = gamma / bn_std * W - scale_factor = gamma * bn_istd - if self.conv.groups == 1: - w_fold = self.conv.weight * scale_factor.reshape(-1, 1, 1, 1) - else: - w_fold = self.conv.weight * scale_factor.reshape( - self.conv.groups, -1, 1, 1, 1 - ) - b_fold = None - if not (self.training and approx): - # b_fold = gamma * (conv_bias - bn_mean) / bn_std + beta - b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd - - w_qat = self.apply_fakequant_with_observer( - w_fold, self.weight_fake_quant, self.weight_observer - ) - conv = self.conv.calc_conv(inp, w_qat, b_fold) - if not (self.training and approx): - return conv - - # rescale conv to get original conv output - orig_conv = conv / scale_factor.reshape(1, -1, 1, 1) - if self.conv.bias is not None: - orig_conv = orig_conv + self.conv.bias - # calculate batch norm - bn_mean, bn_var = self.get_batch_mean_var(orig_conv) - bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) - conv = gamma * bn_istd * (orig_conv - bn_mean) + beta - num_elements_per_channel = conv.shapeof().prod() / conv.shapeof(1) - self.update_running_mean_and_running_var( - bn_mean, bn_var, num_elements_per_channel - ) - return conv - - -class ConvBn2d(_ConvBn2d): +class ConvBn2d(_ConvBnActivation2d): r""" - A fused :class:`~.QATModule` including Conv2d and BatchNorm2d, supporting ``qat`` mode - and ``normal`` mode. + A fused :class:`~.Module` including Conv2d, BatchNorm2d. Could be replaced + with :class:`~.QATModule` version :class:`~.qat.conv_bn_relu.ConvBn2d` using + :func:`~.quantize.quantize_qat`. """ - def forward_qat(self, inp): - return self.apply_fakequant_with_observer( - self.calc_conv_bn_qat(inp), self.act_fake_quant, self.act_observer - ) - def forward(self, inp): return self.bn(self.conv(inp)) -class ConvBnRelu2d(_ConvBn2d): +class ConvBnRelu2d(_ConvBnActivation2d): r""" - A fused :class:`~.QATModule` including Conv2d, BatchNorm2d and relu, supporting ``qat`` - mode and ``normal`` mode. + A fused :class:`~.Module` including Conv2d, BatchNorm2d and relu. Could be replaced + with :class:`~.QATModule` version :class:`~.qat.conv_bn_relu.ConvBnRelu2d` using + :func:`~.quantize.quantize_qat`. """ - def forward_qat(self, inp): - return self.apply_fakequant_with_observer( - relu(self.calc_conv_bn_qat(inp)), self.act_fake_quant, self.act_observer - ) - def forward(self, inp): return relu(self.bn(self.conv(inp))) diff --git a/python_module/megengine/module/elemwise.py b/python_module/megengine/module/elemwise.py index 9ca469c54..d1947b5e0 100644 --- a/python_module/megengine/module/elemwise.py +++ b/python_module/megengine/module/elemwise.py @@ -8,7 +8,7 @@ from .. import _internal as mgb from ..core import Tensor, wrap_io_tensor from ..core.graph import _use_default_if_none -from .module import QATModule +from .module import Module @wrap_io_tensor @@ -22,10 +22,10 @@ def _elemwise_func(mode, *inputs, **kwargs) -> Tensor: return mgb.opr.elemwise(*inputs, mode=mode, **kwargs) -class Elemwise(QATModule): +class Elemwise(Module): r""" - A :class:`~.QATModule` to do elemwise operator, should functional operator with this module, - supporting ``qat`` mode and ``normal`` mode. + A :class:`~.Module` to do elemwise operator. Could be replaced with :class:`~.QATModule` + version :class:`~.qat.elemwise.Elemwise` using :func:`~.quantize.quantize_qat`. :param method: the elemwise method, support the following string. It will do the normal elemwise operator for float. @@ -88,8 +88,3 @@ class Elemwise(QATModule): def forward(self, *inps): return _elemwise_func(self.method, *inps) - - def forward_qat(self, *inps): - return self.apply_fakequant_with_observer( - self.forward(*inps), self.act_fake_quant, self.act_observer, - ) diff --git a/python_module/megengine/module/linear.py b/python_module/megengine/module/linear.py index 30f1ec3d7..30f8ea821 100644 --- a/python_module/megengine/module/linear.py +++ b/python_module/megengine/module/linear.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. @@ -11,10 +10,10 @@ import numpy as np from .. import functional as F from ..core import Parameter from . import init -from .module import QATModule +from .module import Module -class Linear(QATModule): +class Linear(Module): r"""Applies a linear transformation to the input. For instance, if input is x, then output y is: @@ -60,13 +59,3 @@ class Linear(QATModule): def forward(self, x): return self._calc_linear(x, self.weight, self.bias) - - def forward_qat(self, x): - w_qat = self.apply_fakequant_with_observer( - self.weight, self.weight_fake_quant, self.weight_observer - ) - return self.apply_fakequant_with_observer( - self._calc_linear(x, w_qat, self.bias), - self.act_fake_quant, - self.act_observer, - ) diff --git a/python_module/megengine/module/module.py b/python_module/megengine/module/module.py index fdea96b51..183e3e42b 100644 --- a/python_module/megengine/module/module.py +++ b/python_module/megengine/module/module.py @@ -7,7 +7,6 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from abc import ABCMeta, abstractmethod from collections import OrderedDict -from enum import Enum from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union import numpy as np @@ -443,98 +442,3 @@ class Module(metaclass=ABCMeta): loaded.append(k) return set(loaded), set(skipped) - - -class QATModule(Module): - r""" - Base class of quantization related Module. Add extra forward methods - :meth:`~.QATModule.forward_qat` and :meth:`~.QATModule.forward_quantized` for - ``qat``(quantization aware training) mode and ``quantized`` mode respectively. - - Use :meth:`~.QATModule.quant` to switch between ``QAT`` and ``NORMAL`` mode, - and use :meth:`~.QATModule.to_quantized` to switch to ``quantized`` mode, - which is irreversible. - - If you want to recursively switch mode for all QATModule in network, use - functions in :mod:`~.quantization.quantize`. - """ - - class QATMode(Enum): - DISABLED = 1 - QAT = 2 - CALIBRATION = 3 - - def __init__(self): - from ..quantization import ( - QConfig, - FakeQuantize, - Observer, - ) # pylint: disable=all - - super().__init__() - - self.quantizing = self.QATMode.DISABLED - self.scale = None - - self.weight_observer = None # type: Observer - self.act_observer = None # type: Observer - - self.weight_fake_quant = None # type: FakeQuantize - self.act_fake_quant = None # type: FakeQuantize - - def set_qconfig(self, qconfig: "QConfig"): - self.weight_observer = qconfig.weight_observer() - self.act_observer = qconfig.act_observer() - - self.weight_fake_quant = ( - None - if qconfig.fake_quant is None - else qconfig.fake_quant(self.weight_observer.dtype) - ) - self.act_fake_quant = ( - None - if qconfig.fake_quant is None - else qconfig.fake_quant(self.act_observer.dtype) - ) - - def apply_observer(self, target: Tensor, obs: "Observer"): - return obs(target) - - def apply_fakequant_with_observer( - self, target: Tensor, fq: "FakeQuantize", obs: "Observer" - ): - oup = self.apply_observer(target, obs) - if fq is not None: - q_dict = obs.get_qparams() - oup = fq(oup, q_dict) - return oup - - def set_qat_mode(self, mode: QATMode): - r""" - Change ``self.quantizing`` mode, available values: ``self.QATMode.DISABLED``, - ``QAT``,``CALIBRATION``. - """ - if not isinstance(mode, self.QATMode): - raise TypeError("mode must be QATMode Enum type") - self.quantizing = mode - - def to_quantized(self): - r""" - Return a new :class:`~.Module` with quantized parameters of ``self`` - according to scale and zero_point in ``self.xxx_observer``. - """ - raise NotImplementedError( - "Use megengine.quantization.quantize to register the method." - ) - - @abstractmethod - def forward_qat(self, *args, **kwargs): - r""" - Forward method for ``qat`` mode. - """ - - def __call__(self, *args, **kwargs): - if self.quantizing == self.QATMode.DISABLED: - return self.forward(*args, **kwargs) - else: - return self.forward_qat(*args, **kwargs) diff --git a/python_module/megengine/module/qat/__init__.py b/python_module/megengine/module/qat/__init__.py new file mode 100644 index 000000000..26c4b6eeb --- /dev/null +++ b/python_module/megengine/module/qat/__init__.py @@ -0,0 +1,13 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from .concat import Concat +from .conv_bn_relu import ConvBn2d, ConvBnRelu2d +from .elemwise import Elemwise +from .linear import Linear +from .module import QATModule +from .quant_dequant import DequantStub, QuantStub diff --git a/python_module/megengine/module/qat/concat.py b/python_module/megengine/module/qat/concat.py new file mode 100644 index 000000000..893b1ad05 --- /dev/null +++ b/python_module/megengine/module/qat/concat.py @@ -0,0 +1,30 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from typing import Iterable + +from ...core.tensor import Tensor +from .. import concat as Float +from .module import QATModule + + +class Concat(Float.Concat, QATModule): + r""" + A :class:`~.QATModule` to do functional concat with QAT support. + Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`. + """ + + def forward(self, inps: Iterable[Tensor], axis: int = 0): + return self.apply_quant_activation(super().forward(inps, axis)) + + @classmethod + def from_float_module(cls, float_module): + r""" + Return a :class:`~.QATModule` instance converted from + a float :class:`~.Module` instance. + """ + return cls() diff --git a/python_module/megengine/module/qat/conv_bn_relu.py b/python_module/megengine/module/qat/conv_bn_relu.py new file mode 100644 index 000000000..4a8be9e17 --- /dev/null +++ b/python_module/megengine/module/qat/conv_bn_relu.py @@ -0,0 +1,193 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from ...core import ones, zeros +from ...functional import add_update, relu, sqrt, sum, zero_grad +from .. import conv_bn_relu as Float +from .module import QATModule + + +class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule): + def get_batch_mean_var(self, inp): + def _sum_channel(inp, axis=0, keepdims=True): + if isinstance(axis, int): + out = sum(inp, axis=axis, keepdims=keepdims) + elif isinstance(axis, tuple): + for idx, elem in enumerate(axis): + out = sum(inp if idx == 0 else out, axis=elem, keepdims=keepdims) + return out + + sum1 = _sum_channel(inp, (0, 2, 3)) + sum2 = _sum_channel(inp ** 2, (0, 2, 3)) + reduce_size = inp.shapeof().prod() / inp.shapeof(1) + batch_mean = sum1 / reduce_size + batch_var = (sum2 - sum1 ** 2 / reduce_size) / reduce_size + return batch_mean, batch_var + + def fold_weight_bias(self, bn_mean, bn_var): + # get fold bn conv param + # bn_istd = 1 / bn_std + # w_fold = gamma / bn_std * W + # b_fold = gamma * (b - bn_mean) / bn_std + beta + gamma = self.bn.weight + if gamma is None: + gamma = ones((self.bn.num_features), dtype="float32") + gamma = gamma.reshape(1, -1, 1, 1) + beta = self.bn.bias + if beta is None: + beta = zeros((self.bn.num_features), dtype="float32") + beta = beta.reshape(1, -1, 1, 1) + + if bn_mean is None: + bn_mean = zeros((1, self.bn.num_features, 1, 1), dtype="float32") + if bn_var is None: + bn_var = ones((1, self.bn.num_features, 1, 1), dtype="float32") + + conv_bias = self.conv.bias + if conv_bias is None: + conv_bias = zeros(self.conv._infer_bias_shape(), dtype="float32") + + bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) + # bn_istd = 1 / bn_std + # w_fold = gamma / bn_std * W + scale_factor = gamma * bn_istd + if self.conv.groups == 1: + w_fold = self.conv.weight * scale_factor.reshape(-1, 1, 1, 1) + else: + w_fold = self.conv.weight * scale_factor.reshape( + self.conv.groups, -1, 1, 1, 1 + ) + + # b_fold = gamma * (b - bn_mean) / bn_std + beta + b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd + return w_fold, b_fold + + def update_running_mean_and_running_var( + self, bn_mean, bn_var, num_elements_per_channel + ): + # update running mean and running var. no grad, use unbiased bn var + bn_mean = zero_grad(bn_mean) + bn_var = ( + zero_grad(bn_var) + * num_elements_per_channel + / (num_elements_per_channel - 1) + ) + exponential_average_factor = 1 - self.bn.momentum + add_update( + self.bn.running_mean, + delta=bn_mean, + alpha=1 - exponential_average_factor, + beta=exponential_average_factor, + ) + add_update( + self.bn.running_var, + delta=bn_var, + alpha=1 - exponential_average_factor, + beta=exponential_average_factor, + ) + + def calc_conv_bn_qat(self, inp, approx=True): + if self.training and not approx: + conv = self.conv(inp) + bn_mean, bn_var = self.get_batch_mean_var(conv) + num_elements_per_channel = conv.shapeof().prod() / conv.shapeof(1) + self.update_running_mean_and_running_var( + bn_mean, bn_var, num_elements_per_channel + ) + else: + bn_mean, bn_var = self.bn.running_mean, self.bn.running_var + + # get gamma and beta in BatchNorm + gamma = self.bn.weight + if gamma is None: + gamma = ones((self.bn.num_features), dtype="float32") + gamma = gamma.reshape(1, -1, 1, 1) + beta = self.bn.bias + if beta is None: + beta = zeros((self.bn.num_features), dtype="float32") + beta = beta.reshape(1, -1, 1, 1) + # conv_bias + conv_bias = self.conv.bias + if conv_bias is None: + conv_bias = zeros(self.conv._infer_bias_shape(), dtype="float32") + + bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) + # bn_istd = 1 / bn_std + # w_fold = gamma / bn_std * W + scale_factor = gamma * bn_istd + if self.conv.groups == 1: + w_fold = self.conv.weight * scale_factor.reshape(-1, 1, 1, 1) + else: + w_fold = self.conv.weight * scale_factor.reshape( + self.conv.groups, -1, 1, 1, 1 + ) + b_fold = None + if not (self.training and approx): + # b_fold = gamma * (conv_bias - bn_mean) / bn_std + beta + b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd + + w_qat = self.apply_quant_weight(w_fold) + conv = self.conv.calc_conv(inp, w_qat, b_fold) + if not (self.training and approx): + return conv + + # rescale conv to get original conv output + orig_conv = conv / scale_factor.reshape(1, -1, 1, 1) + if self.conv.bias is not None: + orig_conv = orig_conv + self.conv.bias + # calculate batch norm + bn_mean, bn_var = self.get_batch_mean_var(orig_conv) + bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) + conv = gamma * bn_istd * (orig_conv - bn_mean) + beta + num_elements_per_channel = conv.shapeof().prod() / conv.shapeof(1) + self.update_running_mean_and_running_var( + bn_mean, bn_var, num_elements_per_channel + ) + return conv + + @classmethod + def from_float_module(cls, float_module: Float._ConvBnActivation2d): + r""" + Return a :class:`~.QATModule` instance converted from + a float :class:`~.Module` instance. + """ + qat_module = cls( + float_module.conv.in_channels, + float_module.conv.out_channels, + float_module.conv.kernel_size, + float_module.conv.stride, + float_module.conv.padding, + float_module.conv.dilation, + float_module.conv.groups, + bool(float_module.conv.bias), + float_module.conv.conv_mode.name, + float_module.conv.compute_mode.name, + ) + qat_module.conv.weight = float_module.conv.weight + qat_module.conv.bias = float_module.conv.bias + qat_module.bn = float_module.bn + return qat_module + + +class ConvBn2d(_ConvBnActivation2d): + r""" + A fused :class:`~.QATModule` including Conv2d, BatchNorm2d with QAT support. + Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`. + """ + + def forward(self, inp): + return self.apply_quant_activation(self.calc_conv_bn_qat(inp)) + + +class ConvBnRelu2d(_ConvBnActivation2d): + r""" + A fused :class:`~.QATModule` including Conv2d, BatchNorm2d and relu with QAT support. + Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`. + """ + + def forward(self, inp): + return self.apply_quant_activation(relu(self.calc_conv_bn_qat(inp))) diff --git a/python_module/megengine/module/qat/elemwise.py b/python_module/megengine/module/qat/elemwise.py new file mode 100644 index 000000000..37e03e817 --- /dev/null +++ b/python_module/megengine/module/qat/elemwise.py @@ -0,0 +1,29 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from .. import elemwise as Float +from .module import QATModule + + +class Elemwise(Float.Elemwise, QATModule): + r""" + A :class:`~.QATModule` to do elemwise operator with QAT support. + Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`. + + :param method: the elemwise method, see :class:`~.module.elemwise.Elemwise` for detail. + """ + + def forward(self, *inps): + return self.apply_quant_activation(super().forward(*inps)) + + @classmethod + def from_float_module(cls, float_module: Float.Elemwise): + r""" + Return a :class:`~.QATModule` instance converted from + a float :class:`~.Module` instance. + """ + return cls(float_module.method.name) diff --git a/python_module/megengine/module/qat/linear.py b/python_module/megengine/module/qat/linear.py new file mode 100644 index 000000000..d8174624f --- /dev/null +++ b/python_module/megengine/module/qat/linear.py @@ -0,0 +1,37 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from .. import linear as Float +from .module import QATModule + + +class Linear(Float.Linear, QATModule): + r""" + A :class:`~.QATModule` version of :class:`~.module.linear.Linear`. + Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`. + + :param in_features: size of each input sample. + :param out_features: size of each output sample. + :param bias: If set to ``False``, the layer will not learn an additive bias. + Default: ``True`` + + """ + + def forward(self, x): + w_qat = self.apply_quant_weight(self.weight) + return self.apply_quant_activation(self._calc_linear(x, w_qat, self.bias),) + + @classmethod + def from_float_module(cls, float_module: Float.Linear): + r""" + Return a :class:`~.QATModule` instance converted from + a float :class:`~.Module` instance. + """ + qmod = cls(float_module.in_features, float_module.out_features) + qmod.weight = float_module.weight + qmod.bias = float_module.bias + return qmod diff --git a/python_module/megengine/module/qat/module.py b/python_module/megengine/module/qat/module.py new file mode 100644 index 000000000..7dc47996d --- /dev/null +++ b/python_module/megengine/module/qat/module.py @@ -0,0 +1,96 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from abc import abstractmethod + +from ...core import Tensor +from ...quantization import FakeQuantize, Observer, QConfig +from ..module import Module + + +class QATModule(Module): + r""" + Base class of quantized-float related Module, basically for QAT and Calibration. + + Use :meth:`~.QATModule.from_float_module` to generate a instance from float :class:`~.Module`. + Or use :func:`~.quantize.quantize_qat` to do it recursively and automatically. + + Can also be converted to :class:`~.QuantizedModule` for deployment using + :func:`~.quantize.quantize` further. + """ + + def __init__(self): + super().__init__() + + self.scale = None + + self.weight_observer = None # type: Observer + self.act_observer = None # type: Observer + + self.weight_fake_quant = None # type: FakeQuantize + self.act_fake_quant = None # type: FakeQuantize + + def set_qconfig(self, qconfig: QConfig): + r""" + Set quantization related configs with ``qconfig``, including + observer and fake_quant for weight and activation. + """ + self.weight_observer = qconfig.weight_observer() + self.act_observer = qconfig.act_observer() + + if qconfig.fake_quant is None: + self.weight_fake_quant = None + self.act_fake_quant = None + else: + self.weight_fake_quant = qconfig.fake_quant(self.weight_observer.dtype) + self.act_fake_quant = qconfig.fake_quant(self.act_observer.dtype) + + def _apply_fakequant_with_observer( + self, target: Tensor, fake_quant: FakeQuantize, observer: Observer + ): + oup = observer(target) + if fake_quant is None: + return oup + else: + q_dict = observer.get_qparams() + return fake_quant(oup, q_dict) + + def apply_quant_weight(self, target: Tensor): + r""" + Apply weight's observer and fake_quant from ``qconfig`` on ``target``. + """ + return self._apply_fakequant_with_observer( + target, self.weight_fake_quant, self.weight_observer + ) + + def apply_quant_activation(self, target: Tensor): + r""" + Apply weight's observer and fake_quant from ``qconfig`` on ``target``. + """ + return self._apply_fakequant_with_observer( + target, self.act_fake_quant, self.act_observer + ) + + def get_weight_dtype(self): + r""" + Get weight's quantization dtype as the method from ``qconfig``. + """ + return self.weight_observer.get_dtype() + + def get_activation_dtype(self): + r""" + Get activation's quantization dtype as the method from ``qconfig``. + """ + return self.act_observer.get_dtype() + + @classmethod + @abstractmethod + def from_float_module(cls, float_module: Module): + r""" + Return a :class:`~.QATModule` instance converted from + a float :class:`~.Module` instance. + """ diff --git a/python_module/megengine/module/qat/quant_dequant.py b/python_module/megengine/module/qat/quant_dequant.py new file mode 100644 index 000000000..84ebdf92e --- /dev/null +++ b/python_module/megengine/module/qat/quant_dequant.py @@ -0,0 +1,45 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from .. import quant_dequant as Float +from .module import QATModule + + +class QuantStub(Float.QuantStub, QATModule): + r""" + A helper QATModule simply return input, but will quantize + input after converted to :class:`~.QuantizedModule`. + """ + + def forward(self, inp): + return self.apply_quant_activation(inp) + + @classmethod + def from_float_module(cls, float_module: Float.QuantStub): + r""" + Return a :class:`~.QATModule` instance converted from + a float :class:`~.Module` instance. + """ + return cls() + + +class DequantStub(Float.DequantStub, QATModule): + r""" + A helper QATModule simply return input, but will de-quantize + input after converted to :class:`~.QuantizedModule`. + """ + + def forward(self, inp): + return inp + + @classmethod + def from_float_module(cls, float_module: Float.DequantStub): + r""" + Return a :class:`~.QATModule` instance converted from + a float :class:`~.Module` instance. + """ + return cls() diff --git a/python_module/megengine/module/quant_dequant.py b/python_module/megengine/module/quant_dequant.py index ed20e3c0e..aaf2b0cc3 100644 --- a/python_module/megengine/module/quant_dequant.py +++ b/python_module/megengine/module/quant_dequant.py @@ -5,30 +5,24 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -from .module import QATModule +from .module import Module -class QuantStub(QATModule): +class QuantStub(Module): r""" - A helper QATModule doing quantize operation on input. + A helper :class:`~.Module` simply returning input. Could be replaced with :class:`~.QATModule` + version :class:`~.qat.QuantStub` using :func:`~.quantize.quantize_qat`. """ def forward(self, inp): return inp - def forward_qat(self, inp): - return self.apply_fakequant_with_observer( - inp, self.act_fake_quant, self.act_observer - ) - -class DequantStub(QATModule): +class DequantStub(Module): r""" - A helper QATModule doing de-quantize operation on input. + A helper :class:`~.Module` simply returning input. Could be replaced with :class:`~.QATModule` + version :class:`~.qat.DequantStub` using :func:`~.quantize.quantize_qat`. """ def forward(self, inp): return inp - - def forward_qat(self, inp): - return inp diff --git a/python_module/megengine/module/quantized/__init__.py b/python_module/megengine/module/quantized/__init__.py index 040a3b14d..b79977214 100644 --- a/python_module/megengine/module/quantized/__init__.py +++ b/python_module/megengine/module/quantized/__init__.py @@ -9,4 +9,5 @@ from .concat import Concat from .conv_bn_relu import ConvBn2d, ConvBnRelu2d from .elemwise import Elemwise from .linear import Linear +from .module import QuantizedModule from .quant_dequant import DequantStub, QuantStub diff --git a/python_module/megengine/module/quantized/concat.py b/python_module/megengine/module/quantized/concat.py index f3f266a1d..f9ef05d9c 100644 --- a/python_module/megengine/module/quantized/concat.py +++ b/python_module/megengine/module/quantized/concat.py @@ -7,17 +7,15 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from typing import Iterable -from ... import _internal as mgb from ... import functional as F -from ... import module as Float from ...core.tensor import Tensor -from ...quantization.utils import register_method_to_class -from ..module import Module +from ..qat import concat as QAT +from .module import QuantizedModule -class Concat(Module): +class Concat(QuantizedModule): r""" - A :class:`~.Module` to do quantized concat, inference only. + A :class:`~.QuantizedModule` to do quantized concat, inference only. """ def __init__(self, dtype=None): @@ -25,16 +23,13 @@ class Concat(Module): self.output_dtype = dtype def forward(self, inps: Iterable[Tensor], axis: int = 0): - if self.training: - raise ValueError("quantized module only support inference.") new_inps = (x.astype(self.output_dtype) for x in inps) return F.concat(new_inps, axis) - -@register_method_to_class(Float.Concat) -def to_quantized(float_module): - r""" - Replace :class:`~.module.QATModule`'s ``to_quantized`` method. - implemented here to avoid circular import. - """ - return Concat(float_module.act_observer.get_dtype()) + @classmethod + def from_qat_module(cls, qat_module: QAT.Concat): + r""" + return a :class:`~.QuantizedModule` instance converted from a + :class:`~.QATModule` instance. + """ + return cls(qat_module.get_activation_dtype()) diff --git a/python_module/megengine/module/quantized/conv_bn_relu.py b/python_module/megengine/module/quantized/conv_bn_relu.py index 18eddaa8f..6b72921e0 100644 --- a/python_module/megengine/module/quantized/conv_bn_relu.py +++ b/python_module/megengine/module/quantized/conv_bn_relu.py @@ -5,7 +5,6 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -from functools import partial from typing import Tuple, Union import megengine._internal as mgb @@ -13,11 +12,11 @@ import megengine._internal as mgb from ... import module as Float from ...core import Parameter from ...functional import conv_bias_activation -from ...module import Conv2d -from ...quantization.utils import register_method_to_class +from ..qat import conv_bn_relu as QAT +from .module import QuantizedModule -class _ConvBnActivation2d(Conv2d): +class _ConvBnActivation2d(Float.Conv2d, QuantizedModule): r"""Applies a 2D convolution over an quantized input tensor, inference only. The parameter is same with :class: `~.Conv2d` @@ -68,44 +67,41 @@ class _ConvBnActivation2d(Conv2d): nonlinear_mode=nonlinear_mode, ) + @classmethod + def from_qat_module(cls, qat_module: QAT._ConvBnActivation2d): + r""" + return a :class:`~.QuantizedModule` instance converted from a + :class:`~.QATModule` instance. + """ + output_dtype = qat_module.get_activation_dtype() + qconv = cls( + qat_module.conv.in_channels, + qat_module.conv.out_channels, + qat_module.conv.kernel_size, + qat_module.conv.stride, + qat_module.conv.padding, + qat_module.conv.dilation, + qat_module.conv.groups, + dtype=output_dtype, + ) + w_fold, b_fold = qat_module.fold_weight_bias( + qat_module.bn.running_mean, qat_module.bn.running_var + ) + weight = w_fold.astype(qat_module.get_weight_dtype()) + qconv.weight = Parameter(weight.numpy()) + qconv.bias = Parameter(b_fold.numpy()) + return qconv + class ConvBn2d(_ConvBnActivation2d): + r"""quantized version of :class:`~.qat.conv_bn_relu.ConvBn2d`.""" + def forward(self, inp): - if self.training: - raise ValueError("quantized module only support inference.") return self.calc_conv_quantized(inp, nonlinear_mode="IDENTITY") class ConvBnRelu2d(_ConvBnActivation2d): + r"""quantized version of :class:`~.qat.conv_bn_relu.ConvBnRelu2d`.""" + def forward(self, inp): - if self.training: - raise ValueError("quantized module only support inference.") return self.calc_conv_quantized(inp, nonlinear_mode="RELU") - - -def to_quantized(quantized_class, float_module): - output_dtype = float_module.act_observer.get_dtype() - qconv = quantized_class( - float_module.conv.in_channels, - float_module.conv.out_channels, - float_module.conv.kernel_size, - float_module.conv.stride, - float_module.conv.padding, - float_module.conv.dilation, - float_module.conv.groups, - dtype=output_dtype, - ) - w_fold, b_fold = float_module.fold_weight_bias( - float_module.bn.running_mean, float_module.bn.running_var - ) - weight = w_fold.astype(float_module.weight_observer.get_dtype()) - qconv.weight = Parameter(weight.numpy()) - qconv.bias = Parameter(b_fold.numpy()) - - return qconv - - -# replace :class:`~.module.QATModule`'s ``to_quantized`` method. -# implemented here to avoid circular import. -register_method_to_class(Float.ConvBn2d)(partial(to_quantized, ConvBn2d)) -register_method_to_class(Float.ConvBnRelu2d)(partial(to_quantized, ConvBnRelu2d)) diff --git a/python_module/megengine/module/quantized/elemwise.py b/python_module/megengine/module/quantized/elemwise.py index 47f30e47a..db04ed654 100644 --- a/python_module/megengine/module/quantized/elemwise.py +++ b/python_module/megengine/module/quantized/elemwise.py @@ -6,11 +6,10 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from ... import _internal as mgb -from ... import module as Float from ...core import Tensor, wrap_io_tensor from ...core.graph import _use_default_if_none -from ...quantization.utils import register_method_to_class -from ..module import Module +from ..qat import elemwise as QAT +from .module import QuantizedModule @wrap_io_tensor @@ -24,13 +23,8 @@ def _elemwise_multi_type(mode, *inputs, **kwargs) -> Tensor: return mgb.opr.elemwise_multi_type(*inputs, mode=mode, **kwargs) -class Elemwise(Module): - r""" - quantized module for elemwise operator, inference only. - - :param method: the elemwise method, supported string refer to :class:`~.module.elemwise.Elemwise`. - it will do quantized operator with specified output quantized dtype. - """ +class Elemwise(QuantizedModule): + r"""quantized version of :class:`~.qat.elemwise.Elemwise`.""" _elemwise_multi_type_mode = mgb.opr_param_defs.ElemwiseMultiType.Mode @@ -44,11 +38,10 @@ class Elemwise(Module): raise ValueError("quantized module only support inference.") return _elemwise_multi_type(self.method, *inps, dtype=self.output_dtype) - -@register_method_to_class(Float.Elemwise) -def to_quantized(float_module): - r""" - Replace :class:`~.module.QATModule`'s ``to_quantized`` method. - implemented here to avoid circular import. - """ - return Elemwise(float_module.method.name, float_module.act_observer.get_dtype()) + @classmethod + def from_qat_module(cls, qat_module: QAT.Elemwise): + r""" + return a :class:`~.QuantizedModule` instance converted from a + :class:`~.QATModule` instance. + """ + return cls(qat_module.method.name, qat_module.get_activation_dtype()) diff --git a/python_module/megengine/module/quantized/linear.py b/python_module/megengine/module/quantized/linear.py index 243db7d7f..4c7989297 100644 --- a/python_module/megengine/module/quantized/linear.py +++ b/python_module/megengine/module/quantized/linear.py @@ -10,19 +10,13 @@ import numpy as np import megengine._internal as mgb from ... import functional as F -from ... import module as Float from ...core import Parameter -from ...quantization.utils import register_method_to_class -from ..module import Module +from ..qat import linear as QAT +from .module import QuantizedModule -class Linear(Module): - r"""Applies a quantized linear transformation to the input. The module - usually convert from QAT module by to_quantized method. - - :param dtype: output data type. - - """ +class Linear(QuantizedModule): + r"""quantized version of :class:`~.qat.linear.Linear`.""" def __init__( self, dtype: np.dtype = None, @@ -44,17 +38,16 @@ class Linear(Module): None if self.bias is None else self.bias.astype(bias_dtype), ).astype(self.output_dtype) - -@register_method_to_class(Float.Linear) -def to_quantized(float_module): - r""" - Replace :class:`~.module.QATModule`'s ``to_quantized`` method. - implemented here to avoid circular import. - """ - output_dtype = float_module.act_observer.get_dtype() - qmod = Linear(dtype=output_dtype,) - weight = float_module.weight.astype(float_module.weight_observer.get_dtype()) - qmod.weight = Parameter(weight.numpy()) - if float_module.bias is not None: - qmod.bias = Parameter(float_module.bias.numpy()) - return qmod + @classmethod + def from_qat_module(cls, qat_module: QAT.Linear): + r""" + return a :class:`~.QuantizedModule` instance converted from a + :class:`~.QATModule` instance. + """ + output_dtype = qat_module.get_activation_dtype() + qmod = cls(dtype=output_dtype) + weight = qat_module.weight.astype(qat_module.get_weight_dtype()) + qmod.weight = Parameter(weight.numpy()) + if qat_module.bias is not None: + qmod.bias = Parameter(qat_module.bias.numpy()) + return qmod diff --git a/python_module/megengine/module/quantized/module.py b/python_module/megengine/module/quantized/module.py new file mode 100644 index 000000000..4fccdbfa2 --- /dev/null +++ b/python_module/megengine/module/quantized/module.py @@ -0,0 +1,31 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from abc import abstractmethod + +from ..module import Module +from ..qat import QATModule + + +class QuantizedModule(Module): + r""" + Base class of quantized Module, which should be converted from QATModule + and not support traning. + """ + + def __call__(self, *inputs, **kwargs): + if self.training: + raise ValueError("quantized module only support inference.") + return super().__call__(*inputs, **kwargs) + + @classmethod + @abstractmethod + def from_qat_module(cls, qat_module: QATModule): + r""" + return a :class:`~.QuantizedModule` instance converted from a + :class:`~.QATModule` instance. + """ diff --git a/python_module/megengine/module/quantized/quant_dequant.py b/python_module/megengine/module/quantized/quant_dequant.py index 5a91b6fd1..0c245011f 100644 --- a/python_module/megengine/module/quantized/quant_dequant.py +++ b/python_module/megengine/module/quantized/quant_dequant.py @@ -5,15 +5,14 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -from ... import _internal as mgb -from ... import module as Float -from ...quantization.utils import register_method_to_class -from ..module import Module +from ..qat import quant_dequant as QAT +from .module import QuantizedModule -class QuantStub(Module): +class QuantStub(QuantizedModule): r""" - A helper quantize operation on input and inference only. + quantized version of :class:`~.qat.quant_dequant.QuantStub`, + will convert input to quantized dtype. """ def __init__(self, dtype=None): @@ -21,35 +20,30 @@ class QuantStub(Module): self.output_dtype = dtype def forward(self, inp): - if self.training: - raise ValueError("quantized module only support inference.") return inp.astype(self.output_dtype) + @classmethod + def from_qat_module(cls, qat_module: QAT.QuantStub): + r""" + return a :class:`~.QuantizedModule` instance converted from a + :class:`~.QATModule` instance. + """ + return cls(qat_module.get_activation_dtype()) -class DequantStub(Module): + +class DequantStub(QuantizedModule): r""" - A helper de-quantize operation and inference only. + quantized version of :class:`~.qat.quant_dequant.DequantStub`, + will restore quantized input to float32 dtype. """ def forward(self, inp): - if self.training: - raise ValueError("quantized module only support inference.") return inp.astype("float32") - -@register_method_to_class(Float.QuantStub) -def to_quantized(float_module): - r""" - Replace :class:`~.module.QATModule`'s ``to_quantized`` method. - implemented here to avoid circular import. - """ - return QuantStub(float_module.act_observer.get_dtype()) - - -@register_method_to_class(Float.DequantStub) -def to_quantized(float_module): - r""" - Replace :class:`~.module.QATModule`'s ``to_quantized`` method. - implemented here to avoid circular import. - """ - return DequantStub() + @classmethod + def from_qat_module(cls, qat_module: QAT.DequantStub): + r""" + return a :class:`~.QuantizedModule` instance converted from a + :class:`~.QATModule` instance. + """ + return cls() diff --git a/python_module/megengine/quantization/__init__.py b/python_module/megengine/quantization/__init__.py index 1e99493ff..4a2dc96dc 100644 --- a/python_module/megengine/quantization/__init__.py +++ b/python_module/megengine/quantization/__init__.py @@ -13,12 +13,3 @@ from .qconfig import ( ema_fakequant_qconfig, min_max_fakequant_qconfig, ) -from .quantize import ( - disable_fake_quant, - disable_observer, - enable_fake_quant, - enable_observer, - quantize, - quantize_calibration, - quantize_qat, -) diff --git a/python_module/megengine/quantization/qconfig.py b/python_module/megengine/quantization/qconfig.py index cfabdc58e..410cefe4e 100644 --- a/python_module/megengine/quantization/qconfig.py +++ b/python_module/megengine/quantization/qconfig.py @@ -15,16 +15,12 @@ from .observer import ( class QConfig: - """ + r""" A config class indicating how to do quantize toward :class:`~.QATModule`'s - ``activation`` and ``weight``. - - And ``fake_quant`` parameter to indicate - - See :meth:`~.QATModule.set_qconfig` for detail usage. + ``activation`` and ``weight``. See :meth:`~.QATModule.set_qconfig` for detail usage. :param weight_observer: interface to instantiate an :class:`~.Observer` indicating -- how to collect scales and zero_point of wegiht. + how to collect scales and zero_point of wegiht. :param act_observer: similar to ``weight_observer`` but toward activation. :param fake_quant: interface to instantiate a :class:`~.FakeQuantize` indicating how to do fake_quant calculation. can be invoked multi times to get different diff --git a/python_module/megengine/quantization/quantize.py b/python_module/megengine/quantization/quantize.py index 3296f5762..c5558a2be 100644 --- a/python_module/megengine/quantization/quantize.py +++ b/python_module/megengine/quantization/quantize.py @@ -6,68 +6,125 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from copy import deepcopy - -from ..module import Module, QATModule, Sequential, quantized +from typing import Dict, Tuple + +from .. import module as Float +from ..module import Module +from ..module import qat as QAT +from ..module import quantized as Quantized +from ..module.qat import QATModule +from ..module.quantized import QuantizedModule from .qconfig import QConfig, ema_fakequant_qconfig +def _get_quantable_module_names(): + def is_quantable(key: str): + value = getattr(Quantized, key) + return ( + isinstance(value, type) + and issubclass(value, QuantizedModule) + and value != QuantizedModule + ) + + # source should have all quantable modules' names + quantable_module_names = [key for key in dir(Quantized) if is_quantable(key)] + return quantable_module_names + + +def _get_convert_dict() -> Tuple[ + Dict[Module, QATModule], Dict[QATModule, QuantizedModule] +]: + quantable_module_names = _get_quantable_module_names() + + quantable_modules = [getattr(Float, key) for key in quantable_module_names] + qat_modules = [getattr(QAT, key) for key in quantable_module_names] + quantized_modules = [getattr(Quantized, key) for key in quantable_module_names] + + float2qat_dict = dict(zip(quantable_modules, qat_modules)) + qat2quantized_dict = dict(zip(qat_modules, quantized_modules)) + return float2qat_dict, qat2quantized_dict + + +_float2qat_dict, _qat2quantized_dict = _get_convert_dict() + + def quantize(module: Module, inplace=True): r""" - Recursively convert `module` to `quantized` mode through :meth:`~.Module.apply`. + Recursively convert :class:`~.QATModule` to :class:`~.QuantizedModule` + through :meth:`~.Module.apply`. :param module: root module to do convert recursively. + :param inplace: whether to convert submodules in-place. """ if not inplace: module = deepcopy(module) - def is_qat_module(obj): - return isinstance(obj, QATModule) + qat_modules = tuple(_qat2quantized_dict.keys()) + + def is_qat(mod: Module): + return isinstance(mod, qat_modules) # no need to pass prefix and get pure key of parent Module. for key, submodule, parent in module._flatten( - with_key=True, with_parent=True, predicate=is_qat_module + with_key=True, with_parent=True, predicate=is_qat ): - if isinstance(parent, Sequential): + new_mod = _qat2quantized_dict[type(submodule)].from_qat_module(submodule) + if isinstance(parent, Float.Sequential): # cannnot use setattr to be compatible with Sequential's ``__setitem__`` - parent[int(key.split(".")[-1])] = submodule.to_quantized() + parent[int(key.split(".")[-1])] = new_mod else: - setattr(parent, key.split(".")[-1], submodule.to_quantized()) + setattr(parent, key.split(".")[-1], new_mod) return module -def quantize_qat(module: Module, qconfig: QConfig = ema_fakequant_qconfig): +def quantize_qat( + module: Module, inplace=True, qconfig: QConfig = ema_fakequant_qconfig +): r""" - Recursively convert `module` to `qat` mode through :meth:`~.Module.apply` - and set qconfig relatively. + Recursively convert float :class:`~.Module` to :class:`~.QATModule` + through :meth:`~.Module.apply` and set qconfig relatively. :param module: root module to do convert recursively. - :param qconfig: a instance of :class:`~.QConfig` to be set as submodules' qconfig. - default is :any:`~.qconfig.ema_fakequant_qconfig`. + :param inplace: whether to convert submodules in-place. + :param qconfig: an instance of :class:`~.QConfig` to be set as submodules' qconfig. + default is ``ema_fakequant_qconfig``. """ - def fn(mod: Module): - if isinstance(mod, QATModule): - mod.set_qat_mode(QATModule.QATMode.QAT) - mod.set_qconfig(qconfig) + if not inplace: + module = deepcopy(module) - module.apply(fn) + quantable_modules = tuple(_float2qat_dict.keys()) + + def is_quantable(mod: Module): + return isinstance(mod, quantable_modules) + + # no need to pass prefix and get pure key of parent Module. + for key, submodule, parent in module._flatten( + with_key=True, with_parent=True, predicate=is_quantable + ): + new_mod = _float2qat_dict[type(submodule)].from_float_module(submodule) + if isinstance(parent, Float.Sequential): + # cannnot use setattr to be compatible with Sequential's ``__setitem__`` + parent[int(key.split(".")[-1])] = new_mod + else: + setattr(parent, key.split(".")[-1], new_mod) + + propagate_qconfig(module, qconfig) + return module -def quantize_calibration(module: Module, qconfig: QConfig = ema_fakequant_qconfig): +def propagate_qconfig(module: QATModule, qconfig: QConfig): r""" - Recursively convert `module` to `calibration` mode through :meth:`~.Module.apply` - and set qconfig relatively. + Recursively set ``module``'s qconfig through :meth:`~.Module.apply`. - :param module: root module to do convert recursively. + :param module: root module to traverse recursively. :param qconfig: a instance of :class:`~.QConfig` to be set as submodules' qconfig. - default is :any:`~.qconfig.ema_fakequant_qconfig`. """ def fn(mod: Module): if isinstance(mod, QATModule): - mod.set_qat_mode(QATModule.QATMode.CALIBRATION) mod.set_qconfig(qconfig) module.apply(fn) diff --git a/python_module/test/unit/module/test_conv_bn_relu.py b/python_module/test/unit/module/test_conv_bn_relu.py index c713448ff..308adf4fc 100644 --- a/python_module/test/unit/module/test_conv_bn_relu.py +++ b/python_module/test/unit/module/test_conv_bn_relu.py @@ -5,8 +5,7 @@ import numpy as np from megengine import tensor from megengine.module import ConvBn2d -from megengine.quantization import quantize_qat -from megengine.quantization.quantize import disable_fake_quant +from megengine.quantization.quantize import disable_fake_quant, quantize_qat from megengine.test import assertTensorClose @@ -14,18 +13,17 @@ def test_convbn2d(): in_channels = 32 out_channels = 64 kernel_size = 3 - module = ConvBn2d(in_channels, out_channels, kernel_size) - quantize_qat(module) for groups, bias in product([1, 4], [True, False]): - inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) + module = ConvBn2d( + in_channels, out_channels, kernel_size, groups=groups, bias=bias + ) module.train() - qat_module = copy.deepcopy(module) + qat_module = quantize_qat(module, inplace=False) disable_fake_quant(qat_module) - normal_outputs = module.forward(inputs) - qat_outputs = qat_module.forward_qat(inputs) + inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) + normal_outputs = module(inputs) + qat_outputs = qat_module(inputs) assertTensorClose(normal_outputs, qat_outputs, max_err=5e-6) - a = module.bn.running_mean.numpy() - b = qat_module.bn.running_mean.numpy() assertTensorClose( module.bn.running_mean, qat_module.bn.running_mean, max_err=5e-8 ) @@ -33,7 +31,7 @@ def test_convbn2d(): module.bn.running_var, qat_module.bn.running_var, max_err=5e-7 ) module.eval() - normal_outputs = module.forward(inputs) + normal_outputs = module(inputs) qat_module.eval() - qat_outputs = qat_module.forward_qat(inputs) + qat_outputs = qat_module(inputs) assertTensorClose(normal_outputs, qat_outputs, max_err=5e-6) diff --git a/python_module/test/unit/quantization/quantize.py b/python_module/test/unit/quantization/quantize.py new file mode 100644 index 000000000..36cb5279e --- /dev/null +++ b/python_module/test/unit/quantization/quantize.py @@ -0,0 +1,38 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from megengine import module as Float +from megengine.module import qat as QAT +from megengine.quantization.quantize import _get_quantable_module_names + + +def test_get_quantable_module_names(): + # need to make sure names from Quantized and QAT are the same + def _get_qat_module_names(): + def is_qat(key: str): + value = getattr(QAT, key) + return ( + isinstance(value, type) + and issubclass(value, QAT.QATModule) + and value != QAT.QATModule + ) + + # source should have all quantable modules' names + quantable_module_names = [key for key in dir(QAT) if is_qat(key)] + return quantable_module_names + + qat_module_names = _get_qat_module_names() + quantized_module_names = _get_quantable_module_names() + assert set(qat_module_names) == set(quantized_module_names) + + for key in qat_module_names: + value = getattr(Float, key) + assert ( + isinstance(value, type) + and issubclass(value, Float.Module) + and value != Float.Module + ) -- GitLab