diff --git a/python_module/megengine/functional/__init__.py b/python_module/megengine/functional/__init__.py index f84181e6e564dd318a901a5fbf83ed3b281df4f4..9f7abc0f4eb4166039e50b32c30e72204b2f28d0 100644 --- a/python_module/megengine/functional/__init__.py +++ b/python_module/megengine/functional/__init__.py @@ -74,12 +74,14 @@ from .nn import ( softmax, warp_perspective, ) +from .quantized import conv_bias_activation from .sort import argsort, sort, top_k from .tensor import ( add_axis, arange, broadcast_to, concat, + cond_take, dimshuffle, gather, linspace, diff --git a/python_module/megengine/functional/quantized.py b/python_module/megengine/functional/quantized.py new file mode 100644 index 0000000000000000000000000000000000000000..da333cc0114cd008bb2d9d07c8caa10835389508 --- /dev/null +++ b/python_module/megengine/functional/quantized.py @@ -0,0 +1,84 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# pylint: disable=too-many-lines +from typing import Tuple, Union + +from .. import _internal as mgb +from ..core import Tensor, wrap_io_tensor +from ..utils.types import _pair, _pair_nonzero +from .debug_param import get_conv_execution_strategy + + +@wrap_io_tensor +def conv_bias_activation( + inp: Tensor, + weight: Tensor, + bias: Tensor, + dtype=None, + stride: Union[int, Tuple[int, int]] = 1, + padding: Union[int, Tuple[int, int]] = 0, + dilation: Union[int, Tuple[int, int]] = 1, + groups: int = 1, + nonlinear_mode="IDENTITY", + conv_mode="CROSS_CORRELATION", + compute_mode="DEFAULT", +) -> Tensor: + """ convolution bias with activation operation, only for inference. + + :param inp: The feature map of the convolution operation + :param weight: The convolution kernel + :param bias: The bias added to the result of convolution + :param stride: Stride of the 2D convolution operation. Default: 1 + :param padding: Size of the paddings added to the input on both sides of its + spatial dimensions. Only zero-padding is supported. Default: 0 + :param dilation: Dilation of the 2D convolution operation. Default: 1 + :param groups: number of groups to divide input and output channels into, + so as to perform a "grouped convolution". When ``groups`` is not 1, + ``in_channels`` and ``out_channels`` must be divisible by ``groups``, + and the shape of weight should be ``(groups, out_channel // groups, + in_channels // groups, height, width)``. + :type conv_mode: string or :class:`mgb.opr_param_defs.Convolution.Mode` + :param conv_mode: Supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default: + 'CROSS_CORRELATION'. + :param dtype: Support for np.dtype, Default: + np.int8. + :param scale: scale if use quantization, Default: + 0.0. + :param zero_point: scale if use quantization quint8, Default: + 0.0. + :type compute_mode: string or + :class:`mgb.opr_param_defs.Convolution.ComputeMode` + :param compute_mode: When set to 'DEFAULT', no special requirements will be + placed on the precision of intermediate results. When set to 'FLOAT32', + Float32 would be used for accumulator and intermediate result, but only + effective when input and output are of Float16 dtype. + + """ + ph, pw = _pair(padding) + sh, sw = _pair_nonzero(stride) + dh, dw = _pair_nonzero(dilation) + sparse_type = "DENSE" if groups == 1 else "GROUP" + res = mgb.opr.conv_bias_activation( + inp, + weight, + bias, + compute_mode=compute_mode, + dtype=dtype, + strategy=get_conv_execution_strategy(), + nonlineMode=nonlinear_mode, + sparse=sparse_type, + format="NCHW", + pad_h=ph, + pad_w=pw, + stride_h=sh, + stride_w=sw, + dilate_h=dh, + dilate_w=dw, + mode=conv_mode, + ) + return res diff --git a/python_module/megengine/functional/tensor.py b/python_module/megengine/functional/tensor.py index b737cf2f44fcd77cae8393a568efd1d5c9959666..3f1c032da7b0803c038e2b11efed5a2014ea4a01 100644 --- a/python_module/megengine/functional/tensor.py +++ b/python_module/megengine/functional/tensor.py @@ -359,6 +359,41 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor: return out +@wrap_io_tensor +def cond_take(mask: Tensor, x: Tensor, val=1) -> Tensor: + r""" + Take elements from data if specific condition is satisfied on mask. This operator has two outputs: the first is the elements taken, and the second is the indices corresponding to those elements; they are both 1-dimensional. High-dimension input would first be flattened. + + :param mask: condition param; must be the same shape with data + :param x: input tensor from which to take elements + :param val: value to be compared to by mode + + Examples: + + .. testcode:: + + from megengine import tensor + import megengine.functional as F + mask = tensor(np.array([[1, 0], [0, 1]], dtype=np.int32)) + x = tensor(np.array([[1, np.inf], [np.nan, 4]], + dtype=np.float32)) + v, index = F.cond_take(mask, x, 1) + print(v, index) + + Outputs: + + .. testoutput:: + + Tensor([1. 4.]) Tensor([0 3], dtype=int32) + + """ + + v, index = mgb.opr.cond_take( + x, mask, mode=mgb.opr_param_defs.CondTake.Mode.EQ, val=val + ) + return v, index + + def shapeof(x: Tensor, axis=None): r""" The shape of input tensor. diff --git a/python_module/megengine/module/__init__.py b/python_module/megengine/module/__init__.py index ab1c14a0be2d358df9794af86b4219a6671310aa..f941d2fa1a35d60c61ad8e775905a75f796e124c 100644 --- a/python_module/megengine/module/__init__.py +++ b/python_module/megengine/module/__init__.py @@ -8,12 +8,16 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax from .batchnorm import BatchNorm1d, BatchNorm2d +from .concat import Concat from .conv import Conv2d, ConvTranspose2d +from .conv_bn_relu import ConvBn2d, ConvBnRelu2d from .dropout import Dropout +from .elemwise import Elemwise from .embedding import Embedding from .identity import Identity from .linear import Linear -from .module import Module +from .module import Module, QATModule from .parampack import ParamPack from .pooling import AvgPool2d, MaxPool2d +from .quant_dequant import DequantStub, QuantStub from .sequential import Sequential diff --git a/python_module/megengine/module/concat.py b/python_module/megengine/module/concat.py new file mode 100644 index 0000000000000000000000000000000000000000..b62086d0d03485f75a5bf58761375aad153f3c76 --- /dev/null +++ b/python_module/megengine/module/concat.py @@ -0,0 +1,27 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from typing import Iterable + +from .. import functional as F +from ..core.tensor import Tensor +from .module import QATModule + + +class Concat(QATModule): + r""" + A :class:`~.QATModule` to do functional concat, should replace concat with this module, + supporting ``qat`` mode and ``quantized`` mode. + """ + + def forward(self, inps: Iterable[Tensor], axis: int = 0): + return F.concat(inps, axis) + + def forward_qat(self, inps: Iterable[Tensor], axis: int = 0): + return self.apply_fakequant_with_observer( + self.forward(inps, axis), self.act_fake_quant, self.act_observer + ) diff --git a/python_module/megengine/module/conv.py b/python_module/megengine/module/conv.py index fbeb50db11cbe9b9b642b6fb55fde6176fb26dbd..26587ad280e8c23544e0af2a1b30775d8f0918f7 100644 --- a/python_module/megengine/module/conv.py +++ b/python_module/megengine/module/conv.py @@ -182,11 +182,11 @@ class Conv2d(_ConvNd): # Assume format is NCHW return (1, self.out_channels, 1, 1) - def forward(self, inp): + def calc_conv(self, inp, weight, bias): return conv2d( inp, - self.weight, - self.bias, + weight, + bias, self.stride, self.padding, self.dilation, @@ -195,6 +195,9 @@ class Conv2d(_ConvNd): self.compute_mode, ) + def forward(self, inp): + return self.calc_conv(inp, self.weight, self.bias) + class ConvTranspose2d(_ConvNd): r"""Applies a 2D transposed convolution over an input tensor. diff --git a/python_module/megengine/module/conv_bn_relu.py b/python_module/megengine/module/conv_bn_relu.py new file mode 100644 index 0000000000000000000000000000000000000000..15964fcd0fb2201098ab7b8398a77d9ed91ec977 --- /dev/null +++ b/python_module/megengine/module/conv_bn_relu.py @@ -0,0 +1,168 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from typing import Tuple, Union + +from ..core import ones, zeros +from ..functional import flatten, relu, sqrt, sum +from .batchnorm import BatchNorm2d +from .conv import Conv2d +from .module import QATModule + + +class _ConvBn2d(QATModule): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: Union[int, Tuple[int, int]] = 1, + padding: Union[int, Tuple[int, int]] = 0, + dilation: Union[int, Tuple[int, int]] = 1, + groups: int = 1, + bias: bool = True, + conv_mode: str = "CROSS_CORRELATION", + compute_mode: str = "DEFAULT", + eps=1e-5, + momentum=0.9, + affine=True, + track_running_stats=True, + freeze_bn=False, + ): + super().__init__() + self.conv = Conv2d( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + conv_mode, + compute_mode, + ) + self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats) + self.freeze_bn = freeze_bn + + def update_bn_stats(self): + self.freeze_bn = False + return self + + def freeze_bn_stats(self): + self.freeze_bn = True + return self + + def get_bn_gamma_beta(self): + if self.bn.weight is None: + gamma = ones((self.bn.num_features), dtype="float32") + else: + gamma = self.bn.weight + + if self.bn.bias is None: + beta = zeros((self.bn.num_features), dtype="float32") + else: + beta = self.bn.bias + + return gamma, beta + + def get_batch_mean_var(self, inp): + def _sum_channel(inp, axis=0, keepdims=True): + if isinstance(axis, int): + out = sum(inp, axis=axis, keepdims=keepdims) + elif isinstance(axis, tuple): + for idx, elem in enumerate(axis): + out = sum(inp if idx == 0 else out, axis=elem, keepdims=keepdims) + return out + + sum1 = _sum_channel(inp, (0, 2, 3)) + sum2 = _sum_channel(inp ** 2, (0, 2, 3)) + reduce_size = inp.shapeof().prod() / inp.shapeof(1) + batch_mean = sum1 / reduce_size + batch_var = (sum2 - sum1 ** 2 / reduce_size) / (reduce_size - 1) + + return batch_mean, batch_var + + def fold_weight_bias(self, bn_mean, bn_var): + # get fold bn conv param + # bn_istd = 1 / bn_std + # w_fold = gamma / bn_std * W + # b_fold = gamma * (b - bn_mean) / bn_std + beta + gamma, beta = self.get_bn_gamma_beta() + b = self.conv.bias + if b is None: + b = zeros(self.conv._infer_bias_shape(), dtype="float32") + if bn_mean is None: + bn_mean = zeros((1, self.bn.num_features, 1, 1), dtype="float32") + if bn_var is None: + bn_var = ones((1, self.bn.num_features, 1, 1), dtype="float32") + + bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) + if self.conv.groups == 1: + w_fold = ( + self.conv.weight + * gamma.reshape(-1, 1, 1, 1) + * bn_istd.reshape(-1, 1, 1, 1) + ) + else: + w_fold = ( + self.conv.weight + * gamma.reshape(self.conv.groups, -1, 1, 1, 1) + * bn_istd.reshape(self.conv.groups, -1, 1, 1, 1) + ) + b_fold = flatten(beta) + ( + flatten(gamma) * (flatten(b) - flatten(bn_mean)) * flatten(bn_istd) + ) + b_fold = b_fold.reshape(self.conv._infer_bias_shape()) + + return w_fold, b_fold + + def calc_conv_bn_qat(self, inp): + # TODO: use pytorch method as + conv = self.conv(inp) + self.bn(conv) + + if self.training: + bn_mean, bn_var = self.get_batch_mean_var(conv) + else: + bn_mean, bn_var = self.bn.running_mean, self.bn.running_var + + w_fold, b_fold = self.fold_weight_bias(bn_mean, bn_var) + w_qat = self.apply_fakequant_with_observer( + w_fold, self.weight_fake_quant, self.weight_observer + ) + return self.conv.calc_conv(inp, w_qat, b_fold) + + +class ConvBn2d(_ConvBn2d): + r""" + A fused :class:`~.QATModule` including Conv2d and BatchNorm2d, supporting ``qat`` mode + and ``normal`` mode. + """ + + def forward_qat(self, inp): + return self.apply_fakequant_with_observer( + self.calc_conv_bn_qat(inp), self.act_fake_quant, self.act_observer + ) + + def forward(self, inp): + return self.bn(self.conv(inp)) + + +class ConvBnRelu2d(_ConvBn2d): + r""" + A fused :class:`~.QATModule` including Conv2d, BatchNorm2d and relu, supporting ``qat`` + mode and ``normal`` mode. + """ + + def forward_qat(self, inp): + return self.apply_fakequant_with_observer( + relu(self.calc_conv_bn_qat(inp)), self.act_fake_quant, self.act_observer + ) + + def forward(self, inp): + return relu(self.bn(self.conv(inp))) diff --git a/python_module/megengine/module/elemwise.py b/python_module/megengine/module/elemwise.py new file mode 100644 index 0000000000000000000000000000000000000000..9ca469c544378c5e5bc7c9c1a6f346eea2ef3f55 --- /dev/null +++ b/python_module/megengine/module/elemwise.py @@ -0,0 +1,95 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from .. import _internal as mgb +from ..core import Tensor, wrap_io_tensor +from ..core.graph import _use_default_if_none +from .module import QATModule + + +@wrap_io_tensor +def _elemwise_func(mode, *inputs, **kwargs) -> Tensor: + if all(isinstance(i, (int, float)) for i in inputs): + device, comp_graph = _use_default_if_none(None, None) + ret = mgb.opr.elemwise( + *inputs, mode=mode, comp_node=device, comp_graph=comp_graph, **kwargs + ) + return ret.inferred_value[0] + return mgb.opr.elemwise(*inputs, mode=mode, **kwargs) + + +class Elemwise(QATModule): + r""" + A :class:`~.QATModule` to do elemwise operator, should functional operator with this module, + supporting ``qat`` mode and ``normal`` mode. + + :param method: the elemwise method, support the following string. + It will do the normal elemwise operator for float. + + * "ADD": a + b + * "FUSE_ADD_RELU": max(x+y, 0) + * "MUL": x * y + * "MIN": min(x, y) + * "MAX": max(x, y) + * "SUB": x - y + * "TRUE_DIV": x / y + * "FUSE_ADD_SIGMOID": sigmoid(x + y) + * "FUSE_ADD_TANH": tanh(x + y) + * "RELU": x > 0 ? x : 0 + * "ABS": x > 0 ? x : -x + * "SIGMOID": sigmoid(x) + * "EXP": exp(x) + * "TANH": tanh(x) + * "FUSE_MUL_ADD3": x * y + z + * "FAST_TANH": fast_tanh(x) + * "NEGATE": -x + * "ACOS": acos(x) + * "ASIN": asin(x) + * "CEIL": ceil(x) + * "COS": cos(x) + * "EXPM1": expm1(x) + * "FLOOR": floor(x) + * "LOG": log(x) + * "LOG1P": log1p(x) + * "SIN": sin(x) + * "ROUND": round(x) + * "ERF": erf(x) + * "ERFINV": erfinv(x) + * "ERFC": erfc(x) + * "ERFCINV": erfcinv(x) + * "ABS_GRAD": abs_grad + * "FLOOR_DIV": floor_div + * "MOD": mod + * "SIGMOID_GRAD": sigmoid_grad + * "SWITCH_GT0": switch_gt0 + * "TANH_GRAD": tanh_grad + * "LT": lt + * "LEQ": leq + * "EQ": eq + * "POW": pow + * "LOG_SUM_EXP": log_sum_exp + * "FAST_TANH_GRAD": fast_tanh_grad + * "ATAN2": atan2 + * "COND_LEQ_MOV": cond_leq_mov + * "H_SWISH": h_swish + * "FUSE_ADD_H_SWISH": h_swish(x+y) + * "H_SWISH_GRAD": h_swish_grad + """ + + _elemwise_mode_type = mgb.opr_param_defs.Elemwise.Mode + + def __init__(self, method): + super().__init__() + self.method = self._elemwise_mode_type.convert(method) + + def forward(self, *inps): + return _elemwise_func(self.method, *inps) + + def forward_qat(self, *inps): + return self.apply_fakequant_with_observer( + self.forward(*inps), self.act_fake_quant, self.act_observer, + ) diff --git a/python_module/megengine/module/module.py b/python_module/megengine/module/module.py index 89a8edff3a9c4812dcc7da64bcd5082f71eba767..7041b93fdf605c173d5851c75cfe029cfcabf76d 100644 --- a/python_module/megengine/module/module.py +++ b/python_module/megengine/module/module.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. @@ -8,6 +7,7 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from abc import ABCMeta, abstractmethod from collections import OrderedDict +from enum import Enum from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union import numpy as np @@ -442,3 +442,95 @@ class Module(metaclass=ABCMeta): loaded.append(k) return set(loaded), set(skipped) + + +class QATModule(Module): + r""" + Base class of quantization related Module. Add extra forward methods + :meth:`~.QATModule.forward_qat` and :meth:`~.QATModule.forward_quantized` for + ``qat``(quantization aware training) mode and ``quantized`` mode respectively. + + Use :meth:`~.QATModule.quant` to switch between ``QAT`` and ``NORMAL`` mode, + and use :meth:`~.QATModule.to_quantized` to switch to ``quantized`` mode, + which is irreversible. + + If you want to recursively switch mode for all QATModule in network, use + functions in :mod:`~.quantization.quantize`. + """ + + class QATMode(Enum): + DISABLED = 1 + QAT = 2 + CALIBRATION = 3 + + def __init__(self): + from ..quantization import ( + QConfig, + FakeQuantize, + Observer, + ) # pylint: disable=all + + super().__init__() + + self.quantizing = self.QATMode.DISABLED + self.scale = None + + self.inp_observer = None # type: Observer + self.weight_observer = None # type: Observer + self.act_observer = None # type: Observer + + self.weight_fake_quant = None # type: FakeQuantize + self.bias_fake_quant = None # type: FakeQuantize + self.act_fake_quant = None # type: FakeQuantize + + def set_qconfig(self, qconfig: "QConfig"): + self.inp_observer = qconfig.inp_observer() + self.weight_observer = qconfig.weight_observer() + self.act_observer = qconfig.act_observer() + + self.weight_fake_quant = qconfig.fake_quant(self.weight_observer.dtype) + self.bias_fake_quant = qconfig.bias_fake_quant() + self.act_fake_quant = qconfig.fake_quant(self.act_observer.dtype) + + def apply_observer(self, target: Tensor, obs: "Observer"): + return obs(target) + + def apply_fakequant_with_observer( + self, target: Tensor, fq: "FakeQuantize", obs: "Observer" + ): + oup = self.apply_observer(target, obs) + return fq(oup, obs.scale, obs.zero_point) + + def set_qat_mode(self, mode: QATMode): + r""" + Change ``self.quantizing`` mode, available values: ``self.QATMode.DISABLED``, + ``QAT``,``CALIBRATION``. + """ + if not isinstance(mode, self.QATMode): + raise TypeError("mode must be QATMode Enum type") + self.quantizing = mode + + def to_quantized(self): + r""" + Return a new :class:`~.Module` with quantized parameters of ``self`` + according to scale and zero_point in ``self.xxx_observer``. + """ + raise NotImplementedError( + "Use megengine.quantization.quantize to register the method." + ) + + @abstractmethod + def forward_qat(self, *args, **kwargs): + r""" + Forward method for ``qat`` mode. + """ + + def __call__(self, *args, **kwargs): + if self.quantizing == self.QATMode.QAT: + return self.forward_qat(*args, **kwargs) + elif self.quantizing == self.QATMode.CALIBRATION: + # TODO implement the CALIBRATION + assert False + return None + else: + return self.forward(*args, **kwargs) diff --git a/python_module/megengine/module/quant_dequant.py b/python_module/megengine/module/quant_dequant.py new file mode 100644 index 0000000000000000000000000000000000000000..ed20e3c0ee3c4a9bdf5ad3397e57aac7bb35107f --- /dev/null +++ b/python_module/megengine/module/quant_dequant.py @@ -0,0 +1,34 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from .module import QATModule + + +class QuantStub(QATModule): + r""" + A helper QATModule doing quantize operation on input. + """ + + def forward(self, inp): + return inp + + def forward_qat(self, inp): + return self.apply_fakequant_with_observer( + inp, self.act_fake_quant, self.act_observer + ) + + +class DequantStub(QATModule): + r""" + A helper QATModule doing de-quantize operation on input. + """ + + def forward(self, inp): + return inp + + def forward_qat(self, inp): + return inp diff --git a/python_module/megengine/module/quantized/__init__.py b/python_module/megengine/module/quantized/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c7c9635e40e604a664f8daec84cba716b8011020 --- /dev/null +++ b/python_module/megengine/module/quantized/__init__.py @@ -0,0 +1,11 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from .concat import Concat +from .conv_bn_relu import ConvBn2d, ConvBnRelu2d +from .elemwise import Elemwise +from .quant_dequant import DequantStub, QuantStub diff --git a/python_module/megengine/module/quantized/concat.py b/python_module/megengine/module/quantized/concat.py new file mode 100644 index 0000000000000000000000000000000000000000..62a7778a8923efb089dec2dd3799f466fff1008c --- /dev/null +++ b/python_module/megengine/module/quantized/concat.py @@ -0,0 +1,45 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from typing import Iterable + +from ... import _internal as mgb +from ... import functional as F +from ... import module as Float +from ...core.tensor import Tensor +from ...quantization.utils import register_method_to_class +from ..module import Module + + +class Concat(Module): + r""" + A :class:`~.Module` to do quantized concat, inference only. + """ + + def __init__(self): + super().__init__() + self.scale = 1.0 + self.zero_point = 0.0 + self.output_dtype = mgb.dtype.qint8(self.scale) + + def forward(self, inps: Iterable[Tensor], axis: int = 0): + if self.training: + raise ValueError("quantized module only support inference.") + new_inps = (x.astype(self.output_dtype) for x in inps) + return F.concat(new_inps, axis) + + +@register_method_to_class(Float.Concat) +def to_quantized(float_module): + r""" + Replace :class:`~.module.QATModule`'s ``to_quantized`` method. + implemented here to avoid circular import. + """ + qmod = Concat() + qmod.output_dtype = float_module.act_observer.get_dtype() + qmod.scale, qmod.zero_point = float_module.act_observer.get_qparams() + return qmod diff --git a/python_module/megengine/module/quantized/conv_bn_relu.py b/python_module/megengine/module/quantized/conv_bn_relu.py new file mode 100644 index 0000000000000000000000000000000000000000..dfc502a729e14478bebbfa51c1215e2663c35a95 --- /dev/null +++ b/python_module/megengine/module/quantized/conv_bn_relu.py @@ -0,0 +1,114 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from functools import partial +from typing import Tuple, Union + +import megengine._internal as mgb + +from ... import module as Float +from ...core import Parameter +from ...functional import conv_bias_activation +from ...module import Conv2d +from ...quantization.utils import register_method_to_class + + +class _ConvBnActivation2d(Conv2d): + r"""Applies a 2D convolution over an quantized input tensor, inference only. + + The parameter is same with :class: `~.Conv2d` + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: Union[int, Tuple[int, int]] = 1, + padding: Union[int, Tuple[int, int]] = 0, + dilation: Union[int, Tuple[int, int]] = 1, + groups: int = 1, + conv_mode: str = "CROSS_CORRELATION", + compute_mode: str = "DEFAULT", + ): + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + True, + conv_mode, + compute_mode, + ) + self.scale = 1.0 + self.zero_point = 0.0 + self.output_dtype = mgb.dtype.qint8(self.scale) + self.weight = self.weight.astype(self.output_dtype) + self.bias = self.bias.astype(mgb.dtype.qint32(self.scale)) + + def calc_conv_quantized(self, inp, nonlinear_mode="IDENTITY"): + inp_scale = mgb.dtype.get_scale(inp.dtype) + w_scale = mgb.dtype.get_scale(self.weight.dtype) + bias_scale = inp_scale * w_scale + return conv_bias_activation( + inp, + self.weight, + self.bias.astype(mgb.dtype.qint32(bias_scale)), + self.output_dtype, + self.stride, + self.padding, + self.dilation, + self.groups, + conv_mode=self.conv_mode, + compute_mode=self.compute_mode, + nonlinear_mode=nonlinear_mode, + ) + + +class ConvBn2d(_ConvBnActivation2d): + def forward(self, inp): + if self.training: + raise ValueError("quantized module only support inference.") + return self.calc_conv_quantized(inp, nonlinear_mode="IDENTITY") + + +class ConvBnRelu2d(_ConvBnActivation2d): + def forward(self, inp): + if self.training: + raise ValueError("quantized module only support inference.") + return self.calc_conv_quantized(inp, nonlinear_mode="RELU") + + +def to_quantized(quantized_class, float_module): + qconv = quantized_class( + float_module.conv.in_channels, + float_module.conv.out_channels, + float_module.conv.kernel_size, + float_module.conv.stride, + float_module.conv.padding, + float_module.conv.dilation, + float_module.conv.groups, + ) + w_fold, b_fold = float_module.fold_weight_bias( + float_module.bn.running_mean, float_module.bn.running_var + ) + weight = w_fold.astype(float_module.weight_observer.get_dtype()) + qconv.output_dtype = float_module.act_observer.get_dtype() + qconv.weight = Parameter(weight.numpy()) + qconv.bias = Parameter(b_fold.numpy()) + qconv.scale, qconv.zero_point = float_module.act_observer.get_qparams() + + return qconv + + +# replace :class:`~.module.QATModule`'s ``to_quantized`` method. +# implemented here to avoid circular import. +register_method_to_class(Float.ConvBn2d)(partial(to_quantized, ConvBn2d)) +register_method_to_class(Float.ConvBnRelu2d)(partial(to_quantized, ConvBnRelu2d)) diff --git a/python_module/megengine/module/quantized/elemwise.py b/python_module/megengine/module/quantized/elemwise.py new file mode 100644 index 0000000000000000000000000000000000000000..9a03ac9a1811f5d4ea766b4712991f101909aba7 --- /dev/null +++ b/python_module/megengine/module/quantized/elemwise.py @@ -0,0 +1,59 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from ... import _internal as mgb +from ... import module as Float +from ...core import Tensor, wrap_io_tensor +from ...core.graph import _use_default_if_none +from ...quantization.utils import register_method_to_class +from ..module import Module + + +@wrap_io_tensor +def _elemwise_multi_type(mode, *inputs, **kwargs) -> Tensor: + if all(isinstance(i, (int, float)) for i in inputs): + device, comp_graph = _use_default_if_none(None, None) + ret = mgb.opr.elemwise_multi_type( + *inputs, mode=mode, comp_node=device, comp_graph=comp_graph, **kwargs, + ) + return ret.inferred_value[0] + return mgb.opr.elemwise_multi_type(*inputs, mode=mode, **kwargs) + + +class Elemwise(Module): + r""" + quantized module for elemwise operator, inference only. + + :param method: the elemwise method, supported string refer to :class:`~.module.elemwise.Elemwise`. + it will do quantized operator with specified output quantized dtype. + """ + + _elemwise_multi_type_mode = mgb.opr_param_defs.ElemwiseMultiType.Mode + + def __init__(self, method): + super().__init__() + self.method = self._elemwise_multi_type_mode.convert("Q" + method) + self.scale = 1.0 + self.zero_point = 0.0 + self.output_dtype = mgb.dtype.qint8(self.scale) + + def forward(self, *inps): + if self.training: + raise ValueError("quantized module only support inference.") + return _elemwise_multi_type(self.method, *inps, dtype=self.output_dtype) + + +@register_method_to_class(Float.Elemwise) +def to_quantized(float_module): + r""" + Replace :class:`~.module.QATModule`'s ``to_quantized`` method. + implemented here to avoid circular import. + """ + qmod = Elemwise(float_module.method.name) + qmod.output_dtype = float_module.act_observer.get_dtype() + qmod.scale, qmod.zero_point = float_module.act_observer.get_qparams() + return qmod diff --git a/python_module/megengine/module/quantized/quant_dequant.py b/python_module/megengine/module/quantized/quant_dequant.py new file mode 100644 index 0000000000000000000000000000000000000000..5faf923874130584a1fc51bebe94978530114f5c --- /dev/null +++ b/python_module/megengine/module/quantized/quant_dequant.py @@ -0,0 +1,61 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from ... import _internal as mgb +from ... import module as Float +from ...quantization.utils import register_method_to_class +from ..module import Module + + +class QuantStub(Module): + r""" + A helper quantize operation on input and inference only. + """ + + def __init__(self): + super().__init__() + self.scale = 1.0 + self.zero_point = 0.0 + self.output_dtype = mgb.dtype.qint8(self.scale) + + def forward(self, inp): + if self.training: + raise ValueError("quantized module only support inference.") + return inp.astype(self.output_dtype) + + +class DequantStub(Module): + r""" + A helper de-quantize operation and inference only. + """ + + def forward(self, inp): + if self.training: + raise ValueError("quantized module only support inference.") + return inp.astype("float32") + + +@register_method_to_class(Float.QuantStub) +def to_quantized(float_module): + r""" + Replace :class:`~.module.QATModule`'s ``to_quantized`` method. + implemented here to avoid circular import. + """ + qmod = QuantStub() + qmod.output_dtype = float_module.act_observer.get_dtype() + qmod.scale, qmod.zero_point = float_module.act_observer.get_qparams() + return qmod + + +@register_method_to_class(Float.DequantStub) +def to_quantized(float_module): + r""" + Replace :class:`~.module.QATModule`'s ``to_quantized`` method. + implemented here to avoid circular import. + """ + qmod = DequantStub() + return qmod diff --git a/python_module/megengine/module/sequential.py b/python_module/megengine/module/sequential.py index 2e1e52914360ec8b7607ae2f707ce2f6bcf56d7f..03afd48a7e3f0b4012e2fd59e2b6bff4d66b602f 100644 --- a/python_module/megengine/module/sequential.py +++ b/python_module/megengine/module/sequential.py @@ -68,6 +68,7 @@ class Sequential(Module): def __setitem__(self, idx, module): key = self.layer_keys[idx] + self.layer_values[idx] = module return setattr(self, key, module) def __delitem__(self, idx): diff --git a/python_module/megengine/quantization/__init__.py b/python_module/megengine/quantization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..428db50b6d33fd52f6a35fe7c0c0e9372a990fce --- /dev/null +++ b/python_module/megengine/quantization/__init__.py @@ -0,0 +1,11 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from .fake_quant import FakeQuantize +from .observer import Observer +from .qconfig import QConfig, ema_fakequant_qconfig, min_max_fakequant_qconfig +from .quantize import quantize, quantize_qat diff --git a/python_module/megengine/quantization/fake_quant.py b/python_module/megengine/quantization/fake_quant.py new file mode 100644 index 0000000000000000000000000000000000000000..21652d68f36000b91dac4494df02a82034086a74 --- /dev/null +++ b/python_module/megengine/quantization/fake_quant.py @@ -0,0 +1,48 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from .. import functional as F +from .._internal.dtype import _metadata_dict +from ..module import Module +from .observer import Round + + +class FakeQuantize(Module): + r""" + A module to do quant and dequant according to observer's scale and zero_point. + """ + + def __init__(self, dtype: str, enable: bool = True): + super().__init__() + if not dtype in _metadata_dict.keys(): + raise ValueError( + "unknown dtype: {}, only support {}".format( + dtype, _metadata_dict.keys() + ) + ) + self.dtype = dtype + self.qmin = _metadata_dict[dtype].qmin + self.qmax = _metadata_dict[dtype].qmax + self.enabled = enable + + def enable(self): + self.enabled = True + + def disable(self): + self.enabled = False + + def forward(self, inp, scale, zero_point): + if self.enabled: + # Quant + oup = Round()(inp / scale) + zero_point + # clip + oup = F.minimum(F.maximum(oup, self.qmin), self.qmax) + # DeQuant + oup = (oup - zero_point) * scale + return oup + + return inp diff --git a/python_module/megengine/quantization/observer.py b/python_module/megengine/quantization/observer.py new file mode 100644 index 0000000000000000000000000000000000000000..64e6addaba88e26192617d1991f3f1ce2e95b262 --- /dev/null +++ b/python_module/megengine/quantization/observer.py @@ -0,0 +1,193 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from abc import abstractmethod + +import numpy as np + +from .. import functional as F +from .._internal.dtype import _metadata_dict, get_quantized_dtype +from ..core import Buffer, Function, ones, tensor, zeros +from ..module import Module + + +class Round(Function): + def forward(self, x): + return x.round() + + def backward(self, output_grads): + return output_grads + + +class Observer(Module): + r""" + A base class for Observer Module. + + :param dtype: a string indicating to collect scale and zero_point of which dtype + """ + + def __init__(self, dtype="qint8"): + super().__init__() + if dtype not in _metadata_dict.keys(): + raise ValueError( + "unknown dtype: {}, only support {}".format( + dtype, _metadata_dict.keys() + ) + ) + self.dtype = dtype + self.qmin = _metadata_dict[dtype].qmin + self.qmax = _metadata_dict[dtype].qmax + self.zero_point, self.scale = None, None + self.enabled = True + + def get_dtype(self): + scale, zero_point = self.get_qparams() + numpy_scale = None if scale is None else scale.numpy()[0] + numpy_zero_point = None if zero_point is None else zero_point.numpy()[0] + return get_quantized_dtype(self.dtype, numpy_scale, numpy_zero_point) + + def enable(self): + self.enabled = True + + def disable(self): + self.enabled = False + + @abstractmethod + def forward(self, x): + pass + + @abstractmethod + def get_qparams(self, **kwargs): + pass + + +class IdentityObserver(Observer): + r""" + An test Observer that always return scale:1 and zero_point:0. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.zero_point = ones((1), dtype="float32") + self.scale = zeros((1), dtype="float32") + + def forward(self, x): + return x + + def get_qparams(self): + return self.scale, self.zero_point + + +class MinMaxObserver(Observer): + def __init__(self, symmetric=True, eps=0.00001, *args, **kwargs): + super().__init__(*args, **kwargs) + self.symmetric = symmetric + if self.symmetric: + # assert qmin + qmax == -1, 'when reduce_range, qmin + qmax shoule equal -1' + self.zero_point = tensor((self.qmin + self.qmax + 1) // 2) + + self.min_val = Buffer(0.0, dtype=np.float32) + self.max_val = Buffer(0.0, dtype=np.float32) + self.scale_limit = eps + # flag is used by cond_take, first time will be first flag, and after will be set as not_flag + self.first_flag = Buffer(np.array([1, 0], dtype=np.int32)) + self.not_flag = Buffer(np.array([0, 1], dtype=np.int32)) + + def set_min_max(self, tmp_min, tmp_max): + # FIXME: cond_take will destory shape, use reshape to reset shape + tmp_min = tmp_min.reshape(1) + tmp_max = tmp_max.reshape(1) + if self.training: + F.zero_grad( + F.add_update(self.min_val, tmp_min, alpha=0.0, beta=1.0, bias=0.0) + ) + F.zero_grad( + F.add_update(self.max_val, tmp_max, alpha=0.0, beta=1.0, bias=0.0) + ) + F.zero_grad( + F.add_update( + self.first_flag, self.not_flag, alpha=0.0, beta=1.0, bias=0.0 + ) + ) + + # FIXME: add_update is applied after the whole trace procedure in `symbolic=True` + # mode. So use tmp_min/tmp_max to calc and save scale/zero_point for further + # calculation in FakeQuant. + self.set_scale_zero_point(tmp_min, tmp_max) + + def set_scale_zero_point(self, tmp_min, tmp_max): + if self.symmetric: + symmetric_max_vals = F.maximum(-tmp_min, tmp_max) + # use maximun to avoid scale too small at the begin + self.scale = F.maximum( + symmetric_max_vals / ((self.qmax - self.qmin) / 2), self.scale_limit + ) + # zero_point = self.zero_point + else: + # use maximun to avoid scale too small at the begin + self.scale = F.maximum( + (tmp_max - tmp_min) / (self.qmax - self.qmin), self.scale_limit + ) + # caculate zero_point + self.zero_point = self.qmin - Round()((tmp_min / self.scale)) + + def get_qparams(self): + # scale and zero_point is runtime tensor rather than Buffer, + # so need to re-calc if min_val and max_val are loaded. + if self.scale is None: + self.set_scale_zero_point(self.min_val, self.max_val) + + return self.scale, self.zero_point + + def forward(self, x_orig): + if self.enabled: + # stop gradient + x = F.zero_grad(x_orig) + # find max and min + tmp_min, _ = F.cond_take( + self.first_flag, F.concat([x.min(), F.minimum(self.min_val, x.min())]) + ) + tmp_max, _ = F.cond_take( + self.first_flag, F.concat([x.max(), F.maximum(self.max_val, x.max())]) + ) + self.set_min_max(tmp_min, tmp_max) + return x_orig + + +class ExponentialMovingAverageObserver(MinMaxObserver): + def __init__(self, momentum=0.9, *args, **kwargs): + super().__init__(*args, **kwargs) + self.momentum = Buffer(momentum) + + def set_momentum(self, momentum): + self.momentum.set_value(momentum) + + def forward(self, x_orig): + if self.enabled: + # stop gradient + x = F.zero_grad(x_orig) + # Exponential Moving Average + tmp_min, _ = F.cond_take( + self.first_flag, + F.concat( + [ + x.min(), + self.momentum * self.min_val + (1 - self.momentum) * x.min(), + ] + ), + ) + tmp_max, _ = F.cond_take( + self.first_flag, + F.concat( + [ + x.max(), + self.momentum * self.max_val + (1 - self.momentum) * x.max(), + ] + ), + ) + self.set_min_max(tmp_min, tmp_max) + return x_orig diff --git a/python_module/megengine/quantization/qconfig.py b/python_module/megengine/quantization/qconfig.py new file mode 100644 index 0000000000000000000000000000000000000000..14c348316e7fb9dd03c7d10e7b5bf87894969224 --- /dev/null +++ b/python_module/megengine/quantization/qconfig.py @@ -0,0 +1,82 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from functools import partial + +from ..module import Module +from .fake_quant import FakeQuantize +from .observer import ExponentialMovingAverageObserver, MinMaxObserver + + +class QConfig: + """ + A config class indicating how to do quantize toward :class:`~.QATModule`'s + ``activation``, ``weight`` and ``bias``. + + And ``fake_quant`` parameter to indicate + + See :meth:`~.QATModule.set_qconfig` for detail usage. + + :param inp_observer: interface to instantiate an :class:`~.Observer` indicating + how to collect scales and zero_point of input. + :param weight_observer: similar to ``inp_observer`` but toward weight. + :param act_observer: similar to ``inp_observer`` but toward activation. + :param fake_quant: interface to instantiate a :class:`~.FakeQuantize` indicating + how to do fake_quant calculation. can be invoked multi times to get different + instance for each target tensor, for better control on enable and disable. + :param bias_fake_quant: similar to ``fake_quant``, but usually need to set ``dtype`` + in advance, for bias's dtype is unable to be inferred from observer. + + Examples: + + .. code-block:: + + # Default EMA QConfig for QAT. + ema_fakequant_qconfig = QConfig( + inp_observer=ExponentialMovingAverageObserver, + weight_observer=ExponentialMovingAverageObserver, + act_observer=ExponentialMovingAverageObserver, + fake_quant=FakeQuantize, + ) + """ + + def __init__( + self, act_observer, weight_observer, inp_observer, fake_quant, bias_fake_quant, + ): + if ( + isinstance(act_observer, Module) + or isinstance(weight_observer, Module) + or isinstance(inp_observer, Module) + ): + raise ValueError( + "QConfig must not receive observer instance, please pass observer" + " class generator using `partial(Observer, ...)` instead. Use" + " partial(MyObserver, x=1) to override arguments to constructor if needed" + ) + self.act_observer = act_observer + self.weight_observer = weight_observer + self.inp_observer = inp_observer + self.fake_quant = fake_quant + self.bias_fake_quant = bias_fake_quant + + +# Default QAT QConfigs +min_max_fakequant_qconfig = QConfig( + inp_observer=MinMaxObserver, + weight_observer=MinMaxObserver, + act_observer=MinMaxObserver, + fake_quant=FakeQuantize, + bias_fake_quant=partial(FakeQuantize, dtype="qint32"), +) + +ema_fakequant_qconfig = QConfig( + inp_observer=ExponentialMovingAverageObserver, + weight_observer=MinMaxObserver, + act_observer=ExponentialMovingAverageObserver, + fake_quant=FakeQuantize, + bias_fake_quant=partial(FakeQuantize, dtype="qint32"), +) diff --git a/python_module/megengine/quantization/quantize.py b/python_module/megengine/quantization/quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..c89ad6dc6b94d4750005a9c14ebe1cccb6ffafbf --- /dev/null +++ b/python_module/megengine/quantization/quantize.py @@ -0,0 +1,113 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from copy import deepcopy + +from ..module import Module, QATModule, Sequential, quantized +from .qconfig import QConfig, ema_fakequant_qconfig + + +def quantize(module: Module, inplace=True): + r""" + Recursively convert `module` to `quantized` mode through :meth:`~.Module.apply`. + + :param module: root module to do convert recursively. + """ + + if not inplace: + module = deepcopy(module) + + def is_qat_module(obj): + return isinstance(obj, QATModule) + + # no need to pass prefix and get pure key of parent Module. + for key, submodule, parent in module._flatten( + with_key=True, with_parent=True, predicate=is_qat_module + ): + if isinstance(parent, Sequential): + # cannnot use setattr to be compatible with Sequential's ``__setitem__`` + parent[int(key.split(".")[-1])] = submodule.to_quantized() + else: + setattr(parent, key.split(".")[-1], submodule.to_quantized()) + + +def quantize_qat(module: Module, qconfig: QConfig = ema_fakequant_qconfig): + r""" + Recursively convert `module` to `qat` mode through :meth:`~.Module.apply` + and set qconfig relatively. + + :param module: root module to do convert recursively. + :param qconfig: a instance of :class:`~.QConfig` to be set as submodules' qconfig. + default is :any:`~.qconfig.ema_fakequant_qconfig`. + """ + + def fn(mod: Module): + if isinstance(mod, QATModule): + mod.set_qat_mode(QATModule.QATMode.QAT) + mod.set_qconfig(qconfig) + + module.apply(fn) + + +def disable_fake_quant(module: Module): + r""" + Recursively disable `module` fake quantization in QATModule through :meth:`~.Module.apply` + + :param module: root module to do disable fake quantization recursively. + """ + + def fn(mod): + if isinstance(mod, QATModule): + mod.act_fake_quant.disable() + mod.weight_fake_quant.disable() + mod.inp_fake_quant.disable() + + module.apply(fn) + + +def disable_observer(module: Module): + r""" + Recursively disable `module` observer in QATModule through :meth:`~.Module.apply` + + :param module: root module to do disable observer recursively. + """ + + def fn(mod): + if isinstance(mod, QATModule): + mod.act_observer.disable() + + module.apply(fn) + + +def enable_fake_quant(module: Module): + r""" + Recursively enable `module` fake quantization in QATModule through :meth:`~.Module.apply` + + :param module: root module to do enable fake quantization recursively. + """ + + def fn(mod): + if isinstance(mod, QATModule): + mod.act_fake_quant.enable() + mod.weight_fake_quant.enable() + mod.inp_fake_quant.enable() + + module.apply(fn) + + +def enable_observer(module: Module): + r""" + Recursively enable `module` observer in QATModule through :meth:`~.Module.apply` + + :param module: root module to do enable observer recursively. + """ + + def fn(mod): + if isinstance(mod, QATModule): + mod.act_observer.enable() + + module.apply(fn) diff --git a/python_module/megengine/quantization/utils.py b/python_module/megengine/quantization/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..dff2ddf973c632ea773dd82b1cb15cfceb584d44 --- /dev/null +++ b/python_module/megengine/quantization/utils.py @@ -0,0 +1,23 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + +from functools import partial, update_wrapper, wraps + + +def register_method_to_class(cls): + def decorator(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + return func(self, *args, **kwargs) + + if isinstance(func, partial): + update_wrapper(func, func.func) + setattr(cls, func.__name__, wrapper) + return func + + return decorator diff --git a/python_module/test/unit/functional/test_functional.py b/python_module/test/unit/functional/test_functional.py index b4e811295fdbe0e8fd047977cf954f13a06f0389..ac0728bafe28b24efab0ad8b59877249dde31626 100644 --- a/python_module/test/unit/functional/test_functional.py +++ b/python_module/test/unit/functional/test_functional.py @@ -7,10 +7,12 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import numpy as np +import pytest from helpers import opr_test +import megengine._internal as mgb import megengine.functional as F -from megengine import Buffer, jit, tensor +from megengine import Buffer, Parameter, is_cuda_available, jit, tensor from megengine.test import assertTensorClose @@ -332,3 +334,108 @@ def test_binary_cross_entropy(): {"input": [data2, label2], "output": expect2,}, ] opr_test(cases, F.binary_cross_entropy, compare_fn=compare_fn) + + +@pytest.mark.skip +def test_conv_bias(): + inp_scale = 0.01 + w_scale = 0.02 + outp_scale = 0.1 + inp_dtype = mgb.dtype.qint8(inp_scale) + w_dtype = mgb.dtype.qint8(w_scale) + b_dtype = mgb.dtype.qint32(inp_scale * w_scale) + out_dtype = mgb.dtype.qint8(outp_scale) + + def run( + N, + IC, + OC, + IH, + IW, + KH, + KW, + PH, + PW, + SH, + SW, + has_bias=True, + nonlinear_mode="IDENTITY", + ): + inp_v = np.random.normal(size=(N, IC, IH, IW)) + w_v = np.random.normal(size=(OC, IC, KW, KW)) + b_v = np.random.normal(size=(1, OC, 1, 1)) + inp_scale = mgb.dtype.get_scale(inp_dtype) + w_scale = mgb.dtype.get_scale(w_dtype) + b_scale = mgb.dtype.get_scale(b_dtype) + + inpv = mgb.dtype.convert_to_qint8(inp_v * inp_scale, inp_dtype) + wv = mgb.dtype.convert_to_qint8(w_v * w_scale, w_dtype) + bv = mgb.dtype.convert_to_qint32(b_v * b_scale, b_dtype) + + inp_int8 = tensor(inpv, dtype=inp_dtype) + w_int8 = Parameter(wv, dtype=w_dtype) + b_int32 = Parameter(bv, dtype=b_dtype) + + inp_fp32 = inp_int8.astype("float32") + w_fp32 = w_int8.astype("float32") + b_fp32 = b_int32.astype("float32") + + jit.trace.enabled = True + b_symbolic = True + + def convert_to_nchw4(var): + return var.reshape( + var.shapeof(0), var.shapeof(1) // 4, 4, var.shapeof(2), var.shapeof(3) + ).dimshuffle(0, 1, 3, 4, 2) + + @jit.trace(symbolic=b_symbolic) + def run_conv2d(inp, w, b): + O = F.conv2d( + inp, w, b if has_bias else None, stride=(SH, SW), padding=(PH, PW), + ) + if nonlinear_mode == "RELU": + return F.relu(O) + else: + return O + + @jit.trace(symbolic=b_symbolic) + def run_conv_bias(inp, w, b, format="NCHW"): + b = b if has_bias else np.zeros_like(b) + if format == "NCHW4": + inp = convert_to_nchw4(inp) + w = convert_to_nchw4(w) + b = F.flatten(b) + return F.conv_bias_activation( + inp, + w, + b, + stride=(SH, SW), + padding=(PH, PW), + dtype=out_dtype, + nonlinear_mode=nonlinear_mode, + ) + + format = "NCHW4" if is_cuda_available() else "NCHW" + + expected = run_conv2d(inp_fp32, w_fp32, b_fp32) + expected = expected.astype(out_dtype).astype("float32") + result = run_conv_bias(inp_int8, w_int8, b_int32, format=format).astype( + "float32" + ) + if format == "NCHW4": + result = result.dimshuffle(0, 1, 4, 2, 3) + expected = F.flatten(expected) + result = F.flatten(result) + assertTensorClose(result.numpy(), expected.numpy()) + + if not is_cuda_available(): + run(1, 4, 4, 24, 33, 1, 1, 2, 3, 1, 1, False) + run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1, False) + run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False) + + run(1, 4, 4, 24, 33, 1, 1, 2, 3, 1, 1) + run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1) + run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2) + + run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False, "RELU") + run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "RELU")