提交 8c110c39 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

feat(mge/quantization): add quantization interface

GitOrigin-RevId: 4fe1233ec3412f1db3efdf22d1eed496ff974dc2
上级 384e7578
......@@ -74,12 +74,14 @@ from .nn import (
softmax,
warp_perspective,
)
from .quantized import conv_bias_activation
from .sort import argsort, sort, top_k
from .tensor import (
add_axis,
arange,
broadcast_to,
concat,
cond_take,
dimshuffle,
gather,
linspace,
......
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# pylint: disable=too-many-lines
from typing import Tuple, Union
from .. import _internal as mgb
from ..core import Tensor, wrap_io_tensor
from ..utils.types import _pair, _pair_nonzero
from .debug_param import get_conv_execution_strategy
@wrap_io_tensor
def conv_bias_activation(
inp: Tensor,
weight: Tensor,
bias: Tensor,
dtype=None,
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
nonlinear_mode="IDENTITY",
conv_mode="CROSS_CORRELATION",
compute_mode="DEFAULT",
) -> Tensor:
""" convolution bias with activation operation, only for inference.
:param inp: The feature map of the convolution operation
:param weight: The convolution kernel
:param bias: The bias added to the result of convolution
:param stride: Stride of the 2D convolution operation. Default: 1
:param padding: Size of the paddings added to the input on both sides of its
spatial dimensions. Only zero-padding is supported. Default: 0
:param dilation: Dilation of the 2D convolution operation. Default: 1
:param groups: number of groups to divide input and output channels into,
so as to perform a "grouped convolution". When ``groups`` is not 1,
``in_channels`` and ``out_channels`` must be divisible by ``groups``,
and the shape of weight should be ``(groups, out_channel // groups,
in_channels // groups, height, width)``.
:type conv_mode: string or :class:`mgb.opr_param_defs.Convolution.Mode`
:param conv_mode: Supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default:
'CROSS_CORRELATION'.
:param dtype: Support for np.dtype, Default:
np.int8.
:param scale: scale if use quantization, Default:
0.0.
:param zero_point: scale if use quantization quint8, Default:
0.0.
:type compute_mode: string or
:class:`mgb.opr_param_defs.Convolution.ComputeMode`
:param compute_mode: When set to 'DEFAULT', no special requirements will be
placed on the precision of intermediate results. When set to 'FLOAT32',
Float32 would be used for accumulator and intermediate result, but only
effective when input and output are of Float16 dtype.
"""
ph, pw = _pair(padding)
sh, sw = _pair_nonzero(stride)
dh, dw = _pair_nonzero(dilation)
sparse_type = "DENSE" if groups == 1 else "GROUP"
res = mgb.opr.conv_bias_activation(
inp,
weight,
bias,
compute_mode=compute_mode,
dtype=dtype,
strategy=get_conv_execution_strategy(),
nonlineMode=nonlinear_mode,
sparse=sparse_type,
format="NCHW",
pad_h=ph,
pad_w=pw,
stride_h=sh,
stride_w=sw,
dilate_h=dh,
dilate_w=dw,
mode=conv_mode,
)
return res
......@@ -359,6 +359,41 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor:
return out
@wrap_io_tensor
def cond_take(mask: Tensor, x: Tensor, val=1) -> Tensor:
r"""
Take elements from data if specific condition is satisfied on mask. This operator has two outputs: the first is the elements taken, and the second is the indices corresponding to those elements; they are both 1-dimensional. High-dimension input would first be flattened.
:param mask: condition param; must be the same shape with data
:param x: input tensor from which to take elements
:param val: value to be compared to by mode
Examples:
.. testcode::
from megengine import tensor
import megengine.functional as F
mask = tensor(np.array([[1, 0], [0, 1]], dtype=np.int32))
x = tensor(np.array([[1, np.inf], [np.nan, 4]],
dtype=np.float32))
v, index = F.cond_take(mask, x, 1)
print(v, index)
Outputs:
.. testoutput::
Tensor([1. 4.]) Tensor([0 3], dtype=int32)
"""
v, index = mgb.opr.cond_take(
x, mask, mode=mgb.opr_param_defs.CondTake.Mode.EQ, val=val
)
return v, index
def shapeof(x: Tensor, axis=None):
r"""
The shape of input tensor.
......
......@@ -8,12 +8,16 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax
from .batchnorm import BatchNorm1d, BatchNorm2d
from .concat import Concat
from .conv import Conv2d, ConvTranspose2d
from .conv_bn_relu import ConvBn2d, ConvBnRelu2d
from .dropout import Dropout
from .elemwise import Elemwise
from .embedding import Embedding
from .identity import Identity
from .linear import Linear
from .module import Module
from .module import Module, QATModule
from .parampack import ParamPack
from .pooling import AvgPool2d, MaxPool2d
from .quant_dequant import DequantStub, QuantStub
from .sequential import Sequential
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Iterable
from .. import functional as F
from ..core.tensor import Tensor
from .module import QATModule
class Concat(QATModule):
r"""
A :class:`~.QATModule` to do functional concat, should replace concat with this module,
supporting ``qat`` mode and ``quantized`` mode.
"""
def forward(self, inps: Iterable[Tensor], axis: int = 0):
return F.concat(inps, axis)
def forward_qat(self, inps: Iterable[Tensor], axis: int = 0):
return self.apply_fakequant_with_observer(
self.forward(inps, axis), self.act_fake_quant, self.act_observer
)
......@@ -182,11 +182,11 @@ class Conv2d(_ConvNd):
# Assume format is NCHW
return (1, self.out_channels, 1, 1)
def forward(self, inp):
def calc_conv(self, inp, weight, bias):
return conv2d(
inp,
self.weight,
self.bias,
weight,
bias,
self.stride,
self.padding,
self.dilation,
......@@ -195,6 +195,9 @@ class Conv2d(_ConvNd):
self.compute_mode,
)
def forward(self, inp):
return self.calc_conv(inp, self.weight, self.bias)
class ConvTranspose2d(_ConvNd):
r"""Applies a 2D transposed convolution over an input tensor.
......
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Tuple, Union
from ..core import ones, zeros
from ..functional import flatten, relu, sqrt, sum
from .batchnorm import BatchNorm2d
from .conv import Conv2d
from .module import QATModule
class _ConvBn2d(QATModule):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
bias: bool = True,
conv_mode: str = "CROSS_CORRELATION",
compute_mode: str = "DEFAULT",
eps=1e-5,
momentum=0.9,
affine=True,
track_running_stats=True,
freeze_bn=False,
):
super().__init__()
self.conv = Conv2d(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
conv_mode,
compute_mode,
)
self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats)
self.freeze_bn = freeze_bn
def update_bn_stats(self):
self.freeze_bn = False
return self
def freeze_bn_stats(self):
self.freeze_bn = True
return self
def get_bn_gamma_beta(self):
if self.bn.weight is None:
gamma = ones((self.bn.num_features), dtype="float32")
else:
gamma = self.bn.weight
if self.bn.bias is None:
beta = zeros((self.bn.num_features), dtype="float32")
else:
beta = self.bn.bias
return gamma, beta
def get_batch_mean_var(self, inp):
def _sum_channel(inp, axis=0, keepdims=True):
if isinstance(axis, int):
out = sum(inp, axis=axis, keepdims=keepdims)
elif isinstance(axis, tuple):
for idx, elem in enumerate(axis):
out = sum(inp if idx == 0 else out, axis=elem, keepdims=keepdims)
return out
sum1 = _sum_channel(inp, (0, 2, 3))
sum2 = _sum_channel(inp ** 2, (0, 2, 3))
reduce_size = inp.shapeof().prod() / inp.shapeof(1)
batch_mean = sum1 / reduce_size
batch_var = (sum2 - sum1 ** 2 / reduce_size) / (reduce_size - 1)
return batch_mean, batch_var
def fold_weight_bias(self, bn_mean, bn_var):
# get fold bn conv param
# bn_istd = 1 / bn_std
# w_fold = gamma / bn_std * W
# b_fold = gamma * (b - bn_mean) / bn_std + beta
gamma, beta = self.get_bn_gamma_beta()
b = self.conv.bias
if b is None:
b = zeros(self.conv._infer_bias_shape(), dtype="float32")
if bn_mean is None:
bn_mean = zeros((1, self.bn.num_features, 1, 1), dtype="float32")
if bn_var is None:
bn_var = ones((1, self.bn.num_features, 1, 1), dtype="float32")
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps)
if self.conv.groups == 1:
w_fold = (
self.conv.weight
* gamma.reshape(-1, 1, 1, 1)
* bn_istd.reshape(-1, 1, 1, 1)
)
else:
w_fold = (
self.conv.weight
* gamma.reshape(self.conv.groups, -1, 1, 1, 1)
* bn_istd.reshape(self.conv.groups, -1, 1, 1, 1)
)
b_fold = flatten(beta) + (
flatten(gamma) * (flatten(b) - flatten(bn_mean)) * flatten(bn_istd)
)
b_fold = b_fold.reshape(self.conv._infer_bias_shape())
return w_fold, b_fold
def calc_conv_bn_qat(self, inp):
# TODO: use pytorch method as
conv = self.conv(inp)
self.bn(conv)
if self.training:
bn_mean, bn_var = self.get_batch_mean_var(conv)
else:
bn_mean, bn_var = self.bn.running_mean, self.bn.running_var
w_fold, b_fold = self.fold_weight_bias(bn_mean, bn_var)
w_qat = self.apply_fakequant_with_observer(
w_fold, self.weight_fake_quant, self.weight_observer
)
return self.conv.calc_conv(inp, w_qat, b_fold)
class ConvBn2d(_ConvBn2d):
r"""
A fused :class:`~.QATModule` including Conv2d and BatchNorm2d, supporting ``qat`` mode
and ``normal`` mode.
"""
def forward_qat(self, inp):
return self.apply_fakequant_with_observer(
self.calc_conv_bn_qat(inp), self.act_fake_quant, self.act_observer
)
def forward(self, inp):
return self.bn(self.conv(inp))
class ConvBnRelu2d(_ConvBn2d):
r"""
A fused :class:`~.QATModule` including Conv2d, BatchNorm2d and relu, supporting ``qat``
mode and ``normal`` mode.
"""
def forward_qat(self, inp):
return self.apply_fakequant_with_observer(
relu(self.calc_conv_bn_qat(inp)), self.act_fake_quant, self.act_observer
)
def forward(self, inp):
return relu(self.bn(self.conv(inp)))
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .. import _internal as mgb
from ..core import Tensor, wrap_io_tensor
from ..core.graph import _use_default_if_none
from .module import QATModule
@wrap_io_tensor
def _elemwise_func(mode, *inputs, **kwargs) -> Tensor:
if all(isinstance(i, (int, float)) for i in inputs):
device, comp_graph = _use_default_if_none(None, None)
ret = mgb.opr.elemwise(
*inputs, mode=mode, comp_node=device, comp_graph=comp_graph, **kwargs
)
return ret.inferred_value[0]
return mgb.opr.elemwise(*inputs, mode=mode, **kwargs)
class Elemwise(QATModule):
r"""
A :class:`~.QATModule` to do elemwise operator, should functional operator with this module,
supporting ``qat`` mode and ``normal`` mode.
:param method: the elemwise method, support the following string.
It will do the normal elemwise operator for float.
* "ADD": a + b
* "FUSE_ADD_RELU": max(x+y, 0)
* "MUL": x * y
* "MIN": min(x, y)
* "MAX": max(x, y)
* "SUB": x - y
* "TRUE_DIV": x / y
* "FUSE_ADD_SIGMOID": sigmoid(x + y)
* "FUSE_ADD_TANH": tanh(x + y)
* "RELU": x > 0 ? x : 0
* "ABS": x > 0 ? x : -x
* "SIGMOID": sigmoid(x)
* "EXP": exp(x)
* "TANH": tanh(x)
* "FUSE_MUL_ADD3": x * y + z
* "FAST_TANH": fast_tanh(x)
* "NEGATE": -x
* "ACOS": acos(x)
* "ASIN": asin(x)
* "CEIL": ceil(x)
* "COS": cos(x)
* "EXPM1": expm1(x)
* "FLOOR": floor(x)
* "LOG": log(x)
* "LOG1P": log1p(x)
* "SIN": sin(x)
* "ROUND": round(x)
* "ERF": erf(x)
* "ERFINV": erfinv(x)
* "ERFC": erfc(x)
* "ERFCINV": erfcinv(x)
* "ABS_GRAD": abs_grad
* "FLOOR_DIV": floor_div
* "MOD": mod
* "SIGMOID_GRAD": sigmoid_grad
* "SWITCH_GT0": switch_gt0
* "TANH_GRAD": tanh_grad
* "LT": lt
* "LEQ": leq
* "EQ": eq
* "POW": pow
* "LOG_SUM_EXP": log_sum_exp
* "FAST_TANH_GRAD": fast_tanh_grad
* "ATAN2": atan2
* "COND_LEQ_MOV": cond_leq_mov
* "H_SWISH": h_swish
* "FUSE_ADD_H_SWISH": h_swish(x+y)
* "H_SWISH_GRAD": h_swish_grad
"""
_elemwise_mode_type = mgb.opr_param_defs.Elemwise.Mode
def __init__(self, method):
super().__init__()
self.method = self._elemwise_mode_type.convert(method)
def forward(self, *inps):
return _elemwise_func(self.method, *inps)
def forward_qat(self, *inps):
return self.apply_fakequant_with_observer(
self.forward(*inps), self.act_fake_quant, self.act_observer,
)
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
......@@ -8,6 +7,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from abc import ABCMeta, abstractmethod
from collections import OrderedDict
from enum import Enum
from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union
import numpy as np
......@@ -442,3 +442,95 @@ class Module(metaclass=ABCMeta):
loaded.append(k)
return set(loaded), set(skipped)
class QATModule(Module):
r"""
Base class of quantization related Module. Add extra forward methods
:meth:`~.QATModule.forward_qat` and :meth:`~.QATModule.forward_quantized` for
``qat``(quantization aware training) mode and ``quantized`` mode respectively.
Use :meth:`~.QATModule.quant` to switch between ``QAT`` and ``NORMAL`` mode,
and use :meth:`~.QATModule.to_quantized` to switch to ``quantized`` mode,
which is irreversible.
If you want to recursively switch mode for all QATModule in network, use
functions in :mod:`~.quantization.quantize`.
"""
class QATMode(Enum):
DISABLED = 1
QAT = 2
CALIBRATION = 3
def __init__(self):
from ..quantization import (
QConfig,
FakeQuantize,
Observer,
) # pylint: disable=all
super().__init__()
self.quantizing = self.QATMode.DISABLED
self.scale = None
self.inp_observer = None # type: Observer
self.weight_observer = None # type: Observer
self.act_observer = None # type: Observer
self.weight_fake_quant = None # type: FakeQuantize
self.bias_fake_quant = None # type: FakeQuantize
self.act_fake_quant = None # type: FakeQuantize
def set_qconfig(self, qconfig: "QConfig"):
self.inp_observer = qconfig.inp_observer()
self.weight_observer = qconfig.weight_observer()
self.act_observer = qconfig.act_observer()
self.weight_fake_quant = qconfig.fake_quant(self.weight_observer.dtype)
self.bias_fake_quant = qconfig.bias_fake_quant()
self.act_fake_quant = qconfig.fake_quant(self.act_observer.dtype)
def apply_observer(self, target: Tensor, obs: "Observer"):
return obs(target)
def apply_fakequant_with_observer(
self, target: Tensor, fq: "FakeQuantize", obs: "Observer"
):
oup = self.apply_observer(target, obs)
return fq(oup, obs.scale, obs.zero_point)
def set_qat_mode(self, mode: QATMode):
r"""
Change ``self.quantizing`` mode, available values: ``self.QATMode.DISABLED``,
``QAT``,``CALIBRATION``.
"""
if not isinstance(mode, self.QATMode):
raise TypeError("mode must be QATMode Enum type")
self.quantizing = mode
def to_quantized(self):
r"""
Return a new :class:`~.Module` with quantized parameters of ``self``
according to scale and zero_point in ``self.xxx_observer``.
"""
raise NotImplementedError(
"Use megengine.quantization.quantize to register the method."
)
@abstractmethod
def forward_qat(self, *args, **kwargs):
r"""
Forward method for ``qat`` mode.
"""
def __call__(self, *args, **kwargs):
if self.quantizing == self.QATMode.QAT:
return self.forward_qat(*args, **kwargs)
elif self.quantizing == self.QATMode.CALIBRATION:
# TODO implement the CALIBRATION
assert False
return None
else:
return self.forward(*args, **kwargs)
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .module import QATModule
class QuantStub(QATModule):
r"""
A helper QATModule doing quantize operation on input.
"""
def forward(self, inp):
return inp
def forward_qat(self, inp):
return self.apply_fakequant_with_observer(
inp, self.act_fake_quant, self.act_observer
)
class DequantStub(QATModule):
r"""
A helper QATModule doing de-quantize operation on input.
"""
def forward(self, inp):
return inp
def forward_qat(self, inp):
return inp
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .concat import Concat
from .conv_bn_relu import ConvBn2d, ConvBnRelu2d
from .elemwise import Elemwise
from .quant_dequant import DequantStub, QuantStub
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Iterable
from ... import _internal as mgb
from ... import functional as F
from ... import module as Float
from ...core.tensor import Tensor
from ...quantization.utils import register_method_to_class
from ..module import Module
class Concat(Module):
r"""
A :class:`~.Module` to do quantized concat, inference only.
"""
def __init__(self):
super().__init__()
self.scale = 1.0
self.zero_point = 0.0
self.output_dtype = mgb.dtype.qint8(self.scale)
def forward(self, inps: Iterable[Tensor], axis: int = 0):
if self.training:
raise ValueError("quantized module only support inference.")
new_inps = (x.astype(self.output_dtype) for x in inps)
return F.concat(new_inps, axis)
@register_method_to_class(Float.Concat)
def to_quantized(float_module):
r"""
Replace :class:`~.module.QATModule`'s ``to_quantized`` method.
implemented here to avoid circular import.
"""
qmod = Concat()
qmod.output_dtype = float_module.act_observer.get_dtype()
qmod.scale, qmod.zero_point = float_module.act_observer.get_qparams()
return qmod
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from functools import partial
from typing import Tuple, Union
import megengine._internal as mgb
from ... import module as Float
from ...core import Parameter
from ...functional import conv_bias_activation
from ...module import Conv2d
from ...quantization.utils import register_method_to_class
class _ConvBnActivation2d(Conv2d):
r"""Applies a 2D convolution over an quantized input tensor, inference only.
The parameter is same with :class: `~.Conv2d`
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
conv_mode: str = "CROSS_CORRELATION",
compute_mode: str = "DEFAULT",
):
super().__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
True,
conv_mode,
compute_mode,
)
self.scale = 1.0
self.zero_point = 0.0
self.output_dtype = mgb.dtype.qint8(self.scale)
self.weight = self.weight.astype(self.output_dtype)
self.bias = self.bias.astype(mgb.dtype.qint32(self.scale))
def calc_conv_quantized(self, inp, nonlinear_mode="IDENTITY"):
inp_scale = mgb.dtype.get_scale(inp.dtype)
w_scale = mgb.dtype.get_scale(self.weight.dtype)
bias_scale = inp_scale * w_scale
return conv_bias_activation(
inp,
self.weight,
self.bias.astype(mgb.dtype.qint32(bias_scale)),
self.output_dtype,
self.stride,
self.padding,
self.dilation,
self.groups,
conv_mode=self.conv_mode,
compute_mode=self.compute_mode,
nonlinear_mode=nonlinear_mode,
)
class ConvBn2d(_ConvBnActivation2d):
def forward(self, inp):
if self.training:
raise ValueError("quantized module only support inference.")
return self.calc_conv_quantized(inp, nonlinear_mode="IDENTITY")
class ConvBnRelu2d(_ConvBnActivation2d):
def forward(self, inp):
if self.training:
raise ValueError("quantized module only support inference.")
return self.calc_conv_quantized(inp, nonlinear_mode="RELU")
def to_quantized(quantized_class, float_module):
qconv = quantized_class(
float_module.conv.in_channels,
float_module.conv.out_channels,
float_module.conv.kernel_size,
float_module.conv.stride,
float_module.conv.padding,
float_module.conv.dilation,
float_module.conv.groups,
)
w_fold, b_fold = float_module.fold_weight_bias(
float_module.bn.running_mean, float_module.bn.running_var
)
weight = w_fold.astype(float_module.weight_observer.get_dtype())
qconv.output_dtype = float_module.act_observer.get_dtype()
qconv.weight = Parameter(weight.numpy())
qconv.bias = Parameter(b_fold.numpy())
qconv.scale, qconv.zero_point = float_module.act_observer.get_qparams()
return qconv
# replace :class:`~.module.QATModule`'s ``to_quantized`` method.
# implemented here to avoid circular import.
register_method_to_class(Float.ConvBn2d)(partial(to_quantized, ConvBn2d))
register_method_to_class(Float.ConvBnRelu2d)(partial(to_quantized, ConvBnRelu2d))
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ... import _internal as mgb
from ... import module as Float
from ...core import Tensor, wrap_io_tensor
from ...core.graph import _use_default_if_none
from ...quantization.utils import register_method_to_class
from ..module import Module
@wrap_io_tensor
def _elemwise_multi_type(mode, *inputs, **kwargs) -> Tensor:
if all(isinstance(i, (int, float)) for i in inputs):
device, comp_graph = _use_default_if_none(None, None)
ret = mgb.opr.elemwise_multi_type(
*inputs, mode=mode, comp_node=device, comp_graph=comp_graph, **kwargs,
)
return ret.inferred_value[0]
return mgb.opr.elemwise_multi_type(*inputs, mode=mode, **kwargs)
class Elemwise(Module):
r"""
quantized module for elemwise operator, inference only.
:param method: the elemwise method, supported string refer to :class:`~.module.elemwise.Elemwise`.
it will do quantized operator with specified output quantized dtype.
"""
_elemwise_multi_type_mode = mgb.opr_param_defs.ElemwiseMultiType.Mode
def __init__(self, method):
super().__init__()
self.method = self._elemwise_multi_type_mode.convert("Q" + method)
self.scale = 1.0
self.zero_point = 0.0
self.output_dtype = mgb.dtype.qint8(self.scale)
def forward(self, *inps):
if self.training:
raise ValueError("quantized module only support inference.")
return _elemwise_multi_type(self.method, *inps, dtype=self.output_dtype)
@register_method_to_class(Float.Elemwise)
def to_quantized(float_module):
r"""
Replace :class:`~.module.QATModule`'s ``to_quantized`` method.
implemented here to avoid circular import.
"""
qmod = Elemwise(float_module.method.name)
qmod.output_dtype = float_module.act_observer.get_dtype()
qmod.scale, qmod.zero_point = float_module.act_observer.get_qparams()
return qmod
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ... import _internal as mgb
from ... import module as Float
from ...quantization.utils import register_method_to_class
from ..module import Module
class QuantStub(Module):
r"""
A helper quantize operation on input and inference only.
"""
def __init__(self):
super().__init__()
self.scale = 1.0
self.zero_point = 0.0
self.output_dtype = mgb.dtype.qint8(self.scale)
def forward(self, inp):
if self.training:
raise ValueError("quantized module only support inference.")
return inp.astype(self.output_dtype)
class DequantStub(Module):
r"""
A helper de-quantize operation and inference only.
"""
def forward(self, inp):
if self.training:
raise ValueError("quantized module only support inference.")
return inp.astype("float32")
@register_method_to_class(Float.QuantStub)
def to_quantized(float_module):
r"""
Replace :class:`~.module.QATModule`'s ``to_quantized`` method.
implemented here to avoid circular import.
"""
qmod = QuantStub()
qmod.output_dtype = float_module.act_observer.get_dtype()
qmod.scale, qmod.zero_point = float_module.act_observer.get_qparams()
return qmod
@register_method_to_class(Float.DequantStub)
def to_quantized(float_module):
r"""
Replace :class:`~.module.QATModule`'s ``to_quantized`` method.
implemented here to avoid circular import.
"""
qmod = DequantStub()
return qmod
......@@ -68,6 +68,7 @@ class Sequential(Module):
def __setitem__(self, idx, module):
key = self.layer_keys[idx]
self.layer_values[idx] = module
return setattr(self, key, module)
def __delitem__(self, idx):
......
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .fake_quant import FakeQuantize
from .observer import Observer
from .qconfig import QConfig, ema_fakequant_qconfig, min_max_fakequant_qconfig
from .quantize import quantize, quantize_qat
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .. import functional as F
from .._internal.dtype import _metadata_dict
from ..module import Module
from .observer import Round
class FakeQuantize(Module):
r"""
A module to do quant and dequant according to observer's scale and zero_point.
"""
def __init__(self, dtype: str, enable: bool = True):
super().__init__()
if not dtype in _metadata_dict.keys():
raise ValueError(
"unknown dtype: {}, only support {}".format(
dtype, _metadata_dict.keys()
)
)
self.dtype = dtype
self.qmin = _metadata_dict[dtype].qmin
self.qmax = _metadata_dict[dtype].qmax
self.enabled = enable
def enable(self):
self.enabled = True
def disable(self):
self.enabled = False
def forward(self, inp, scale, zero_point):
if self.enabled:
# Quant
oup = Round()(inp / scale) + zero_point
# clip
oup = F.minimum(F.maximum(oup, self.qmin), self.qmax)
# DeQuant
oup = (oup - zero_point) * scale
return oup
return inp
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from abc import abstractmethod
import numpy as np
from .. import functional as F
from .._internal.dtype import _metadata_dict, get_quantized_dtype
from ..core import Buffer, Function, ones, tensor, zeros
from ..module import Module
class Round(Function):
def forward(self, x):
return x.round()
def backward(self, output_grads):
return output_grads
class Observer(Module):
r"""
A base class for Observer Module.
:param dtype: a string indicating to collect scale and zero_point of which dtype
"""
def __init__(self, dtype="qint8"):
super().__init__()
if dtype not in _metadata_dict.keys():
raise ValueError(
"unknown dtype: {}, only support {}".format(
dtype, _metadata_dict.keys()
)
)
self.dtype = dtype
self.qmin = _metadata_dict[dtype].qmin
self.qmax = _metadata_dict[dtype].qmax
self.zero_point, self.scale = None, None
self.enabled = True
def get_dtype(self):
scale, zero_point = self.get_qparams()
numpy_scale = None if scale is None else scale.numpy()[0]
numpy_zero_point = None if zero_point is None else zero_point.numpy()[0]
return get_quantized_dtype(self.dtype, numpy_scale, numpy_zero_point)
def enable(self):
self.enabled = True
def disable(self):
self.enabled = False
@abstractmethod
def forward(self, x):
pass
@abstractmethod
def get_qparams(self, **kwargs):
pass
class IdentityObserver(Observer):
r"""
An test Observer that always return scale:1 and zero_point:0.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.zero_point = ones((1), dtype="float32")
self.scale = zeros((1), dtype="float32")
def forward(self, x):
return x
def get_qparams(self):
return self.scale, self.zero_point
class MinMaxObserver(Observer):
def __init__(self, symmetric=True, eps=0.00001, *args, **kwargs):
super().__init__(*args, **kwargs)
self.symmetric = symmetric
if self.symmetric:
# assert qmin + qmax == -1, 'when reduce_range, qmin + qmax shoule equal -1'
self.zero_point = tensor((self.qmin + self.qmax + 1) // 2)
self.min_val = Buffer(0.0, dtype=np.float32)
self.max_val = Buffer(0.0, dtype=np.float32)
self.scale_limit = eps
# flag is used by cond_take, first time will be first flag, and after will be set as not_flag
self.first_flag = Buffer(np.array([1, 0], dtype=np.int32))
self.not_flag = Buffer(np.array([0, 1], dtype=np.int32))
def set_min_max(self, tmp_min, tmp_max):
# FIXME: cond_take will destory shape, use reshape to reset shape
tmp_min = tmp_min.reshape(1)
tmp_max = tmp_max.reshape(1)
if self.training:
F.zero_grad(
F.add_update(self.min_val, tmp_min, alpha=0.0, beta=1.0, bias=0.0)
)
F.zero_grad(
F.add_update(self.max_val, tmp_max, alpha=0.0, beta=1.0, bias=0.0)
)
F.zero_grad(
F.add_update(
self.first_flag, self.not_flag, alpha=0.0, beta=1.0, bias=0.0
)
)
# FIXME: add_update is applied after the whole trace procedure in `symbolic=True`
# mode. So use tmp_min/tmp_max to calc and save scale/zero_point for further
# calculation in FakeQuant.
self.set_scale_zero_point(tmp_min, tmp_max)
def set_scale_zero_point(self, tmp_min, tmp_max):
if self.symmetric:
symmetric_max_vals = F.maximum(-tmp_min, tmp_max)
# use maximun to avoid scale too small at the begin
self.scale = F.maximum(
symmetric_max_vals / ((self.qmax - self.qmin) / 2), self.scale_limit
)
# zero_point = self.zero_point
else:
# use maximun to avoid scale too small at the begin
self.scale = F.maximum(
(tmp_max - tmp_min) / (self.qmax - self.qmin), self.scale_limit
)
# caculate zero_point
self.zero_point = self.qmin - Round()((tmp_min / self.scale))
def get_qparams(self):
# scale and zero_point is runtime tensor rather than Buffer,
# so need to re-calc if min_val and max_val are loaded.
if self.scale is None:
self.set_scale_zero_point(self.min_val, self.max_val)
return self.scale, self.zero_point
def forward(self, x_orig):
if self.enabled:
# stop gradient
x = F.zero_grad(x_orig)
# find max and min
tmp_min, _ = F.cond_take(
self.first_flag, F.concat([x.min(), F.minimum(self.min_val, x.min())])
)
tmp_max, _ = F.cond_take(
self.first_flag, F.concat([x.max(), F.maximum(self.max_val, x.max())])
)
self.set_min_max(tmp_min, tmp_max)
return x_orig
class ExponentialMovingAverageObserver(MinMaxObserver):
def __init__(self, momentum=0.9, *args, **kwargs):
super().__init__(*args, **kwargs)
self.momentum = Buffer(momentum)
def set_momentum(self, momentum):
self.momentum.set_value(momentum)
def forward(self, x_orig):
if self.enabled:
# stop gradient
x = F.zero_grad(x_orig)
# Exponential Moving Average
tmp_min, _ = F.cond_take(
self.first_flag,
F.concat(
[
x.min(),
self.momentum * self.min_val + (1 - self.momentum) * x.min(),
]
),
)
tmp_max, _ = F.cond_take(
self.first_flag,
F.concat(
[
x.max(),
self.momentum * self.max_val + (1 - self.momentum) * x.max(),
]
),
)
self.set_min_max(tmp_min, tmp_max)
return x_orig
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from functools import partial
from ..module import Module
from .fake_quant import FakeQuantize
from .observer import ExponentialMovingAverageObserver, MinMaxObserver
class QConfig:
"""
A config class indicating how to do quantize toward :class:`~.QATModule`'s
``activation``, ``weight`` and ``bias``.
And ``fake_quant`` parameter to indicate
See :meth:`~.QATModule.set_qconfig` for detail usage.
:param inp_observer: interface to instantiate an :class:`~.Observer` indicating
how to collect scales and zero_point of input.
:param weight_observer: similar to ``inp_observer`` but toward weight.
:param act_observer: similar to ``inp_observer`` but toward activation.
:param fake_quant: interface to instantiate a :class:`~.FakeQuantize` indicating
how to do fake_quant calculation. can be invoked multi times to get different
instance for each target tensor, for better control on enable and disable.
:param bias_fake_quant: similar to ``fake_quant``, but usually need to set ``dtype``
in advance, for bias's dtype is unable to be inferred from observer.
Examples:
.. code-block::
# Default EMA QConfig for QAT.
ema_fakequant_qconfig = QConfig(
inp_observer=ExponentialMovingAverageObserver,
weight_observer=ExponentialMovingAverageObserver,
act_observer=ExponentialMovingAverageObserver,
fake_quant=FakeQuantize,
)
"""
def __init__(
self, act_observer, weight_observer, inp_observer, fake_quant, bias_fake_quant,
):
if (
isinstance(act_observer, Module)
or isinstance(weight_observer, Module)
or isinstance(inp_observer, Module)
):
raise ValueError(
"QConfig must not receive observer instance, please pass observer"
" class generator using `partial(Observer, ...)` instead. Use"
" partial(MyObserver, x=1) to override arguments to constructor if needed"
)
self.act_observer = act_observer
self.weight_observer = weight_observer
self.inp_observer = inp_observer
self.fake_quant = fake_quant
self.bias_fake_quant = bias_fake_quant
# Default QAT QConfigs
min_max_fakequant_qconfig = QConfig(
inp_observer=MinMaxObserver,
weight_observer=MinMaxObserver,
act_observer=MinMaxObserver,
fake_quant=FakeQuantize,
bias_fake_quant=partial(FakeQuantize, dtype="qint32"),
)
ema_fakequant_qconfig = QConfig(
inp_observer=ExponentialMovingAverageObserver,
weight_observer=MinMaxObserver,
act_observer=ExponentialMovingAverageObserver,
fake_quant=FakeQuantize,
bias_fake_quant=partial(FakeQuantize, dtype="qint32"),
)
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from copy import deepcopy
from ..module import Module, QATModule, Sequential, quantized
from .qconfig import QConfig, ema_fakequant_qconfig
def quantize(module: Module, inplace=True):
r"""
Recursively convert `module` to `quantized` mode through :meth:`~.Module.apply`.
:param module: root module to do convert recursively.
"""
if not inplace:
module = deepcopy(module)
def is_qat_module(obj):
return isinstance(obj, QATModule)
# no need to pass prefix and get pure key of parent Module.
for key, submodule, parent in module._flatten(
with_key=True, with_parent=True, predicate=is_qat_module
):
if isinstance(parent, Sequential):
# cannnot use setattr to be compatible with Sequential's ``__setitem__``
parent[int(key.split(".")[-1])] = submodule.to_quantized()
else:
setattr(parent, key.split(".")[-1], submodule.to_quantized())
def quantize_qat(module: Module, qconfig: QConfig = ema_fakequant_qconfig):
r"""
Recursively convert `module` to `qat` mode through :meth:`~.Module.apply`
and set qconfig relatively.
:param module: root module to do convert recursively.
:param qconfig: a instance of :class:`~.QConfig` to be set as submodules' qconfig.
default is :any:`~.qconfig.ema_fakequant_qconfig`.
"""
def fn(mod: Module):
if isinstance(mod, QATModule):
mod.set_qat_mode(QATModule.QATMode.QAT)
mod.set_qconfig(qconfig)
module.apply(fn)
def disable_fake_quant(module: Module):
r"""
Recursively disable `module` fake quantization in QATModule through :meth:`~.Module.apply`
:param module: root module to do disable fake quantization recursively.
"""
def fn(mod):
if isinstance(mod, QATModule):
mod.act_fake_quant.disable()
mod.weight_fake_quant.disable()
mod.inp_fake_quant.disable()
module.apply(fn)
def disable_observer(module: Module):
r"""
Recursively disable `module` observer in QATModule through :meth:`~.Module.apply`
:param module: root module to do disable observer recursively.
"""
def fn(mod):
if isinstance(mod, QATModule):
mod.act_observer.disable()
module.apply(fn)
def enable_fake_quant(module: Module):
r"""
Recursively enable `module` fake quantization in QATModule through :meth:`~.Module.apply`
:param module: root module to do enable fake quantization recursively.
"""
def fn(mod):
if isinstance(mod, QATModule):
mod.act_fake_quant.enable()
mod.weight_fake_quant.enable()
mod.inp_fake_quant.enable()
module.apply(fn)
def enable_observer(module: Module):
r"""
Recursively enable `module` observer in QATModule through :meth:`~.Module.apply`
:param module: root module to do enable observer recursively.
"""
def fn(mod):
if isinstance(mod, QATModule):
mod.act_observer.enable()
module.apply(fn)
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from functools import partial, update_wrapper, wraps
def register_method_to_class(cls):
def decorator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
return func(self, *args, **kwargs)
if isinstance(func, partial):
update_wrapper(func, func.func)
setattr(cls, func.__name__, wrapper)
return func
return decorator
......@@ -7,10 +7,12 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
import pytest
from helpers import opr_test
import megengine._internal as mgb
import megengine.functional as F
from megengine import Buffer, jit, tensor
from megengine import Buffer, Parameter, is_cuda_available, jit, tensor
from megengine.test import assertTensorClose
......@@ -332,3 +334,108 @@ def test_binary_cross_entropy():
{"input": [data2, label2], "output": expect2,},
]
opr_test(cases, F.binary_cross_entropy, compare_fn=compare_fn)
@pytest.mark.skip
def test_conv_bias():
inp_scale = 0.01
w_scale = 0.02
outp_scale = 0.1
inp_dtype = mgb.dtype.qint8(inp_scale)
w_dtype = mgb.dtype.qint8(w_scale)
b_dtype = mgb.dtype.qint32(inp_scale * w_scale)
out_dtype = mgb.dtype.qint8(outp_scale)
def run(
N,
IC,
OC,
IH,
IW,
KH,
KW,
PH,
PW,
SH,
SW,
has_bias=True,
nonlinear_mode="IDENTITY",
):
inp_v = np.random.normal(size=(N, IC, IH, IW))
w_v = np.random.normal(size=(OC, IC, KW, KW))
b_v = np.random.normal(size=(1, OC, 1, 1))
inp_scale = mgb.dtype.get_scale(inp_dtype)
w_scale = mgb.dtype.get_scale(w_dtype)
b_scale = mgb.dtype.get_scale(b_dtype)
inpv = mgb.dtype.convert_to_qint8(inp_v * inp_scale, inp_dtype)
wv = mgb.dtype.convert_to_qint8(w_v * w_scale, w_dtype)
bv = mgb.dtype.convert_to_qint32(b_v * b_scale, b_dtype)
inp_int8 = tensor(inpv, dtype=inp_dtype)
w_int8 = Parameter(wv, dtype=w_dtype)
b_int32 = Parameter(bv, dtype=b_dtype)
inp_fp32 = inp_int8.astype("float32")
w_fp32 = w_int8.astype("float32")
b_fp32 = b_int32.astype("float32")
jit.trace.enabled = True
b_symbolic = True
def convert_to_nchw4(var):
return var.reshape(
var.shapeof(0), var.shapeof(1) // 4, 4, var.shapeof(2), var.shapeof(3)
).dimshuffle(0, 1, 3, 4, 2)
@jit.trace(symbolic=b_symbolic)
def run_conv2d(inp, w, b):
O = F.conv2d(
inp, w, b if has_bias else None, stride=(SH, SW), padding=(PH, PW),
)
if nonlinear_mode == "RELU":
return F.relu(O)
else:
return O
@jit.trace(symbolic=b_symbolic)
def run_conv_bias(inp, w, b, format="NCHW"):
b = b if has_bias else np.zeros_like(b)
if format == "NCHW4":
inp = convert_to_nchw4(inp)
w = convert_to_nchw4(w)
b = F.flatten(b)
return F.conv_bias_activation(
inp,
w,
b,
stride=(SH, SW),
padding=(PH, PW),
dtype=out_dtype,
nonlinear_mode=nonlinear_mode,
)
format = "NCHW4" if is_cuda_available() else "NCHW"
expected = run_conv2d(inp_fp32, w_fp32, b_fp32)
expected = expected.astype(out_dtype).astype("float32")
result = run_conv_bias(inp_int8, w_int8, b_int32, format=format).astype(
"float32"
)
if format == "NCHW4":
result = result.dimshuffle(0, 1, 4, 2, 3)
expected = F.flatten(expected)
result = F.flatten(result)
assertTensorClose(result.numpy(), expected.numpy())
if not is_cuda_available():
run(1, 4, 4, 24, 33, 1, 1, 2, 3, 1, 1, False)
run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1, False)
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False)
run(1, 4, 4, 24, 33, 1, 1, 2, 3, 1, 1)
run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1)
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2)
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False, "RELU")
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "RELU")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册