From ab9f44f15cf005ae5b81524fc59d08650fe124c4 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 19 Nov 2020 15:18:14 +0800 Subject: [PATCH] feat(mge/quantization): add support for easyquant GitOrigin-RevId: 060d908349ca6bdcee293be5a2e47a5bee98af5e --- .../megengine/module/quantized/linear.py | 4 +- .../python/megengine/quantization/__init__.py | 3 +- .../megengine/quantization/fake_quant.py | 32 ++- .../python/megengine/quantization/observer.py | 47 +++- .../python/megengine/quantization/qconfig.py | 40 ++- .../python/megengine/quantization/quantize.py | 121 +++++++- .../python/megengine/quantization/utils.py | 4 +- .../python/test/unit/quantization/quantize.py | 116 -------- .../test/unit/quantization/test_observer.py | 74 ++++- .../test/unit/quantization/test_qconfig.py | 14 + .../test/unit/quantization/test_quantize.py | 266 ++++++++++++++++++ 11 files changed, 561 insertions(+), 160 deletions(-) delete mode 100644 imperative/python/test/unit/quantization/quantize.py create mode 100644 imperative/python/test/unit/quantization/test_qconfig.py create mode 100644 imperative/python/test/unit/quantization/test_quantize.py diff --git a/imperative/python/megengine/module/quantized/linear.py b/imperative/python/megengine/module/quantized/linear.py index c01b2b492..7f8ac43dc 100644 --- a/imperative/python/megengine/module/quantized/linear.py +++ b/imperative/python/megengine/module/quantized/linear.py @@ -17,9 +17,7 @@ from .module import QuantizedModule class Linear(QuantizedModule): r"""Quantized version of :class:`~.qat.linear.Linear`.""" - def __init__( - self, dtype: np.dtype = None, - ): + def __init__(self, dtype: np.dtype = None): super().__init__() self.weight = None self.bias = None diff --git a/imperative/python/megengine/quantization/__init__.py b/imperative/python/megengine/quantization/__init__.py index d8be24ee6..427365e5a 100644 --- a/imperative/python/megengine/quantization/__init__.py +++ b/imperative/python/megengine/quantization/__init__.py @@ -15,7 +15,8 @@ from .qconfig import ( ema_fakequant_qconfig, ema_lowbit_fakequant_qconfig, min_max_fakequant_qconfig, + passive_qconfig, sync_ema_fakequant_qconfig, - tqt_quant_qconfig, + tqt_qconfig, ) from .utils import QuantMode diff --git a/imperative/python/megengine/quantization/fake_quant.py b/imperative/python/megengine/quantization/fake_quant.py index 2846131fc..a5accd1dd 100644 --- a/imperative/python/megengine/quantization/fake_quant.py +++ b/imperative/python/megengine/quantization/fake_quant.py @@ -28,7 +28,9 @@ class _FakeQuantize(Module): :param enable: whether do ``normal_forward`` or ``fake_quant_forward``. """ - def __init__(self, dtype: str, narrow_range: bool = False, enable: bool = True): + def __init__( + self, dtype: str, narrow_range: bool = False, enable: bool = True, **kwargs + ): super().__init__() if not dtype in _metadata_dict.keys(): raise ValueError( @@ -114,24 +116,28 @@ class TQT(_FakeQuantize): for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks. """ - def __init__(self, dtype: str, narrow_range: bool = False, enable: bool = True): - super().__init__(dtype, narrow_range, enable) - self.scale = Parameter([0.0], dtype=np.float32) + def __init__( + self, + q_dict, + dtype: str, + narrow_range: bool = False, + enable: bool = True, + **kwargs + ): + super().__init__(dtype, narrow_range, enable, **kwargs) + assert ( + q_dict["mode"] == QuantMode.SYMMERTIC + ), "only symmetric quantization is supported by TQT" + if "scale" not in q_dict or q_dict["scale"] is None: + raise AssertionError("Can not get an initialized scale") + self.scale = F.log(q_dict["scale"]) / math.log(2) def fake_quant_forward(self, inp, q_dict=None): # when enable, TQT will do fakequant forward, finetune the scale return TQT_Function(self.qmin, self.qmax)(inp, self.scale) - def normal_foward(self, inp, q_dict=None): - if q_dict["enable_observer"]: - # when disable, TQT will do normal forward, initialize scale weight - tmp_scale = F.maximum(F.abs(q_dict["min_val"]), F.abs(q_dict["max_val"])) - tmp_scale = F.log(tmp_scale / 127) / math.log(2) - self.scale[...] = tmp_scale - return inp - def get_qparams(self): - q_dict = get_qparam_dict(QuantMode.TQT) + q_dict = get_qparam_dict(QuantMode.SYMMERTIC) q_dict["scale"] = 2 ** self.scale return q_dict diff --git a/imperative/python/megengine/quantization/observer.py b/imperative/python/megengine/quantization/observer.py index e3528e15d..4862d79ee 100644 --- a/imperative/python/megengine/quantization/observer.py +++ b/imperative/python/megengine/quantization/observer.py @@ -7,6 +7,7 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import math from abc import abstractmethod +from copy import deepcopy import numpy as np @@ -28,7 +29,7 @@ class Observer(Module): instead of 1 greater. Usually True for weight and False for activation. """ - def __init__(self, dtype: str, narrow_range: bool = False): + def __init__(self, dtype: str, narrow_range: bool = False, **kwargs): super().__init__() if dtype not in _metadata_dict.keys(): raise ValueError( @@ -81,8 +82,9 @@ class MinMaxObserver(Observer): eps=0.00001, dtype="qint8", narrow_range: bool = False, + **kwargs ): - super().__init__(dtype, narrow_range) + super().__init__(dtype, narrow_range, **kwargs) self.mode = mode self.min_val = Tensor(np.finfo(np.float32).max, dtype=np.float32) self.max_val = Tensor(np.finfo(np.float32).min, dtype=np.float32) @@ -105,7 +107,7 @@ class MinMaxObserver(Observer): else: # use maximun to avoid scale too small at the begin q_dict["scale"] = F.maximum( - (max_val - min_val) / (self.qmax - self.qmin), self.scale_limit, + (max_val - min_val) / (self.qmax - self.qmin), self.scale_limit ) # caculate zero_point q_dict["zero_point"] = self.qmin - Round()((min_val / q_dict["scale"])) @@ -148,8 +150,9 @@ class ExponentialMovingAverageObserver(MinMaxObserver): eps=0.00001, dtype="qint8", narrow_range: bool = False, + **kwargs ): - super().__init__(mode, eps, dtype, narrow_range) + super().__init__(mode, eps, dtype, narrow_range, **kwargs) self.momentum = Tensor(momentum) self.runtime_momentum = Tensor(0.0) @@ -205,8 +208,9 @@ class HistogramObserver(MinMaxObserver): eps=0.00001, dtype="qint8", narrow_range: bool = False, + **kwargs ): - super().__init__(mode, eps, dtype, narrow_range) + super().__init__(mode, eps, dtype, narrow_range, **kwargs) self.bins = bins self.upsample_rate = upsample_rate self.dst_nbins = _metadata_dict[dtype].qmax - _metadata_dict[dtype].qmin + 1 @@ -417,7 +421,7 @@ class HistogramObserver(MinMaxObserver): # combine the existing histogram and new histogram into 1 histogram # We do this by first upsampling the histogram to a dense grid # and then downsampling the histogram efficiently - (new_min, new_max, downsample_rate, start_idx,) = self._adjust_min_max( + (new_min, new_max, downsample_rate, start_idx) = self._adjust_min_max( new_min, new_max, self.upsample_rate ) @@ -442,3 +446,34 @@ class HistogramObserver(MinMaxObserver): def forward(self, x_orig): self.sideeffect_forward(x_orig) return x_orig + + +class PassiveObserver(Observer): + r""" + This class can be set :attr:`scale` derectly. + """ + + def __init__(self, q_dict, dtype: str, narrow_range: bool = False, **kwargs): + super().__init__(dtype, narrow_range, **kwargs) + self.q_dict = deepcopy(q_dict) + if "scale" not in q_dict or q_dict["scale"] is None: + raise AssertionError("Can not get an initialized scale") + self.orig_scale = q_dict["scale"].numpy() + + @property + def scale(self): + return self.q_dict["scale"] + + @scale.setter + def scale(self, value): + assert value > 0 + self.q_dict["scale"].set_value(value) + + def get_qparams(self): + return self.q_dict + + def forward(self, x): + r""" + Just return input because :attr:`q_dict` is set by :func:`~.apply_easy_quant`. + """ + return x diff --git a/imperative/python/megengine/quantization/qconfig.py b/imperative/python/megengine/quantization/qconfig.py index 74757c7db..5213a622f 100644 --- a/imperative/python/megengine/quantization/qconfig.py +++ b/imperative/python/megengine/quantization/qconfig.py @@ -13,6 +13,7 @@ from .observer import ( ExponentialMovingAverageObserver, HistogramObserver, MinMaxObserver, + PassiveObserver, SyncExponentialMovingAverageObserver, SyncMinMaxObserver, ) @@ -66,17 +67,22 @@ class QConfig: self.weight_fake_quant = weight_fake_quant self.act_fake_quant = act_fake_quant + def __eq__(self, other): + def eq(a, b): + if isinstance(a, partial) and isinstance(b, partial): + return all( + [a.func == b.func, a.args == b.args, a.keywords == b.keywords] + ) + else: + return a == b + + return ( + eq(self.weight_observer, other.weight_observer) + and eq(self.act_observer, other.act_observer) + and eq(self.weight_fake_quant, other.weight_fake_quant) + and eq(self.act_fake_quant, other.act_fake_quant) + ) -tqt_quant_qconfig = QConfig( - weight_observer=partial( - ExponentialMovingAverageObserver, dtype="qint8", narrow_range=True - ), - act_observer=partial( - ExponentialMovingAverageObserver, dtype="qint8", narrow_range=False - ), - weight_fake_quant=partial(TQT, dtype="qint8", narrow_range=True), - act_fake_quant=partial(TQT, dtype="qint8", narrow_range=False), -) min_max_fakequant_qconfig = QConfig( weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True), @@ -118,3 +124,17 @@ calibration_qconfig = QConfig( weight_fake_quant=None, act_fake_quant=None, ) + +tqt_qconfig = QConfig( + weight_observer=None, + act_observer=None, + weight_fake_quant=partial(TQT, dtype="qint8", narrow_range=True), + act_fake_quant=partial(TQT, dtype="qint8", narrow_range=False), +) + +passive_qconfig = QConfig( + weight_observer=partial(PassiveObserver, dtype="qint8", narrow_range=True), + act_observer=partial(PassiveObserver, dtype="qint8", narrow_range=False), + weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True), + act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False), +) diff --git a/imperative/python/megengine/quantization/quantize.py b/imperative/python/megengine/quantization/quantize.py index 5dab2ae4e..d6216d9b7 100644 --- a/imperative/python/megengine/quantization/quantize.py +++ b/imperative/python/megengine/quantization/quantize.py @@ -6,15 +6,18 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from copy import copy, deepcopy +from functools import partial from typing import Callable, Dict, Tuple +import numpy as np + from .. import module as Float +from ..functional import concat, norm from ..module import Module from ..module import qat as QAT from ..module import quantized as Quantized from ..module.qat import QATModule from ..module.quantized import QuantizedModule -from .fake_quant import TQT from .qconfig import QConfig, ema_fakequant_qconfig @@ -32,9 +35,7 @@ def _get_quantable_module_names(): return quantable_module_names -def _get_convert_dict() -> Tuple[ - Dict[Module, QATModule], Dict[QATModule, QuantizedModule] -]: +def _get_convert_dict(): quantable_module_names = _get_quantable_module_names() quantable_modules = [getattr(Float, key) for key in quantable_module_names] @@ -47,6 +48,11 @@ def _get_convert_dict() -> Tuple[ _float2qat_dict, _qat2quantized_dict = _get_convert_dict() +qat_modules = tuple(_qat2quantized_dict.keys()) + + +def is_qat(mod: Module): + return isinstance(mod, qat_modules) def quantize(module: Module, inplace: bool = True, mapping: dict = None): @@ -133,6 +139,34 @@ def quantize_qat( return module +def reset_qconfig(module: Module, qconfig: QConfig, inplace: bool = True): + r""" + Reset :class:`~._FakeQuantize` and :class:`~.Observer` according to ``qconfig`` + + :param module: root module to reset recursively. + :param qconfig: an instance of :class:`~.QConfig` to be set as submodules' qconfig. + :param inplace: whether to reset submodules in-place. + """ + + if not inplace: + module = deepcopy(module) + + def safe_call(func, q_dict): + return func(q_dict=q_dict) if func is not None else None + + for m in list(module._flatten(predicate=is_qat)): + if m.with_weight: + weight_q_dict = m.get_weight_qparams() + m.weight_observer = safe_call(qconfig.weight_observer, weight_q_dict) + m.weight_fake_quant = safe_call(qconfig.weight_fake_quant, weight_q_dict) + if m.with_act: + act_q_dict = m.get_activation_qparams() + m.act_observer = safe_call(qconfig.act_observer, act_q_dict) + m.act_fake_quant = safe_call(qconfig.act_fake_quant, act_q_dict) + + return module + + def _propagate(module: Module, func_str: str, *args, **kargs): def fn(mod: Module): if isinstance(mod, QATModule): @@ -151,6 +185,85 @@ def propagate_qconfig(module: QATModule, qconfig: QConfig): _propagate(module, "set_qconfig", qconfig) +def hook_qat_module(module: Module, func: Callable): + r""" + Add hooks for all :class:`~.QATModule` submodule + """ + + hooks = [] + for submodule in list(module._flatten(predicate=is_qat)): + hooks.append(submodule.register_forward_hook(func)) + + return hooks + + +def apply_easy_quant(module, data, start=0.8, stop=1.2, num=40): + r""" + Implementation of ``EasyQuant``: https://arxiv.org/pdf/2006.16669. + Search for optimal scales. + + :param module: root module. + :param data: input tensor used to search optimal scale. + :param start: lower bound of the search interval. + :param stop: upper bound of the search interval. + :param num: number of samples to search. + """ + + batch_size = data.shape[0] + + def get_cosine(x, y): + ndim = len(x.shape) + axis = tuple(range(1, ndim)) + up = (x * y).sum(axis=axis) + down = norm(x, axis=axis) * norm(y, axis=axis) + sim = up / down + return sim.mean(axis=0) + + def search(mod, inputs, outputs, where): + + mod._forward_hooks.clear() + + fp32_in = [_[:batch_size] for _ in inputs] + int8_in = [_[batch_size:] for _ in inputs] + + disable_fake_quant(mod) + fp32_out = mod(*fp32_in) + enable_fake_quant(mod) + + ob = getattr(mod, where) + if ob is None: + return + + orig_scale = ob.orig_scale + distance = 0 + best_scale = 0 + for scale in np.linspace(start * orig_scale, stop * orig_scale, num): + ob.scale = scale + int8_out = mod(*int8_in) + dis = get_cosine(fp32_out, int8_out) + if dis > distance: + distance = dis + best_scale = scale + ob.scale = best_scale + + if where == "act_observer": + int8_out = mod(*int8_in) + return concat([fp32_out, int8_out]) + else: + int8_out = outputs[batch_size:] + return concat([fp32_out, int8_out]) + + data = concat([data, data]) + + hook_qat_module(module, partial(search, where="weight_observer")) + module(data) + + hook_qat_module(module, partial(search, where="act_observer")) + module(data) + + return module + + def disable_fake_quant(module: Module): r""" Recursively disable ``module`` fake quantization in QATModule through :meth:`~.Module.apply` diff --git a/imperative/python/megengine/quantization/utils.py b/imperative/python/megengine/quantization/utils.py index 970c6beee..31c342e11 100644 --- a/imperative/python/megengine/quantization/utils.py +++ b/imperative/python/megengine/quantization/utils.py @@ -54,17 +54,15 @@ class QuantMode(Enum): SYMMERTIC = 1 ASYMMERTIC = 2 - TQT = 3 qparam_dict = { - QuantMode.SYMMERTIC: {"mode": QuantMode.SYMMERTIC, "scale": None,}, + QuantMode.SYMMERTIC: {"mode": QuantMode.SYMMERTIC, "scale": None}, QuantMode.ASYMMERTIC: { "mode": QuantMode.ASYMMERTIC, "scale": None, "zero_point": None, }, - QuantMode.TQT: {"mode": QuantMode.TQT, "scale": None,}, } diff --git a/imperative/python/test/unit/quantization/quantize.py b/imperative/python/test/unit/quantization/quantize.py deleted file mode 100644 index e912e0c0b..000000000 --- a/imperative/python/test/unit/quantization/quantize.py +++ /dev/null @@ -1,116 +0,0 @@ -# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") -# -# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -import numpy as np -import pytest - -from megengine import module as Float -from megengine import tensor -from megengine.module import qat as QAT -from megengine.quantization import min_max_fakequant_qconfig -from megengine.quantization.quantize import ( - _get_quantable_module_names, - disable_fake_quant, - quantize_qat, -) - - -def test_get_quantable_module_names(): - # need to make sure names from Quantized and QAT are the same - def _get_qat_module_names(): - def is_qat(key: str): - value = getattr(QAT, key) - return ( - isinstance(value, type) - and issubclass(value, QAT.QATModule) - and value != QAT.QATModule - ) - - # source should have all quantable modules' names - quantable_module_names = [key for key in dir(QAT) if is_qat(key)] - return quantable_module_names - - qat_module_names = _get_qat_module_names() - quantized_module_names = _get_quantable_module_names() - assert set(qat_module_names) == set(quantized_module_names) - - for key in qat_module_names: - value = getattr(Float, key) - assert ( - isinstance(value, type) - and issubclass(value, Float.Module) - and value != Float.Module - ) - - -def test_disable_quantize(): - class Net(Float.Module): - def __init__(self): - super().__init__() - self.conv = Float.ConvBnRelu2d(3, 3, 3) - self.conv.disable_quantize() - - def forward(self, x): - return self.conv(x) - - net = Net() - qat_net = quantize_qat(net, inplace=False) - assert isinstance(qat_net.conv, Float.ConvBnRelu2d) - assert isinstance(qat_net.conv.conv, Float.Conv2d) - - -def test_convert_with_custom_mapping(): - class FloatExample(Float.Module): - def forward(self, x): - return x - - class QATExample(QAT.QATModule): - def forward(self, x): - return x - - @classmethod - def from_float_module(cls, float_module): - return cls() - - class Net(Float.Module): - def __init__(self): - super().__init__() - self.example = FloatExample() - - def forward(self, x): - return self.example(x) - - net = Net() - qat_net = quantize_qat(net, inplace=False, mapping={FloatExample: QATExample}) - assert isinstance(qat_net.example, QATExample) - - -def test_disable_fake_quant(): - class Net(Float.Module): - def __init__(self): - super().__init__() - self.quant = Float.QuantStub() - self.linear = Float.Linear(3, 3) - self.dequant = Float.DequantStub() - self.linear.bias.set_value(np.random.rand(3)) - - def forward(self, x): - x = self.quant(x) - x = self.linear(x) - x = self.dequant(x) - return x - - x = tensor(np.random.randint(1, 10, size=(3, 3)).astype(np.float32)) - net = Net() - y1 = net(x).numpy() - net = quantize_qat(net, min_max_fakequant_qconfig) - y2 = net(x).numpy() - disable_fake_quant(net) - y3 = net(x).numpy() - np.testing.assert_allclose(y1, y3) - with pytest.raises(AssertionError): - np.testing.assert_allclose(y2, y3) diff --git a/imperative/python/test/unit/quantization/test_observer.py b/imperative/python/test/unit/quantization/test_observer.py index 61aff5eef..f27a8514d 100644 --- a/imperative/python/test/unit/quantization/test_observer.py +++ b/imperative/python/test/unit/quantization/test_observer.py @@ -6,17 +6,53 @@ import pytest import megengine as mge import megengine.distributed as dist -import megengine.quantization.observer as ob from megengine.distributed.helper import get_device_count_by_fork +from megengine.quantization.observer import ( + ExponentialMovingAverageObserver, + MinMaxObserver, + Observer, + PassiveObserver, + SyncExponentialMovingAverageObserver, + SyncMinMaxObserver, +) + + +def test_observer(): + with pytest.raises(TypeError): + Observer("qint8") def test_min_max_observer(): x = np.random.rand(3, 3, 3, 3).astype("float32") np_min, np_max = x.min(), x.max() x = mge.tensor(x) - m = ob.MinMaxObserver() + m = MinMaxObserver() m(x) - assert m.min_val == np_min and m.max_val == np_max + np.testing.assert_allclose(m.min_val.numpy(), np_min) + np.testing.assert_allclose(m.max_val.numpy(), np_max) + + +def test_exponential_moving_average_observer(): + t = np.random.rand() + x1 = np.random.rand(3, 3, 3, 3).astype("float32") + x2 = np.random.rand(3, 3, 3, 3).astype("float32") + expected_min = x1.min() * t + x2.min() * (1 - t) + expected_max = x1.max() * t + x2.max() * (1 - t) + m = ExponentialMovingAverageObserver(momentum=t) + m(mge.tensor(x1, dtype=np.float32)) + m(mge.tensor(x2, dtype=np.float32)) + np.testing.assert_allclose(m.min_val.numpy(), expected_min) + np.testing.assert_allclose(m.max_val.numpy(), expected_max) + + +def test_passive_observer(): + q_dict = {"scale": mge.tensor(1.0)} + m = PassiveObserver(q_dict, "qint8") + assert m.orig_scale == 1.0 + assert m.scale == 1.0 + m.scale = 2.0 + assert m.scale == 2.0 + assert m.get_qparams() == {"scale": mge.tensor(2.0)} @pytest.mark.skipif( @@ -35,9 +71,39 @@ def test_sync_min_max_observer(): @dist.launcher def worker(): rank = dist.get_rank() - m = ob.SyncMinMaxObserver() + m = SyncMinMaxObserver() y = mge.tensor(x[rank * 3 : (rank + 1) * 3]) m(y) assert m.min_val == np_min and m.max_val == np_max worker() + + +@pytest.mark.skipif( + platform.system() == "Darwin", reason="do not imp GPU mode at macos now" +) +@pytest.mark.skipif( + platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM" +) +@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") +@pytest.mark.isolated_distributed +def test_sync_exponential_moving_average_observer(): + word_size = get_device_count_by_fork("gpu") + t = np.random.rand() + x1 = np.random.rand(3 * word_size, 3, 3, 3).astype("float32") + x2 = np.random.rand(3 * word_size, 3, 3, 3).astype("float32") + expected_min = x1.min() * t + x2.min() * (1 - t) + expected_max = x1.max() * t + x2.max() * (1 - t) + + @dist.launcher + def worker(): + rank = dist.get_rank() + m = SyncExponentialMovingAverageObserver(momentum=t) + y1 = mge.tensor(x1[rank * 3 : (rank + 1) * 3]) + y2 = mge.tensor(x2[rank * 3 : (rank + 1) * 3]) + m(y1) + m(y2) + np.testing.assert_allclose(m.min_val.numpy(), expected_min) + np.testing.assert_allclose(m.max_val.numpy(), expected_max) + + worker() diff --git a/imperative/python/test/unit/quantization/test_qconfig.py b/imperative/python/test/unit/quantization/test_qconfig.py new file mode 100644 index 000000000..92b0150f5 --- /dev/null +++ b/imperative/python/test/unit/quantization/test_qconfig.py @@ -0,0 +1,14 @@ +from functools import partial + +from megengine.quantization import QConfig, tqt_qconfig +from megengine.quantization.fake_quant import TQT + + +def test_equal(): + qconfig = QConfig( + weight_observer=None, + act_observer=None, + weight_fake_quant=partial(TQT, dtype="qint8", narrow_range=True), + act_fake_quant=partial(TQT, dtype="qint8", narrow_range=False), + ) + assert qconfig == tqt_qconfig diff --git a/imperative/python/test/unit/quantization/test_quantize.py b/imperative/python/test/unit/quantization/test_quantize.py new file mode 100644 index 000000000..20562d8be --- /dev/null +++ b/imperative/python/test/unit/quantization/test_quantize.py @@ -0,0 +1,266 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import numpy as np +import pytest + +from megengine import functional +from megengine import module as Float +from megengine import tensor +from megengine.module import qat as QAT +from megengine.module import quantized as Q +from megengine.quantization import ( + min_max_fakequant_qconfig, + passive_qconfig, + tqt_qconfig, +) +from megengine.quantization.fake_quant import TQT, FakeQuantize +from megengine.quantization.observer import MinMaxObserver, PassiveObserver +from megengine.quantization.quantize import ( + _get_quantable_module_names, + apply_easy_quant, + disable_fake_quant, + disable_observer, + enable_fake_quant, + enable_observer, + propagate_qconfig, + quantize, + quantize_qat, + reset_qconfig, +) + + +class Net(Float.Module): + def __init__(self): + super().__init__() + self.quant = Float.QuantStub() + self.linear = Float.Linear(3, 3) + self.dequant = Float.DequantStub() + self.linear.bias.set_value(np.random.rand(3)) + + def forward(self, x): + x = self.quant(x) + x = self.linear(x) + x = self.dequant(x) + return x + + +class QATNet(Float.Module): + def __init__(self): + super().__init__() + self.quant = QAT.QuantStub() + self.linear = QAT.Linear(3, 3) + self.dequant = QAT.DequantStub() + self.linear.bias.set_value(np.random.rand(3)) + + def forward(self, x): + x = self.quant(x) + x = self.linear(x) + x = self.dequant(x) + return x + + +def test_propagate_qconfig(): + net = QATNet() + propagate_qconfig(net, min_max_fakequant_qconfig) + assert all( + [ + net.quant.weight_observer is None, + net.quant.weight_fake_quant is None, + isinstance(net.quant.act_observer, MinMaxObserver), + isinstance(net.quant.act_fake_quant, FakeQuantize), + isinstance(net.linear.weight_observer, MinMaxObserver), + isinstance(net.linear.weight_fake_quant, FakeQuantize), + isinstance(net.linear.act_observer, MinMaxObserver), + isinstance(net.linear.act_fake_quant, FakeQuantize), + net.dequant.weight_observer is None, + net.dequant.weight_fake_quant is None, + net.dequant.act_observer is None, + net.dequant.act_observer is None, + ] + ) + + +def init_qat_net(): + net = QATNet() + propagate_qconfig(net, min_max_fakequant_qconfig) + min_val = np.random.randint(-127, 0, size=(2,)) + max_val = np.random.randint(1, 127, size=(2,)) + net.linear.weight_observer.min_val.set_value(min_val[0]) + net.linear.weight_observer.max_val.set_value(max_val[0]) + net.linear.act_observer.min_val.set_value(min_val[1]) + net.linear.act_observer.max_val.set_value(max_val[1]) + return net + + +def test_reset_qconfig(): + qat_net = init_qat_net() + new_qat_net = reset_qconfig(qat_net, passive_qconfig) + assert ( + new_qat_net.linear.get_weight_qparams() == qat_net.linear.get_weight_qparams() + ) + assert ( + new_qat_net.linear.get_activation_qparams() + == qat_net.linear.get_activation_qparams() + ) + + +def test_enable_and_disable_observer(): + net = init_qat_net() + enable_observer(net) + assert net.quant.act_observer.enabled == True + assert net.linear.weight_observer.enabled == True + assert net.linear.act_observer.enabled == True + disable_observer(net) + assert net.quant.act_observer.enabled == False + assert net.linear.weight_observer.enabled == False + assert net.linear.act_observer.enabled == False + + +def test_enable_and_disable_fake_quant(): + net = init_qat_net() + disable_fake_quant(net) + assert net.quant.act_fake_quant.enabled == False + assert net.linear.weight_fake_quant.enabled == False + assert net.linear.act_fake_quant.enabled == False + enable_fake_quant(net) + assert net.quant.act_fake_quant.enabled == True + assert net.linear.weight_fake_quant.enabled == True + assert net.linear.act_fake_quant.enabled == True + + +def init_observer(module, data): + enable_observer(module) + disable_fake_quant(module) + module(data) + disable_observer(module) + enable_fake_quant(module) + + +def test_enable_and_disable_all(): + x = tensor(np.random.randint(1, 10, size=(3, 3)).astype(np.float32)) + net = Net() + y1 = net(x).numpy() + net = quantize_qat(net, min_max_fakequant_qconfig) + + init_observer(net, x) + + y2 = net(x).numpy() + disable_fake_quant(net) + y3 = net(x).numpy() + enable_fake_quant(net) + y4 = net(x).numpy() + np.testing.assert_allclose(y1, y3) + np.testing.assert_allclose(y2, y4) + with pytest.raises(AssertionError): + np.testing.assert_allclose(y2, y3) + + +def test_quantize_qat(): + net = Net() + qat_net = quantize_qat(net, inplace=False, qconfig=min_max_fakequant_qconfig) + assert isinstance(qat_net.quant, QAT.QuantStub) + assert isinstance(qat_net.linear, QAT.Linear) + assert isinstance(qat_net.dequant, QAT.DequantStub) + + +def test_quantize(): + qat_net = init_qat_net() + q_net = quantize(qat_net, inplace=False) + assert isinstance(q_net.quant, Q.QuantStub) + assert isinstance(q_net.linear, Q.Linear) + assert isinstance(q_net.dequant, Q.DequantStub) + + +def test_apply_easy_quant(): + qat_net = init_qat_net() + data = tensor(np.random.rand(2, 3, 3, 3), dtype=np.float32) + eq_net = reset_qconfig(qat_net, passive_qconfig, inplace=False) + apply_easy_quant(eq_net, data, 0.9, 1.1, 10) + assert isinstance(eq_net.quant.act_observer, PassiveObserver) + assert isinstance(eq_net.linear.weight_observer, PassiveObserver) + assert isinstance(eq_net.linear.act_observer, PassiveObserver) + assert eq_net.dequant.act_observer is None + + +def test_apply_tqt(): + qat_net = init_qat_net() + tqt_net = reset_qconfig(qat_net, tqt_qconfig, inplace=False) + assert isinstance(tqt_net.quant.act_fake_quant, TQT) + assert isinstance(tqt_net.linear.weight_fake_quant, TQT) + assert isinstance(tqt_net.linear.act_fake_quant, TQT) + assert tqt_net.dequant.act_fake_quant is None + + +def test_get_quantable_module_names(): + # need to make sure names from Quantized and QAT are the same + def _get_qat_module_names(): + def is_qat(key: str): + value = getattr(QAT, key) + return ( + isinstance(value, type) + and issubclass(value, QAT.QATModule) + and value != QAT.QATModule + ) + + # source should have all quantable modules' names + quantable_module_names = [key for key in dir(QAT) if is_qat(key)] + return quantable_module_names + + qat_module_names = _get_qat_module_names() + quantized_module_names = _get_quantable_module_names() + assert set(qat_module_names) == set(quantized_module_names) + + for key in qat_module_names: + value = getattr(Float, key) + assert ( + isinstance(value, type) + and issubclass(value, Float.Module) + and value != Float.Module + ) + + +def test_disable_quantize(): + class Net(Float.Module): + def __init__(self): + super().__init__() + self.conv = Float.ConvBnRelu2d(3, 3, 3) + self.conv.disable_quantize() + + def forward(self, x): + return self.conv(x) + + net = Net() + qat_net = quantize_qat(net, inplace=False) + assert isinstance(qat_net.conv, Float.ConvBnRelu2d) + assert isinstance(qat_net.conv.conv, Float.Conv2d) + + +def test_convert_with_custom_mapping(): + class FloatExample(Float.Module): + def forward(self, x): + return x + + class QATExample(QAT.QATModule): + def forward(self, x): + return x + + @classmethod + def from_float_module(cls, float_module): + return cls() + + class Net(Float.Module): + def __init__(self): + super().__init__() + self.example = FloatExample() + + def forward(self, x): + return self.example(x) + + net = Net() + qat_net = quantize_qat(net, inplace=False, mapping={FloatExample: QATExample}) + assert isinstance(qat_net.example, QATExample) -- GitLab