提交 caf1fac2 编写于 作者: M Megvii Engine Team

refactor(mge/quantization): split `QATModule` and refactor convert api

GitOrigin-RevId: 80cfb12d10590bbc88fd98370f5e3cf5d196d586
上级 ad3c9315
...@@ -27,10 +27,10 @@ from .utils import _decide_comp_node_and_comp_graph ...@@ -27,10 +27,10 @@ from .utils import _decide_comp_node_and_comp_graph
def linear(inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor: def linear(inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor:
"""Applies a linear transformation to the input. """Applies a linear transformation to the input.
Refer to :class:`~.Linear` for more information. Refer to :class:`~.module.linear.Linear` for more information.
:param inp: the input tensor with shape `(N, in_features)`. :param inp: the input tensor with shape `(N, in_features)`.
:param weight: the weight with shape `(out_features, in_features)`. :param weight: the weight with shape `(out_features, in_features)`.
:param bias: the bias with shape `(out_features,)`. :param bias: the bias with shape `(out_features,)`.
Default: ``None`` Default: ``None``
""" """
...@@ -300,9 +300,9 @@ def leaky_relu(inp: Tensor, negative_slope: float = 0.01) -> Tensor: ...@@ -300,9 +300,9 @@ def leaky_relu(inp: Tensor, negative_slope: float = 0.01) -> Tensor:
def softplus(inp: Tensor, beta: float = 1, threshold: float = 20) -> Tensor: def softplus(inp: Tensor, beta: float = 1, threshold: float = 20) -> Tensor:
r""" r"""
Performs the elementwise function: Performs the elementwise function:
.. math:: .. math::
\mathsf{softplus}(x) = \log(1+\exp(\beta x)) / \beta. \mathsf{softplus}(x) = \log(1+\exp(\beta x)) / \beta.
For numerical stability the identity function is used when :math:`\beta x > \textrm{threshold}`. For numerical stability the identity function is used when :math:`\beta x > \textrm{threshold}`.
......
...@@ -16,7 +16,7 @@ from .elemwise import Elemwise ...@@ -16,7 +16,7 @@ from .elemwise import Elemwise
from .embedding import Embedding from .embedding import Embedding
from .identity import Identity from .identity import Identity
from .linear import Linear from .linear import Linear
from .module import Module, QATModule from .module import Module
from .parampack import ParamPack from .parampack import ParamPack
from .pooling import AvgPool2d, MaxPool2d from .pooling import AvgPool2d, MaxPool2d
from .quant_dequant import DequantStub, QuantStub from .quant_dequant import DequantStub, QuantStub
......
...@@ -9,19 +9,14 @@ from typing import Iterable ...@@ -9,19 +9,14 @@ from typing import Iterable
from .. import functional as F from .. import functional as F
from ..core.tensor import Tensor from ..core.tensor import Tensor
from .module import QATModule from .module import Module
class Concat(QATModule): class Concat(Module):
r""" r"""
A :class:`~.QATModule` to do functional concat, should replace concat with this module, A :class:`~.Module` to do functional concat. Could be replaced with :class:`~.QATModule`
supporting ``qat`` mode and ``quantized`` mode. version :class:`~.qat.concat.Concat` using :func:`~.quantize.quantize_qat`.
""" """
def forward(self, inps: Iterable[Tensor], axis: int = 0): def forward(self, inps: Iterable[Tensor], axis: int = 0):
return F.concat(inps, axis) return F.concat(inps, axis)
def forward_qat(self, inps: Iterable[Tensor], axis: int = 0):
return self.apply_fakequant_with_observer(
self.forward(inps, axis), self.act_fake_quant, self.act_observer
)
...@@ -7,14 +7,13 @@ ...@@ -7,14 +7,13 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Tuple, Union from typing import Tuple, Union
from ..core import ones, zeros from ..functional import relu
from ..functional import add_update, flatten, relu, sqrt, sum, zero_grad
from .batchnorm import BatchNorm2d from .batchnorm import BatchNorm2d
from .conv import Conv2d from .conv import Conv2d
from .module import QATModule from .module import Module
class _ConvBn2d(QATModule): class _ConvBnActivation2d(Module):
def __init__( def __init__(
self, self,
in_channels: int, in_channels: int,
...@@ -47,171 +46,24 @@ class _ConvBn2d(QATModule): ...@@ -47,171 +46,24 @@ class _ConvBn2d(QATModule):
) )
self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats) self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats)
def get_batch_mean_var(self, inp):
def _sum_channel(inp, axis=0, keepdims=True):
if isinstance(axis, int):
out = sum(inp, axis=axis, keepdims=keepdims)
elif isinstance(axis, tuple):
for idx, elem in enumerate(axis):
out = sum(inp if idx == 0 else out, axis=elem, keepdims=keepdims)
return out
sum1 = _sum_channel(inp, (0, 2, 3)) class ConvBn2d(_ConvBnActivation2d):
sum2 = _sum_channel(inp ** 2, (0, 2, 3))
reduce_size = inp.shapeof().prod() / inp.shapeof(1)
batch_mean = sum1 / reduce_size
batch_var = (sum2 - sum1 ** 2 / reduce_size) / reduce_size
return batch_mean, batch_var
def fold_weight_bias(self, bn_mean, bn_var):
# get fold bn conv param
# bn_istd = 1 / bn_std
# w_fold = gamma / bn_std * W
# b_fold = gamma * (b - bn_mean) / bn_std + beta
gamma = self.bn.weight
if gamma is None:
gamma = ones((self.bn.num_features), dtype="float32")
gamma = gamma.reshape(1, -1, 1, 1)
beta = self.bn.bias
if beta is None:
beta = zeros((self.bn.num_features), dtype="float32")
beta = beta.reshape(1, -1, 1, 1)
if bn_mean is None:
bn_mean = zeros((1, self.bn.num_features, 1, 1), dtype="float32")
if bn_var is None:
bn_var = ones((1, self.bn.num_features, 1, 1), dtype="float32")
conv_bias = self.conv.bias
if conv_bias is None:
conv_bias = zeros(self.conv._infer_bias_shape(), dtype="float32")
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps)
# bn_istd = 1 / bn_std
# w_fold = gamma / bn_std * W
scale_factor = gamma * bn_istd
if self.conv.groups == 1:
w_fold = self.conv.weight * scale_factor.reshape(-1, 1, 1, 1)
else:
w_fold = self.conv.weight * scale_factor.reshape(
self.conv.groups, -1, 1, 1, 1
)
# b_fold = gamma * (b - bn_mean) / bn_std + beta
b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd
return w_fold, b_fold
def update_running_mean_and_running_var(
self, bn_mean, bn_var, num_elements_per_channel
):
# update running mean and running var. no grad, use unbiased bn var
bn_mean = zero_grad(bn_mean)
bn_var = (
zero_grad(bn_var)
* num_elements_per_channel
/ (num_elements_per_channel - 1)
)
exponential_average_factor = 1 - self.bn.momentum
add_update(
self.bn.running_mean,
delta=bn_mean,
alpha=1 - exponential_average_factor,
beta=exponential_average_factor,
)
add_update(
self.bn.running_var,
delta=bn_var,
alpha=1 - exponential_average_factor,
beta=exponential_average_factor,
)
def calc_conv_bn_qat(self, inp, approx=True):
if self.training and not approx:
conv = self.conv(inp)
bn_mean, bn_var = self.get_batch_mean_var(conv)
num_elements_per_channel = conv.shapeof().prod() / conv.shapeof(1)
self.update_running_mean_and_running_var(
bn_mean, bn_var, num_elements_per_channel
)
else:
bn_mean, bn_var = self.bn.running_mean, self.bn.running_var
# get gamma and beta in BatchNorm
gamma = self.bn.weight
if gamma is None:
gamma = ones((self.bn.num_features), dtype="float32")
gamma = gamma.reshape(1, -1, 1, 1)
beta = self.bn.bias
if beta is None:
beta = zeros((self.bn.num_features), dtype="float32")
beta = beta.reshape(1, -1, 1, 1)
# conv_bias
conv_bias = self.conv.bias
if conv_bias is None:
conv_bias = zeros(self.conv._infer_bias_shape(), dtype="float32")
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps)
# bn_istd = 1 / bn_std
# w_fold = gamma / bn_std * W
scale_factor = gamma * bn_istd
if self.conv.groups == 1:
w_fold = self.conv.weight * scale_factor.reshape(-1, 1, 1, 1)
else:
w_fold = self.conv.weight * scale_factor.reshape(
self.conv.groups, -1, 1, 1, 1
)
b_fold = None
if not (self.training and approx):
# b_fold = gamma * (conv_bias - bn_mean) / bn_std + beta
b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd
w_qat = self.apply_fakequant_with_observer(
w_fold, self.weight_fake_quant, self.weight_observer
)
conv = self.conv.calc_conv(inp, w_qat, b_fold)
if not (self.training and approx):
return conv
# rescale conv to get original conv output
orig_conv = conv / scale_factor.reshape(1, -1, 1, 1)
if self.conv.bias is not None:
orig_conv = orig_conv + self.conv.bias
# calculate batch norm
bn_mean, bn_var = self.get_batch_mean_var(orig_conv)
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps)
conv = gamma * bn_istd * (orig_conv - bn_mean) + beta
num_elements_per_channel = conv.shapeof().prod() / conv.shapeof(1)
self.update_running_mean_and_running_var(
bn_mean, bn_var, num_elements_per_channel
)
return conv
class ConvBn2d(_ConvBn2d):
r""" r"""
A fused :class:`~.QATModule` including Conv2d and BatchNorm2d, supporting ``qat`` mode A fused :class:`~.Module` including Conv2d, BatchNorm2d. Could be replaced
and ``normal`` mode. with :class:`~.QATModule` version :class:`~.qat.conv_bn_relu.ConvBn2d` using
:func:`~.quantize.quantize_qat`.
""" """
def forward_qat(self, inp):
return self.apply_fakequant_with_observer(
self.calc_conv_bn_qat(inp), self.act_fake_quant, self.act_observer
)
def forward(self, inp): def forward(self, inp):
return self.bn(self.conv(inp)) return self.bn(self.conv(inp))
class ConvBnRelu2d(_ConvBn2d): class ConvBnRelu2d(_ConvBnActivation2d):
r""" r"""
A fused :class:`~.QATModule` including Conv2d, BatchNorm2d and relu, supporting ``qat`` A fused :class:`~.Module` including Conv2d, BatchNorm2d and relu. Could be replaced
mode and ``normal`` mode. with :class:`~.QATModule` version :class:`~.qat.conv_bn_relu.ConvBnRelu2d` using
:func:`~.quantize.quantize_qat`.
""" """
def forward_qat(self, inp):
return self.apply_fakequant_with_observer(
relu(self.calc_conv_bn_qat(inp)), self.act_fake_quant, self.act_observer
)
def forward(self, inp): def forward(self, inp):
return relu(self.bn(self.conv(inp))) return relu(self.bn(self.conv(inp)))
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
from .. import _internal as mgb from .. import _internal as mgb
from ..core import Tensor, wrap_io_tensor from ..core import Tensor, wrap_io_tensor
from ..core.graph import _use_default_if_none from ..core.graph import _use_default_if_none
from .module import QATModule from .module import Module
@wrap_io_tensor @wrap_io_tensor
...@@ -22,10 +22,10 @@ def _elemwise_func(mode, *inputs, **kwargs) -> Tensor: ...@@ -22,10 +22,10 @@ def _elemwise_func(mode, *inputs, **kwargs) -> Tensor:
return mgb.opr.elemwise(*inputs, mode=mode, **kwargs) return mgb.opr.elemwise(*inputs, mode=mode, **kwargs)
class Elemwise(QATModule): class Elemwise(Module):
r""" r"""
A :class:`~.QATModule` to do elemwise operator, should functional operator with this module, A :class:`~.Module` to do elemwise operator. Could be replaced with :class:`~.QATModule`
supporting ``qat`` mode and ``normal`` mode. version :class:`~.qat.elemwise.Elemwise` using :func:`~.quantize.quantize_qat`.
:param method: the elemwise method, support the following string. :param method: the elemwise method, support the following string.
It will do the normal elemwise operator for float. It will do the normal elemwise operator for float.
...@@ -88,8 +88,3 @@ class Elemwise(QATModule): ...@@ -88,8 +88,3 @@ class Elemwise(QATModule):
def forward(self, *inps): def forward(self, *inps):
return _elemwise_func(self.method, *inps) return _elemwise_func(self.method, *inps)
def forward_qat(self, *inps):
return self.apply_fakequant_with_observer(
self.forward(*inps), self.act_fake_quant, self.act_observer,
)
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
# #
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
...@@ -11,10 +10,10 @@ import numpy as np ...@@ -11,10 +10,10 @@ import numpy as np
from .. import functional as F from .. import functional as F
from ..core import Parameter from ..core import Parameter
from . import init from . import init
from .module import QATModule from .module import Module
class Linear(QATModule): class Linear(Module):
r"""Applies a linear transformation to the input. For instance, if input r"""Applies a linear transformation to the input. For instance, if input
is x, then output y is: is x, then output y is:
...@@ -60,13 +59,3 @@ class Linear(QATModule): ...@@ -60,13 +59,3 @@ class Linear(QATModule):
def forward(self, x): def forward(self, x):
return self._calc_linear(x, self.weight, self.bias) return self._calc_linear(x, self.weight, self.bias)
def forward_qat(self, x):
w_qat = self.apply_fakequant_with_observer(
self.weight, self.weight_fake_quant, self.weight_observer
)
return self.apply_fakequant_with_observer(
self._calc_linear(x, w_qat, self.bias),
self.act_fake_quant,
self.act_observer,
)
...@@ -7,7 +7,6 @@ ...@@ -7,7 +7,6 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from collections import OrderedDict from collections import OrderedDict
from enum import Enum
from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union
import numpy as np import numpy as np
...@@ -443,98 +442,3 @@ class Module(metaclass=ABCMeta): ...@@ -443,98 +442,3 @@ class Module(metaclass=ABCMeta):
loaded.append(k) loaded.append(k)
return set(loaded), set(skipped) return set(loaded), set(skipped)
class QATModule(Module):
r"""
Base class of quantization related Module. Add extra forward methods
:meth:`~.QATModule.forward_qat` and :meth:`~.QATModule.forward_quantized` for
``qat``(quantization aware training) mode and ``quantized`` mode respectively.
Use :meth:`~.QATModule.quant` to switch between ``QAT`` and ``NORMAL`` mode,
and use :meth:`~.QATModule.to_quantized` to switch to ``quantized`` mode,
which is irreversible.
If you want to recursively switch mode for all QATModule in network, use
functions in :mod:`~.quantization.quantize`.
"""
class QATMode(Enum):
DISABLED = 1
QAT = 2
CALIBRATION = 3
def __init__(self):
from ..quantization import (
QConfig,
FakeQuantize,
Observer,
) # pylint: disable=all
super().__init__()
self.quantizing = self.QATMode.DISABLED
self.scale = None
self.weight_observer = None # type: Observer
self.act_observer = None # type: Observer
self.weight_fake_quant = None # type: FakeQuantize
self.act_fake_quant = None # type: FakeQuantize
def set_qconfig(self, qconfig: "QConfig"):
self.weight_observer = qconfig.weight_observer()
self.act_observer = qconfig.act_observer()
self.weight_fake_quant = (
None
if qconfig.fake_quant is None
else qconfig.fake_quant(self.weight_observer.dtype)
)
self.act_fake_quant = (
None
if qconfig.fake_quant is None
else qconfig.fake_quant(self.act_observer.dtype)
)
def apply_observer(self, target: Tensor, obs: "Observer"):
return obs(target)
def apply_fakequant_with_observer(
self, target: Tensor, fq: "FakeQuantize", obs: "Observer"
):
oup = self.apply_observer(target, obs)
if fq is not None:
q_dict = obs.get_qparams()
oup = fq(oup, q_dict)
return oup
def set_qat_mode(self, mode: QATMode):
r"""
Change ``self.quantizing`` mode, available values: ``self.QATMode.DISABLED``,
``QAT``,``CALIBRATION``.
"""
if not isinstance(mode, self.QATMode):
raise TypeError("mode must be QATMode Enum type")
self.quantizing = mode
def to_quantized(self):
r"""
Return a new :class:`~.Module` with quantized parameters of ``self``
according to scale and zero_point in ``self.xxx_observer``.
"""
raise NotImplementedError(
"Use megengine.quantization.quantize to register the method."
)
@abstractmethod
def forward_qat(self, *args, **kwargs):
r"""
Forward method for ``qat`` mode.
"""
def __call__(self, *args, **kwargs):
if self.quantizing == self.QATMode.DISABLED:
return self.forward(*args, **kwargs)
else:
return self.forward_qat(*args, **kwargs)
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .concat import Concat
from .conv_bn_relu import ConvBn2d, ConvBnRelu2d
from .elemwise import Elemwise
from .linear import Linear
from .module import QATModule
from .quant_dequant import DequantStub, QuantStub
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Iterable
from ...core.tensor import Tensor
from .. import concat as Float
from .module import QATModule
class Concat(Float.Concat, QATModule):
r"""
A :class:`~.QATModule` to do functional concat with QAT support.
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`.
"""
def forward(self, inps: Iterable[Tensor], axis: int = 0):
return self.apply_quant_activation(super().forward(inps, axis))
@classmethod
def from_float_module(cls, float_module):
r"""
Return a :class:`~.QATModule` instance converted from
a float :class:`~.Module` instance.
"""
return cls()
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ...core import ones, zeros
from ...functional import add_update, relu, sqrt, sum, zero_grad
from .. import conv_bn_relu as Float
from .module import QATModule
class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule):
def get_batch_mean_var(self, inp):
def _sum_channel(inp, axis=0, keepdims=True):
if isinstance(axis, int):
out = sum(inp, axis=axis, keepdims=keepdims)
elif isinstance(axis, tuple):
for idx, elem in enumerate(axis):
out = sum(inp if idx == 0 else out, axis=elem, keepdims=keepdims)
return out
sum1 = _sum_channel(inp, (0, 2, 3))
sum2 = _sum_channel(inp ** 2, (0, 2, 3))
reduce_size = inp.shapeof().prod() / inp.shapeof(1)
batch_mean = sum1 / reduce_size
batch_var = (sum2 - sum1 ** 2 / reduce_size) / reduce_size
return batch_mean, batch_var
def fold_weight_bias(self, bn_mean, bn_var):
# get fold bn conv param
# bn_istd = 1 / bn_std
# w_fold = gamma / bn_std * W
# b_fold = gamma * (b - bn_mean) / bn_std + beta
gamma = self.bn.weight
if gamma is None:
gamma = ones((self.bn.num_features), dtype="float32")
gamma = gamma.reshape(1, -1, 1, 1)
beta = self.bn.bias
if beta is None:
beta = zeros((self.bn.num_features), dtype="float32")
beta = beta.reshape(1, -1, 1, 1)
if bn_mean is None:
bn_mean = zeros((1, self.bn.num_features, 1, 1), dtype="float32")
if bn_var is None:
bn_var = ones((1, self.bn.num_features, 1, 1), dtype="float32")
conv_bias = self.conv.bias
if conv_bias is None:
conv_bias = zeros(self.conv._infer_bias_shape(), dtype="float32")
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps)
# bn_istd = 1 / bn_std
# w_fold = gamma / bn_std * W
scale_factor = gamma * bn_istd
if self.conv.groups == 1:
w_fold = self.conv.weight * scale_factor.reshape(-1, 1, 1, 1)
else:
w_fold = self.conv.weight * scale_factor.reshape(
self.conv.groups, -1, 1, 1, 1
)
# b_fold = gamma * (b - bn_mean) / bn_std + beta
b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd
return w_fold, b_fold
def update_running_mean_and_running_var(
self, bn_mean, bn_var, num_elements_per_channel
):
# update running mean and running var. no grad, use unbiased bn var
bn_mean = zero_grad(bn_mean)
bn_var = (
zero_grad(bn_var)
* num_elements_per_channel
/ (num_elements_per_channel - 1)
)
exponential_average_factor = 1 - self.bn.momentum
add_update(
self.bn.running_mean,
delta=bn_mean,
alpha=1 - exponential_average_factor,
beta=exponential_average_factor,
)
add_update(
self.bn.running_var,
delta=bn_var,
alpha=1 - exponential_average_factor,
beta=exponential_average_factor,
)
def calc_conv_bn_qat(self, inp, approx=True):
if self.training and not approx:
conv = self.conv(inp)
bn_mean, bn_var = self.get_batch_mean_var(conv)
num_elements_per_channel = conv.shapeof().prod() / conv.shapeof(1)
self.update_running_mean_and_running_var(
bn_mean, bn_var, num_elements_per_channel
)
else:
bn_mean, bn_var = self.bn.running_mean, self.bn.running_var
# get gamma and beta in BatchNorm
gamma = self.bn.weight
if gamma is None:
gamma = ones((self.bn.num_features), dtype="float32")
gamma = gamma.reshape(1, -1, 1, 1)
beta = self.bn.bias
if beta is None:
beta = zeros((self.bn.num_features), dtype="float32")
beta = beta.reshape(1, -1, 1, 1)
# conv_bias
conv_bias = self.conv.bias
if conv_bias is None:
conv_bias = zeros(self.conv._infer_bias_shape(), dtype="float32")
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps)
# bn_istd = 1 / bn_std
# w_fold = gamma / bn_std * W
scale_factor = gamma * bn_istd
if self.conv.groups == 1:
w_fold = self.conv.weight * scale_factor.reshape(-1, 1, 1, 1)
else:
w_fold = self.conv.weight * scale_factor.reshape(
self.conv.groups, -1, 1, 1, 1
)
b_fold = None
if not (self.training and approx):
# b_fold = gamma * (conv_bias - bn_mean) / bn_std + beta
b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd
w_qat = self.apply_quant_weight(w_fold)
conv = self.conv.calc_conv(inp, w_qat, b_fold)
if not (self.training and approx):
return conv
# rescale conv to get original conv output
orig_conv = conv / scale_factor.reshape(1, -1, 1, 1)
if self.conv.bias is not None:
orig_conv = orig_conv + self.conv.bias
# calculate batch norm
bn_mean, bn_var = self.get_batch_mean_var(orig_conv)
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps)
conv = gamma * bn_istd * (orig_conv - bn_mean) + beta
num_elements_per_channel = conv.shapeof().prod() / conv.shapeof(1)
self.update_running_mean_and_running_var(
bn_mean, bn_var, num_elements_per_channel
)
return conv
@classmethod
def from_float_module(cls, float_module: Float._ConvBnActivation2d):
r"""
Return a :class:`~.QATModule` instance converted from
a float :class:`~.Module` instance.
"""
qat_module = cls(
float_module.conv.in_channels,
float_module.conv.out_channels,
float_module.conv.kernel_size,
float_module.conv.stride,
float_module.conv.padding,
float_module.conv.dilation,
float_module.conv.groups,
bool(float_module.conv.bias),
float_module.conv.conv_mode.name,
float_module.conv.compute_mode.name,
)
qat_module.conv.weight = float_module.conv.weight
qat_module.conv.bias = float_module.conv.bias
qat_module.bn = float_module.bn
return qat_module
class ConvBn2d(_ConvBnActivation2d):
r"""
A fused :class:`~.QATModule` including Conv2d, BatchNorm2d with QAT support.
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`.
"""
def forward(self, inp):
return self.apply_quant_activation(self.calc_conv_bn_qat(inp))
class ConvBnRelu2d(_ConvBnActivation2d):
r"""
A fused :class:`~.QATModule` including Conv2d, BatchNorm2d and relu with QAT support.
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`.
"""
def forward(self, inp):
return self.apply_quant_activation(relu(self.calc_conv_bn_qat(inp)))
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .. import elemwise as Float
from .module import QATModule
class Elemwise(Float.Elemwise, QATModule):
r"""
A :class:`~.QATModule` to do elemwise operator with QAT support.
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`.
:param method: the elemwise method, see :class:`~.module.elemwise.Elemwise` for detail.
"""
def forward(self, *inps):
return self.apply_quant_activation(super().forward(*inps))
@classmethod
def from_float_module(cls, float_module: Float.Elemwise):
r"""
Return a :class:`~.QATModule` instance converted from
a float :class:`~.Module` instance.
"""
return cls(float_module.method.name)
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .. import linear as Float
from .module import QATModule
class Linear(Float.Linear, QATModule):
r"""
A :class:`~.QATModule` version of :class:`~.module.linear.Linear`.
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`.
:param in_features: size of each input sample.
:param out_features: size of each output sample.
:param bias: If set to ``False``, the layer will not learn an additive bias.
Default: ``True``
"""
def forward(self, x):
w_qat = self.apply_quant_weight(self.weight)
return self.apply_quant_activation(self._calc_linear(x, w_qat, self.bias),)
@classmethod
def from_float_module(cls, float_module: Float.Linear):
r"""
Return a :class:`~.QATModule` instance converted from
a float :class:`~.Module` instance.
"""
qmod = cls(float_module.in_features, float_module.out_features)
qmod.weight = float_module.weight
qmod.bias = float_module.bias
return qmod
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from abc import abstractmethod
from ...core import Tensor
from ...quantization import FakeQuantize, Observer, QConfig
from ..module import Module
class QATModule(Module):
r"""
Base class of quantized-float related Module, basically for QAT and Calibration.
Use :meth:`~.QATModule.from_float_module` to generate a instance from float :class:`~.Module`.
Or use :func:`~.quantize.quantize_qat` to do it recursively and automatically.
Can also be converted to :class:`~.QuantizedModule` for deployment using
:func:`~.quantize.quantize` further.
"""
def __init__(self):
super().__init__()
self.scale = None
self.weight_observer = None # type: Observer
self.act_observer = None # type: Observer
self.weight_fake_quant = None # type: FakeQuantize
self.act_fake_quant = None # type: FakeQuantize
def set_qconfig(self, qconfig: QConfig):
r"""
Set quantization related configs with ``qconfig``, including
observer and fake_quant for weight and activation.
"""
self.weight_observer = qconfig.weight_observer()
self.act_observer = qconfig.act_observer()
if qconfig.fake_quant is None:
self.weight_fake_quant = None
self.act_fake_quant = None
else:
self.weight_fake_quant = qconfig.fake_quant(self.weight_observer.dtype)
self.act_fake_quant = qconfig.fake_quant(self.act_observer.dtype)
def _apply_fakequant_with_observer(
self, target: Tensor, fake_quant: FakeQuantize, observer: Observer
):
oup = observer(target)
if fake_quant is None:
return oup
else:
q_dict = observer.get_qparams()
return fake_quant(oup, q_dict)
def apply_quant_weight(self, target: Tensor):
r"""
Apply weight's observer and fake_quant from ``qconfig`` on ``target``.
"""
return self._apply_fakequant_with_observer(
target, self.weight_fake_quant, self.weight_observer
)
def apply_quant_activation(self, target: Tensor):
r"""
Apply weight's observer and fake_quant from ``qconfig`` on ``target``.
"""
return self._apply_fakequant_with_observer(
target, self.act_fake_quant, self.act_observer
)
def get_weight_dtype(self):
r"""
Get weight's quantization dtype as the method from ``qconfig``.
"""
return self.weight_observer.get_dtype()
def get_activation_dtype(self):
r"""
Get activation's quantization dtype as the method from ``qconfig``.
"""
return self.act_observer.get_dtype()
@classmethod
@abstractmethod
def from_float_module(cls, float_module: Module):
r"""
Return a :class:`~.QATModule` instance converted from
a float :class:`~.Module` instance.
"""
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .. import quant_dequant as Float
from .module import QATModule
class QuantStub(Float.QuantStub, QATModule):
r"""
A helper QATModule simply return input, but will quantize
input after converted to :class:`~.QuantizedModule`.
"""
def forward(self, inp):
return self.apply_quant_activation(inp)
@classmethod
def from_float_module(cls, float_module: Float.QuantStub):
r"""
Return a :class:`~.QATModule` instance converted from
a float :class:`~.Module` instance.
"""
return cls()
class DequantStub(Float.DequantStub, QATModule):
r"""
A helper QATModule simply return input, but will de-quantize
input after converted to :class:`~.QuantizedModule`.
"""
def forward(self, inp):
return inp
@classmethod
def from_float_module(cls, float_module: Float.DequantStub):
r"""
Return a :class:`~.QATModule` instance converted from
a float :class:`~.Module` instance.
"""
return cls()
...@@ -5,30 +5,24 @@ ...@@ -5,30 +5,24 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .module import QATModule from .module import Module
class QuantStub(QATModule): class QuantStub(Module):
r""" r"""
A helper QATModule doing quantize operation on input. A helper :class:`~.Module` simply returning input. Could be replaced with :class:`~.QATModule`
version :class:`~.qat.QuantStub` using :func:`~.quantize.quantize_qat`.
""" """
def forward(self, inp): def forward(self, inp):
return inp return inp
def forward_qat(self, inp):
return self.apply_fakequant_with_observer(
inp, self.act_fake_quant, self.act_observer
)
class DequantStub(Module):
class DequantStub(QATModule):
r""" r"""
A helper QATModule doing de-quantize operation on input. A helper :class:`~.Module` simply returning input. Could be replaced with :class:`~.QATModule`
version :class:`~.qat.DequantStub` using :func:`~.quantize.quantize_qat`.
""" """
def forward(self, inp): def forward(self, inp):
return inp return inp
def forward_qat(self, inp):
return inp
...@@ -9,4 +9,5 @@ from .concat import Concat ...@@ -9,4 +9,5 @@ from .concat import Concat
from .conv_bn_relu import ConvBn2d, ConvBnRelu2d from .conv_bn_relu import ConvBn2d, ConvBnRelu2d
from .elemwise import Elemwise from .elemwise import Elemwise
from .linear import Linear from .linear import Linear
from .module import QuantizedModule
from .quant_dequant import DequantStub, QuantStub from .quant_dequant import DequantStub, QuantStub
...@@ -7,17 +7,15 @@ ...@@ -7,17 +7,15 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Iterable from typing import Iterable
from ... import _internal as mgb
from ... import functional as F from ... import functional as F
from ... import module as Float
from ...core.tensor import Tensor from ...core.tensor import Tensor
from ...quantization.utils import register_method_to_class from ..qat import concat as QAT
from ..module import Module from .module import QuantizedModule
class Concat(Module): class Concat(QuantizedModule):
r""" r"""
A :class:`~.Module` to do quantized concat, inference only. A :class:`~.QuantizedModule` to do quantized concat, inference only.
""" """
def __init__(self, dtype=None): def __init__(self, dtype=None):
...@@ -25,16 +23,13 @@ class Concat(Module): ...@@ -25,16 +23,13 @@ class Concat(Module):
self.output_dtype = dtype self.output_dtype = dtype
def forward(self, inps: Iterable[Tensor], axis: int = 0): def forward(self, inps: Iterable[Tensor], axis: int = 0):
if self.training:
raise ValueError("quantized module only support inference.")
new_inps = (x.astype(self.output_dtype) for x in inps) new_inps = (x.astype(self.output_dtype) for x in inps)
return F.concat(new_inps, axis) return F.concat(new_inps, axis)
@classmethod
@register_method_to_class(Float.Concat) def from_qat_module(cls, qat_module: QAT.Concat):
def to_quantized(float_module): r"""
r""" return a :class:`~.QuantizedModule` instance converted from a
Replace :class:`~.module.QATModule`'s ``to_quantized`` method. :class:`~.QATModule` instance.
implemented here to avoid circular import. """
""" return cls(qat_module.get_activation_dtype())
return Concat(float_module.act_observer.get_dtype())
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from functools import partial
from typing import Tuple, Union from typing import Tuple, Union
import megengine._internal as mgb import megengine._internal as mgb
...@@ -13,11 +12,11 @@ import megengine._internal as mgb ...@@ -13,11 +12,11 @@ import megengine._internal as mgb
from ... import module as Float from ... import module as Float
from ...core import Parameter from ...core import Parameter
from ...functional import conv_bias_activation from ...functional import conv_bias_activation
from ...module import Conv2d from ..qat import conv_bn_relu as QAT
from ...quantization.utils import register_method_to_class from .module import QuantizedModule
class _ConvBnActivation2d(Conv2d): class _ConvBnActivation2d(Float.Conv2d, QuantizedModule):
r"""Applies a 2D convolution over an quantized input tensor, inference only. r"""Applies a 2D convolution over an quantized input tensor, inference only.
The parameter is same with :class: `~.Conv2d` The parameter is same with :class: `~.Conv2d`
...@@ -68,44 +67,41 @@ class _ConvBnActivation2d(Conv2d): ...@@ -68,44 +67,41 @@ class _ConvBnActivation2d(Conv2d):
nonlinear_mode=nonlinear_mode, nonlinear_mode=nonlinear_mode,
) )
@classmethod
def from_qat_module(cls, qat_module: QAT._ConvBnActivation2d):
r"""
return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance.
"""
output_dtype = qat_module.get_activation_dtype()
qconv = cls(
qat_module.conv.in_channels,
qat_module.conv.out_channels,
qat_module.conv.kernel_size,
qat_module.conv.stride,
qat_module.conv.padding,
qat_module.conv.dilation,
qat_module.conv.groups,
dtype=output_dtype,
)
w_fold, b_fold = qat_module.fold_weight_bias(
qat_module.bn.running_mean, qat_module.bn.running_var
)
weight = w_fold.astype(qat_module.get_weight_dtype())
qconv.weight = Parameter(weight.numpy())
qconv.bias = Parameter(b_fold.numpy())
return qconv
class ConvBn2d(_ConvBnActivation2d): class ConvBn2d(_ConvBnActivation2d):
r"""quantized version of :class:`~.qat.conv_bn_relu.ConvBn2d`."""
def forward(self, inp): def forward(self, inp):
if self.training:
raise ValueError("quantized module only support inference.")
return self.calc_conv_quantized(inp, nonlinear_mode="IDENTITY") return self.calc_conv_quantized(inp, nonlinear_mode="IDENTITY")
class ConvBnRelu2d(_ConvBnActivation2d): class ConvBnRelu2d(_ConvBnActivation2d):
r"""quantized version of :class:`~.qat.conv_bn_relu.ConvBnRelu2d`."""
def forward(self, inp): def forward(self, inp):
if self.training:
raise ValueError("quantized module only support inference.")
return self.calc_conv_quantized(inp, nonlinear_mode="RELU") return self.calc_conv_quantized(inp, nonlinear_mode="RELU")
def to_quantized(quantized_class, float_module):
output_dtype = float_module.act_observer.get_dtype()
qconv = quantized_class(
float_module.conv.in_channels,
float_module.conv.out_channels,
float_module.conv.kernel_size,
float_module.conv.stride,
float_module.conv.padding,
float_module.conv.dilation,
float_module.conv.groups,
dtype=output_dtype,
)
w_fold, b_fold = float_module.fold_weight_bias(
float_module.bn.running_mean, float_module.bn.running_var
)
weight = w_fold.astype(float_module.weight_observer.get_dtype())
qconv.weight = Parameter(weight.numpy())
qconv.bias = Parameter(b_fold.numpy())
return qconv
# replace :class:`~.module.QATModule`'s ``to_quantized`` method.
# implemented here to avoid circular import.
register_method_to_class(Float.ConvBn2d)(partial(to_quantized, ConvBn2d))
register_method_to_class(Float.ConvBnRelu2d)(partial(to_quantized, ConvBnRelu2d))
...@@ -6,11 +6,10 @@ ...@@ -6,11 +6,10 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ... import _internal as mgb from ... import _internal as mgb
from ... import module as Float
from ...core import Tensor, wrap_io_tensor from ...core import Tensor, wrap_io_tensor
from ...core.graph import _use_default_if_none from ...core.graph import _use_default_if_none
from ...quantization.utils import register_method_to_class from ..qat import elemwise as QAT
from ..module import Module from .module import QuantizedModule
@wrap_io_tensor @wrap_io_tensor
...@@ -24,13 +23,8 @@ def _elemwise_multi_type(mode, *inputs, **kwargs) -> Tensor: ...@@ -24,13 +23,8 @@ def _elemwise_multi_type(mode, *inputs, **kwargs) -> Tensor:
return mgb.opr.elemwise_multi_type(*inputs, mode=mode, **kwargs) return mgb.opr.elemwise_multi_type(*inputs, mode=mode, **kwargs)
class Elemwise(Module): class Elemwise(QuantizedModule):
r""" r"""quantized version of :class:`~.qat.elemwise.Elemwise`."""
quantized module for elemwise operator, inference only.
:param method: the elemwise method, supported string refer to :class:`~.module.elemwise.Elemwise`.
it will do quantized operator with specified output quantized dtype.
"""
_elemwise_multi_type_mode = mgb.opr_param_defs.ElemwiseMultiType.Mode _elemwise_multi_type_mode = mgb.opr_param_defs.ElemwiseMultiType.Mode
...@@ -44,11 +38,10 @@ class Elemwise(Module): ...@@ -44,11 +38,10 @@ class Elemwise(Module):
raise ValueError("quantized module only support inference.") raise ValueError("quantized module only support inference.")
return _elemwise_multi_type(self.method, *inps, dtype=self.output_dtype) return _elemwise_multi_type(self.method, *inps, dtype=self.output_dtype)
@classmethod
@register_method_to_class(Float.Elemwise) def from_qat_module(cls, qat_module: QAT.Elemwise):
def to_quantized(float_module): r"""
r""" return a :class:`~.QuantizedModule` instance converted from a
Replace :class:`~.module.QATModule`'s ``to_quantized`` method. :class:`~.QATModule` instance.
implemented here to avoid circular import. """
""" return cls(qat_module.method.name, qat_module.get_activation_dtype())
return Elemwise(float_module.method.name, float_module.act_observer.get_dtype())
...@@ -10,19 +10,13 @@ import numpy as np ...@@ -10,19 +10,13 @@ import numpy as np
import megengine._internal as mgb import megengine._internal as mgb
from ... import functional as F from ... import functional as F
from ... import module as Float
from ...core import Parameter from ...core import Parameter
from ...quantization.utils import register_method_to_class from ..qat import linear as QAT
from ..module import Module from .module import QuantizedModule
class Linear(Module): class Linear(QuantizedModule):
r"""Applies a quantized linear transformation to the input. The module r"""quantized version of :class:`~.qat.linear.Linear`."""
usually convert from QAT module by to_quantized method.
:param dtype: output data type.
"""
def __init__( def __init__(
self, dtype: np.dtype = None, self, dtype: np.dtype = None,
...@@ -44,17 +38,16 @@ class Linear(Module): ...@@ -44,17 +38,16 @@ class Linear(Module):
None if self.bias is None else self.bias.astype(bias_dtype), None if self.bias is None else self.bias.astype(bias_dtype),
).astype(self.output_dtype) ).astype(self.output_dtype)
@classmethod
@register_method_to_class(Float.Linear) def from_qat_module(cls, qat_module: QAT.Linear):
def to_quantized(float_module): r"""
r""" return a :class:`~.QuantizedModule` instance converted from a
Replace :class:`~.module.QATModule`'s ``to_quantized`` method. :class:`~.QATModule` instance.
implemented here to avoid circular import. """
""" output_dtype = qat_module.get_activation_dtype()
output_dtype = float_module.act_observer.get_dtype() qmod = cls(dtype=output_dtype)
qmod = Linear(dtype=output_dtype,) weight = qat_module.weight.astype(qat_module.get_weight_dtype())
weight = float_module.weight.astype(float_module.weight_observer.get_dtype()) qmod.weight = Parameter(weight.numpy())
qmod.weight = Parameter(weight.numpy()) if qat_module.bias is not None:
if float_module.bias is not None: qmod.bias = Parameter(qat_module.bias.numpy())
qmod.bias = Parameter(float_module.bias.numpy()) return qmod
return qmod
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from abc import abstractmethod
from ..module import Module
from ..qat import QATModule
class QuantizedModule(Module):
r"""
Base class of quantized Module, which should be converted from QATModule
and not support traning.
"""
def __call__(self, *inputs, **kwargs):
if self.training:
raise ValueError("quantized module only support inference.")
return super().__call__(*inputs, **kwargs)
@classmethod
@abstractmethod
def from_qat_module(cls, qat_module: QATModule):
r"""
return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance.
"""
...@@ -5,15 +5,14 @@ ...@@ -5,15 +5,14 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ... import _internal as mgb from ..qat import quant_dequant as QAT
from ... import module as Float from .module import QuantizedModule
from ...quantization.utils import register_method_to_class
from ..module import Module
class QuantStub(Module): class QuantStub(QuantizedModule):
r""" r"""
A helper quantize operation on input and inference only. quantized version of :class:`~.qat.quant_dequant.QuantStub`,
will convert input to quantized dtype.
""" """
def __init__(self, dtype=None): def __init__(self, dtype=None):
...@@ -21,35 +20,30 @@ class QuantStub(Module): ...@@ -21,35 +20,30 @@ class QuantStub(Module):
self.output_dtype = dtype self.output_dtype = dtype
def forward(self, inp): def forward(self, inp):
if self.training:
raise ValueError("quantized module only support inference.")
return inp.astype(self.output_dtype) return inp.astype(self.output_dtype)
@classmethod
def from_qat_module(cls, qat_module: QAT.QuantStub):
r"""
return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance.
"""
return cls(qat_module.get_activation_dtype())
class DequantStub(Module):
class DequantStub(QuantizedModule):
r""" r"""
A helper de-quantize operation and inference only. quantized version of :class:`~.qat.quant_dequant.DequantStub`,
will restore quantized input to float32 dtype.
""" """
def forward(self, inp): def forward(self, inp):
if self.training:
raise ValueError("quantized module only support inference.")
return inp.astype("float32") return inp.astype("float32")
@classmethod
@register_method_to_class(Float.QuantStub) def from_qat_module(cls, qat_module: QAT.DequantStub):
def to_quantized(float_module): r"""
r""" return a :class:`~.QuantizedModule` instance converted from a
Replace :class:`~.module.QATModule`'s ``to_quantized`` method. :class:`~.QATModule` instance.
implemented here to avoid circular import. """
""" return cls()
return QuantStub(float_module.act_observer.get_dtype())
@register_method_to_class(Float.DequantStub)
def to_quantized(float_module):
r"""
Replace :class:`~.module.QATModule`'s ``to_quantized`` method.
implemented here to avoid circular import.
"""
return DequantStub()
...@@ -13,12 +13,3 @@ from .qconfig import ( ...@@ -13,12 +13,3 @@ from .qconfig import (
ema_fakequant_qconfig, ema_fakequant_qconfig,
min_max_fakequant_qconfig, min_max_fakequant_qconfig,
) )
from .quantize import (
disable_fake_quant,
disable_observer,
enable_fake_quant,
enable_observer,
quantize,
quantize_calibration,
quantize_qat,
)
...@@ -15,16 +15,12 @@ from .observer import ( ...@@ -15,16 +15,12 @@ from .observer import (
class QConfig: class QConfig:
""" r"""
A config class indicating how to do quantize toward :class:`~.QATModule`'s A config class indicating how to do quantize toward :class:`~.QATModule`'s
``activation`` and ``weight``. ``activation`` and ``weight``. See :meth:`~.QATModule.set_qconfig` for detail usage.
And ``fake_quant`` parameter to indicate
See :meth:`~.QATModule.set_qconfig` for detail usage.
:param weight_observer: interface to instantiate an :class:`~.Observer` indicating :param weight_observer: interface to instantiate an :class:`~.Observer` indicating
- how to collect scales and zero_point of wegiht. how to collect scales and zero_point of wegiht.
:param act_observer: similar to ``weight_observer`` but toward activation. :param act_observer: similar to ``weight_observer`` but toward activation.
:param fake_quant: interface to instantiate a :class:`~.FakeQuantize` indicating :param fake_quant: interface to instantiate a :class:`~.FakeQuantize` indicating
how to do fake_quant calculation. can be invoked multi times to get different how to do fake_quant calculation. can be invoked multi times to get different
......
...@@ -6,68 +6,125 @@ ...@@ -6,68 +6,125 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from copy import deepcopy from copy import deepcopy
from typing import Dict, Tuple
from ..module import Module, QATModule, Sequential, quantized
from .. import module as Float
from ..module import Module
from ..module import qat as QAT
from ..module import quantized as Quantized
from ..module.qat import QATModule
from ..module.quantized import QuantizedModule
from .qconfig import QConfig, ema_fakequant_qconfig from .qconfig import QConfig, ema_fakequant_qconfig
def _get_quantable_module_names():
def is_quantable(key: str):
value = getattr(Quantized, key)
return (
isinstance(value, type)
and issubclass(value, QuantizedModule)
and value != QuantizedModule
)
# source should have all quantable modules' names
quantable_module_names = [key for key in dir(Quantized) if is_quantable(key)]
return quantable_module_names
def _get_convert_dict() -> Tuple[
Dict[Module, QATModule], Dict[QATModule, QuantizedModule]
]:
quantable_module_names = _get_quantable_module_names()
quantable_modules = [getattr(Float, key) for key in quantable_module_names]
qat_modules = [getattr(QAT, key) for key in quantable_module_names]
quantized_modules = [getattr(Quantized, key) for key in quantable_module_names]
float2qat_dict = dict(zip(quantable_modules, qat_modules))
qat2quantized_dict = dict(zip(qat_modules, quantized_modules))
return float2qat_dict, qat2quantized_dict
_float2qat_dict, _qat2quantized_dict = _get_convert_dict()
def quantize(module: Module, inplace=True): def quantize(module: Module, inplace=True):
r""" r"""
Recursively convert `module` to `quantized` mode through :meth:`~.Module.apply`. Recursively convert :class:`~.QATModule` to :class:`~.QuantizedModule`
through :meth:`~.Module.apply`.
:param module: root module to do convert recursively. :param module: root module to do convert recursively.
:param inplace: whether to convert submodules in-place.
""" """
if not inplace: if not inplace:
module = deepcopy(module) module = deepcopy(module)
def is_qat_module(obj): qat_modules = tuple(_qat2quantized_dict.keys())
return isinstance(obj, QATModule)
def is_qat(mod: Module):
return isinstance(mod, qat_modules)
# no need to pass prefix and get pure key of parent Module. # no need to pass prefix and get pure key of parent Module.
for key, submodule, parent in module._flatten( for key, submodule, parent in module._flatten(
with_key=True, with_parent=True, predicate=is_qat_module with_key=True, with_parent=True, predicate=is_qat
): ):
if isinstance(parent, Sequential): new_mod = _qat2quantized_dict[type(submodule)].from_qat_module(submodule)
if isinstance(parent, Float.Sequential):
# cannnot use setattr to be compatible with Sequential's ``__setitem__`` # cannnot use setattr to be compatible with Sequential's ``__setitem__``
parent[int(key.split(".")[-1])] = submodule.to_quantized() parent[int(key.split(".")[-1])] = new_mod
else: else:
setattr(parent, key.split(".")[-1], submodule.to_quantized()) setattr(parent, key.split(".")[-1], new_mod)
return module return module
def quantize_qat(module: Module, qconfig: QConfig = ema_fakequant_qconfig): def quantize_qat(
module: Module, inplace=True, qconfig: QConfig = ema_fakequant_qconfig
):
r""" r"""
Recursively convert `module` to `qat` mode through :meth:`~.Module.apply` Recursively convert float :class:`~.Module` to :class:`~.QATModule`
and set qconfig relatively. through :meth:`~.Module.apply` and set qconfig relatively.
:param module: root module to do convert recursively. :param module: root module to do convert recursively.
:param qconfig: a instance of :class:`~.QConfig` to be set as submodules' qconfig. :param inplace: whether to convert submodules in-place.
default is :any:`~.qconfig.ema_fakequant_qconfig`. :param qconfig: an instance of :class:`~.QConfig` to be set as submodules' qconfig.
default is ``ema_fakequant_qconfig``.
""" """
def fn(mod: Module): if not inplace:
if isinstance(mod, QATModule): module = deepcopy(module)
mod.set_qat_mode(QATModule.QATMode.QAT)
mod.set_qconfig(qconfig)
module.apply(fn) quantable_modules = tuple(_float2qat_dict.keys())
def is_quantable(mod: Module):
return isinstance(mod, quantable_modules)
# no need to pass prefix and get pure key of parent Module.
for key, submodule, parent in module._flatten(
with_key=True, with_parent=True, predicate=is_quantable
):
new_mod = _float2qat_dict[type(submodule)].from_float_module(submodule)
if isinstance(parent, Float.Sequential):
# cannnot use setattr to be compatible with Sequential's ``__setitem__``
parent[int(key.split(".")[-1])] = new_mod
else:
setattr(parent, key.split(".")[-1], new_mod)
propagate_qconfig(module, qconfig)
return module
def quantize_calibration(module: Module, qconfig: QConfig = ema_fakequant_qconfig): def propagate_qconfig(module: QATModule, qconfig: QConfig):
r""" r"""
Recursively convert `module` to `calibration` mode through :meth:`~.Module.apply` Recursively set ``module``'s qconfig through :meth:`~.Module.apply`.
and set qconfig relatively.
:param module: root module to do convert recursively. :param module: root module to traverse recursively.
:param qconfig: a instance of :class:`~.QConfig` to be set as submodules' qconfig. :param qconfig: a instance of :class:`~.QConfig` to be set as submodules' qconfig.
default is :any:`~.qconfig.ema_fakequant_qconfig`.
""" """
def fn(mod: Module): def fn(mod: Module):
if isinstance(mod, QATModule): if isinstance(mod, QATModule):
mod.set_qat_mode(QATModule.QATMode.CALIBRATION)
mod.set_qconfig(qconfig) mod.set_qconfig(qconfig)
module.apply(fn) module.apply(fn)
......
...@@ -5,8 +5,7 @@ import numpy as np ...@@ -5,8 +5,7 @@ import numpy as np
from megengine import tensor from megengine import tensor
from megengine.module import ConvBn2d from megengine.module import ConvBn2d
from megengine.quantization import quantize_qat from megengine.quantization.quantize import disable_fake_quant, quantize_qat
from megengine.quantization.quantize import disable_fake_quant
from megengine.test import assertTensorClose from megengine.test import assertTensorClose
...@@ -14,18 +13,17 @@ def test_convbn2d(): ...@@ -14,18 +13,17 @@ def test_convbn2d():
in_channels = 32 in_channels = 32
out_channels = 64 out_channels = 64
kernel_size = 3 kernel_size = 3
module = ConvBn2d(in_channels, out_channels, kernel_size)
quantize_qat(module)
for groups, bias in product([1, 4], [True, False]): for groups, bias in product([1, 4], [True, False]):
inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) module = ConvBn2d(
in_channels, out_channels, kernel_size, groups=groups, bias=bias
)
module.train() module.train()
qat_module = copy.deepcopy(module) qat_module = quantize_qat(module, inplace=False)
disable_fake_quant(qat_module) disable_fake_quant(qat_module)
normal_outputs = module.forward(inputs) inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
qat_outputs = qat_module.forward_qat(inputs) normal_outputs = module(inputs)
qat_outputs = qat_module(inputs)
assertTensorClose(normal_outputs, qat_outputs, max_err=5e-6) assertTensorClose(normal_outputs, qat_outputs, max_err=5e-6)
a = module.bn.running_mean.numpy()
b = qat_module.bn.running_mean.numpy()
assertTensorClose( assertTensorClose(
module.bn.running_mean, qat_module.bn.running_mean, max_err=5e-8 module.bn.running_mean, qat_module.bn.running_mean, max_err=5e-8
) )
...@@ -33,7 +31,7 @@ def test_convbn2d(): ...@@ -33,7 +31,7 @@ def test_convbn2d():
module.bn.running_var, qat_module.bn.running_var, max_err=5e-7 module.bn.running_var, qat_module.bn.running_var, max_err=5e-7
) )
module.eval() module.eval()
normal_outputs = module.forward(inputs) normal_outputs = module(inputs)
qat_module.eval() qat_module.eval()
qat_outputs = qat_module.forward_qat(inputs) qat_outputs = qat_module(inputs)
assertTensorClose(normal_outputs, qat_outputs, max_err=5e-6) assertTensorClose(normal_outputs, qat_outputs, max_err=5e-6)
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from megengine import module as Float
from megengine.module import qat as QAT
from megengine.quantization.quantize import _get_quantable_module_names
def test_get_quantable_module_names():
# need to make sure names from Quantized and QAT are the same
def _get_qat_module_names():
def is_qat(key: str):
value = getattr(QAT, key)
return (
isinstance(value, type)
and issubclass(value, QAT.QATModule)
and value != QAT.QATModule
)
# source should have all quantable modules' names
quantable_module_names = [key for key in dir(QAT) if is_qat(key)]
return quantable_module_names
qat_module_names = _get_qat_module_names()
quantized_module_names = _get_quantable_module_names()
assert set(qat_module_names) == set(quantized_module_names)
for key in qat_module_names:
value = getattr(Float, key)
assert (
isinstance(value, type)
and issubclass(value, Float.Module)
and value != Float.Module
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册