提交 ab9f44f1 编写于 作者: M Megvii Engine Team

feat(mge/quantization): add support for easyquant

GitOrigin-RevId: 060d908349ca6bdcee293be5a2e47a5bee98af5e
上级 fc0fcd2f
...@@ -17,9 +17,7 @@ from .module import QuantizedModule ...@@ -17,9 +17,7 @@ from .module import QuantizedModule
class Linear(QuantizedModule): class Linear(QuantizedModule):
r"""Quantized version of :class:`~.qat.linear.Linear`.""" r"""Quantized version of :class:`~.qat.linear.Linear`."""
def __init__( def __init__(self, dtype: np.dtype = None):
self, dtype: np.dtype = None,
):
super().__init__() super().__init__()
self.weight = None self.weight = None
self.bias = None self.bias = None
......
...@@ -15,7 +15,8 @@ from .qconfig import ( ...@@ -15,7 +15,8 @@ from .qconfig import (
ema_fakequant_qconfig, ema_fakequant_qconfig,
ema_lowbit_fakequant_qconfig, ema_lowbit_fakequant_qconfig,
min_max_fakequant_qconfig, min_max_fakequant_qconfig,
passive_qconfig,
sync_ema_fakequant_qconfig, sync_ema_fakequant_qconfig,
tqt_quant_qconfig, tqt_qconfig,
) )
from .utils import QuantMode from .utils import QuantMode
...@@ -28,7 +28,9 @@ class _FakeQuantize(Module): ...@@ -28,7 +28,9 @@ class _FakeQuantize(Module):
:param enable: whether do ``normal_forward`` or ``fake_quant_forward``. :param enable: whether do ``normal_forward`` or ``fake_quant_forward``.
""" """
def __init__(self, dtype: str, narrow_range: bool = False, enable: bool = True): def __init__(
self, dtype: str, narrow_range: bool = False, enable: bool = True, **kwargs
):
super().__init__() super().__init__()
if not dtype in _metadata_dict.keys(): if not dtype in _metadata_dict.keys():
raise ValueError( raise ValueError(
...@@ -114,24 +116,28 @@ class TQT(_FakeQuantize): ...@@ -114,24 +116,28 @@ class TQT(_FakeQuantize):
for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks. for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks.
""" """
def __init__(self, dtype: str, narrow_range: bool = False, enable: bool = True): def __init__(
super().__init__(dtype, narrow_range, enable) self,
self.scale = Parameter([0.0], dtype=np.float32) q_dict,
dtype: str,
narrow_range: bool = False,
enable: bool = True,
**kwargs
):
super().__init__(dtype, narrow_range, enable, **kwargs)
assert (
q_dict["mode"] == QuantMode.SYMMERTIC
), "only symmetric quantization is supported by TQT"
if "scale" not in q_dict or q_dict["scale"] is None:
raise AssertionError("Can not get an initialized scale")
self.scale = F.log(q_dict["scale"]) / math.log(2)
def fake_quant_forward(self, inp, q_dict=None): def fake_quant_forward(self, inp, q_dict=None):
# when enable, TQT will do fakequant forward, finetune the scale # when enable, TQT will do fakequant forward, finetune the scale
return TQT_Function(self.qmin, self.qmax)(inp, self.scale) return TQT_Function(self.qmin, self.qmax)(inp, self.scale)
def normal_foward(self, inp, q_dict=None):
if q_dict["enable_observer"]:
# when disable, TQT will do normal forward, initialize scale weight
tmp_scale = F.maximum(F.abs(q_dict["min_val"]), F.abs(q_dict["max_val"]))
tmp_scale = F.log(tmp_scale / 127) / math.log(2)
self.scale[...] = tmp_scale
return inp
def get_qparams(self): def get_qparams(self):
q_dict = get_qparam_dict(QuantMode.TQT) q_dict = get_qparam_dict(QuantMode.SYMMERTIC)
q_dict["scale"] = 2 ** self.scale q_dict["scale"] = 2 ** self.scale
return q_dict return q_dict
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import math import math
from abc import abstractmethod from abc import abstractmethod
from copy import deepcopy
import numpy as np import numpy as np
...@@ -28,7 +29,7 @@ class Observer(Module): ...@@ -28,7 +29,7 @@ class Observer(Module):
instead of 1 greater. Usually True for weight and False for activation. instead of 1 greater. Usually True for weight and False for activation.
""" """
def __init__(self, dtype: str, narrow_range: bool = False): def __init__(self, dtype: str, narrow_range: bool = False, **kwargs):
super().__init__() super().__init__()
if dtype not in _metadata_dict.keys(): if dtype not in _metadata_dict.keys():
raise ValueError( raise ValueError(
...@@ -81,8 +82,9 @@ class MinMaxObserver(Observer): ...@@ -81,8 +82,9 @@ class MinMaxObserver(Observer):
eps=0.00001, eps=0.00001,
dtype="qint8", dtype="qint8",
narrow_range: bool = False, narrow_range: bool = False,
**kwargs
): ):
super().__init__(dtype, narrow_range) super().__init__(dtype, narrow_range, **kwargs)
self.mode = mode self.mode = mode
self.min_val = Tensor(np.finfo(np.float32).max, dtype=np.float32) self.min_val = Tensor(np.finfo(np.float32).max, dtype=np.float32)
self.max_val = Tensor(np.finfo(np.float32).min, dtype=np.float32) self.max_val = Tensor(np.finfo(np.float32).min, dtype=np.float32)
...@@ -105,7 +107,7 @@ class MinMaxObserver(Observer): ...@@ -105,7 +107,7 @@ class MinMaxObserver(Observer):
else: else:
# use maximun to avoid scale too small at the begin # use maximun to avoid scale too small at the begin
q_dict["scale"] = F.maximum( q_dict["scale"] = F.maximum(
(max_val - min_val) / (self.qmax - self.qmin), self.scale_limit, (max_val - min_val) / (self.qmax - self.qmin), self.scale_limit
) )
# caculate zero_point # caculate zero_point
q_dict["zero_point"] = self.qmin - Round()((min_val / q_dict["scale"])) q_dict["zero_point"] = self.qmin - Round()((min_val / q_dict["scale"]))
...@@ -148,8 +150,9 @@ class ExponentialMovingAverageObserver(MinMaxObserver): ...@@ -148,8 +150,9 @@ class ExponentialMovingAverageObserver(MinMaxObserver):
eps=0.00001, eps=0.00001,
dtype="qint8", dtype="qint8",
narrow_range: bool = False, narrow_range: bool = False,
**kwargs
): ):
super().__init__(mode, eps, dtype, narrow_range) super().__init__(mode, eps, dtype, narrow_range, **kwargs)
self.momentum = Tensor(momentum) self.momentum = Tensor(momentum)
self.runtime_momentum = Tensor(0.0) self.runtime_momentum = Tensor(0.0)
...@@ -205,8 +208,9 @@ class HistogramObserver(MinMaxObserver): ...@@ -205,8 +208,9 @@ class HistogramObserver(MinMaxObserver):
eps=0.00001, eps=0.00001,
dtype="qint8", dtype="qint8",
narrow_range: bool = False, narrow_range: bool = False,
**kwargs
): ):
super().__init__(mode, eps, dtype, narrow_range) super().__init__(mode, eps, dtype, narrow_range, **kwargs)
self.bins = bins self.bins = bins
self.upsample_rate = upsample_rate self.upsample_rate = upsample_rate
self.dst_nbins = _metadata_dict[dtype].qmax - _metadata_dict[dtype].qmin + 1 self.dst_nbins = _metadata_dict[dtype].qmax - _metadata_dict[dtype].qmin + 1
...@@ -417,7 +421,7 @@ class HistogramObserver(MinMaxObserver): ...@@ -417,7 +421,7 @@ class HistogramObserver(MinMaxObserver):
# combine the existing histogram and new histogram into 1 histogram # combine the existing histogram and new histogram into 1 histogram
# We do this by first upsampling the histogram to a dense grid # We do this by first upsampling the histogram to a dense grid
# and then downsampling the histogram efficiently # and then downsampling the histogram efficiently
(new_min, new_max, downsample_rate, start_idx,) = self._adjust_min_max( (new_min, new_max, downsample_rate, start_idx) = self._adjust_min_max(
new_min, new_max, self.upsample_rate new_min, new_max, self.upsample_rate
) )
...@@ -442,3 +446,34 @@ class HistogramObserver(MinMaxObserver): ...@@ -442,3 +446,34 @@ class HistogramObserver(MinMaxObserver):
def forward(self, x_orig): def forward(self, x_orig):
self.sideeffect_forward(x_orig) self.sideeffect_forward(x_orig)
return x_orig return x_orig
class PassiveObserver(Observer):
r"""
This class can be set :attr:`scale` derectly.
"""
def __init__(self, q_dict, dtype: str, narrow_range: bool = False, **kwargs):
super().__init__(dtype, narrow_range, **kwargs)
self.q_dict = deepcopy(q_dict)
if "scale" not in q_dict or q_dict["scale"] is None:
raise AssertionError("Can not get an initialized scale")
self.orig_scale = q_dict["scale"].numpy()
@property
def scale(self):
return self.q_dict["scale"]
@scale.setter
def scale(self, value):
assert value > 0
self.q_dict["scale"].set_value(value)
def get_qparams(self):
return self.q_dict
def forward(self, x):
r"""
Just return input because :attr:`q_dict` is set by :func:`~.apply_easy_quant`.
"""
return x
...@@ -13,6 +13,7 @@ from .observer import ( ...@@ -13,6 +13,7 @@ from .observer import (
ExponentialMovingAverageObserver, ExponentialMovingAverageObserver,
HistogramObserver, HistogramObserver,
MinMaxObserver, MinMaxObserver,
PassiveObserver,
SyncExponentialMovingAverageObserver, SyncExponentialMovingAverageObserver,
SyncMinMaxObserver, SyncMinMaxObserver,
) )
...@@ -66,17 +67,22 @@ class QConfig: ...@@ -66,17 +67,22 @@ class QConfig:
self.weight_fake_quant = weight_fake_quant self.weight_fake_quant = weight_fake_quant
self.act_fake_quant = act_fake_quant self.act_fake_quant = act_fake_quant
def __eq__(self, other):
def eq(a, b):
if isinstance(a, partial) and isinstance(b, partial):
return all(
[a.func == b.func, a.args == b.args, a.keywords == b.keywords]
)
else:
return a == b
return (
eq(self.weight_observer, other.weight_observer)
and eq(self.act_observer, other.act_observer)
and eq(self.weight_fake_quant, other.weight_fake_quant)
and eq(self.act_fake_quant, other.act_fake_quant)
)
tqt_quant_qconfig = QConfig(
weight_observer=partial(
ExponentialMovingAverageObserver, dtype="qint8", narrow_range=True
),
act_observer=partial(
ExponentialMovingAverageObserver, dtype="qint8", narrow_range=False
),
weight_fake_quant=partial(TQT, dtype="qint8", narrow_range=True),
act_fake_quant=partial(TQT, dtype="qint8", narrow_range=False),
)
min_max_fakequant_qconfig = QConfig( min_max_fakequant_qconfig = QConfig(
weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True), weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True),
...@@ -118,3 +124,17 @@ calibration_qconfig = QConfig( ...@@ -118,3 +124,17 @@ calibration_qconfig = QConfig(
weight_fake_quant=None, weight_fake_quant=None,
act_fake_quant=None, act_fake_quant=None,
) )
tqt_qconfig = QConfig(
weight_observer=None,
act_observer=None,
weight_fake_quant=partial(TQT, dtype="qint8", narrow_range=True),
act_fake_quant=partial(TQT, dtype="qint8", narrow_range=False),
)
passive_qconfig = QConfig(
weight_observer=partial(PassiveObserver, dtype="qint8", narrow_range=True),
act_observer=partial(PassiveObserver, dtype="qint8", narrow_range=False),
weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True),
act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False),
)
...@@ -6,15 +6,18 @@ ...@@ -6,15 +6,18 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from copy import copy, deepcopy from copy import copy, deepcopy
from functools import partial
from typing import Callable, Dict, Tuple from typing import Callable, Dict, Tuple
import numpy as np
from .. import module as Float from .. import module as Float
from ..functional import concat, norm
from ..module import Module from ..module import Module
from ..module import qat as QAT from ..module import qat as QAT
from ..module import quantized as Quantized from ..module import quantized as Quantized
from ..module.qat import QATModule from ..module.qat import QATModule
from ..module.quantized import QuantizedModule from ..module.quantized import QuantizedModule
from .fake_quant import TQT
from .qconfig import QConfig, ema_fakequant_qconfig from .qconfig import QConfig, ema_fakequant_qconfig
...@@ -32,9 +35,7 @@ def _get_quantable_module_names(): ...@@ -32,9 +35,7 @@ def _get_quantable_module_names():
return quantable_module_names return quantable_module_names
def _get_convert_dict() -> Tuple[ def _get_convert_dict():
Dict[Module, QATModule], Dict[QATModule, QuantizedModule]
]:
quantable_module_names = _get_quantable_module_names() quantable_module_names = _get_quantable_module_names()
quantable_modules = [getattr(Float, key) for key in quantable_module_names] quantable_modules = [getattr(Float, key) for key in quantable_module_names]
...@@ -47,6 +48,11 @@ def _get_convert_dict() -> Tuple[ ...@@ -47,6 +48,11 @@ def _get_convert_dict() -> Tuple[
_float2qat_dict, _qat2quantized_dict = _get_convert_dict() _float2qat_dict, _qat2quantized_dict = _get_convert_dict()
qat_modules = tuple(_qat2quantized_dict.keys())
def is_qat(mod: Module):
return isinstance(mod, qat_modules)
def quantize(module: Module, inplace: bool = True, mapping: dict = None): def quantize(module: Module, inplace: bool = True, mapping: dict = None):
...@@ -133,6 +139,34 @@ def quantize_qat( ...@@ -133,6 +139,34 @@ def quantize_qat(
return module return module
def reset_qconfig(module: Module, qconfig: QConfig, inplace: bool = True):
r"""
Reset :class:`~._FakeQuantize` and :class:`~.Observer` according to ``qconfig``
:param module: root module to reset recursively.
:param qconfig: an instance of :class:`~.QConfig` to be set as submodules' qconfig.
:param inplace: whether to reset submodules in-place.
"""
if not inplace:
module = deepcopy(module)
def safe_call(func, q_dict):
return func(q_dict=q_dict) if func is not None else None
for m in list(module._flatten(predicate=is_qat)):
if m.with_weight:
weight_q_dict = m.get_weight_qparams()
m.weight_observer = safe_call(qconfig.weight_observer, weight_q_dict)
m.weight_fake_quant = safe_call(qconfig.weight_fake_quant, weight_q_dict)
if m.with_act:
act_q_dict = m.get_activation_qparams()
m.act_observer = safe_call(qconfig.act_observer, act_q_dict)
m.act_fake_quant = safe_call(qconfig.act_fake_quant, act_q_dict)
return module
def _propagate(module: Module, func_str: str, *args, **kargs): def _propagate(module: Module, func_str: str, *args, **kargs):
def fn(mod: Module): def fn(mod: Module):
if isinstance(mod, QATModule): if isinstance(mod, QATModule):
...@@ -151,6 +185,85 @@ def propagate_qconfig(module: QATModule, qconfig: QConfig): ...@@ -151,6 +185,85 @@ def propagate_qconfig(module: QATModule, qconfig: QConfig):
_propagate(module, "set_qconfig", qconfig) _propagate(module, "set_qconfig", qconfig)
def hook_qat_module(module: Module, func: Callable):
r"""
Add hooks for all :class:`~.QATModule` submodule
"""
hooks = []
for submodule in list(module._flatten(predicate=is_qat)):
hooks.append(submodule.register_forward_hook(func))
return hooks
def apply_easy_quant(module, data, start=0.8, stop=1.2, num=40):
r"""
Implementation of ``EasyQuant``: https://arxiv.org/pdf/2006.16669.
Search for optimal scales.
:param module: root module.
:param data: input tensor used to search optimal scale.
:param start: lower bound of the search interval.
:param stop: upper bound of the search interval.
:param num: number of samples to search.
"""
batch_size = data.shape[0]
def get_cosine(x, y):
ndim = len(x.shape)
axis = tuple(range(1, ndim))
up = (x * y).sum(axis=axis)
down = norm(x, axis=axis) * norm(y, axis=axis)
sim = up / down
return sim.mean(axis=0)
def search(mod, inputs, outputs, where):
mod._forward_hooks.clear()
fp32_in = [_[:batch_size] for _ in inputs]
int8_in = [_[batch_size:] for _ in inputs]
disable_fake_quant(mod)
fp32_out = mod(*fp32_in)
enable_fake_quant(mod)
ob = getattr(mod, where)
if ob is None:
return
orig_scale = ob.orig_scale
distance = 0
best_scale = 0
for scale in np.linspace(start * orig_scale, stop * orig_scale, num):
ob.scale = scale
int8_out = mod(*int8_in)
dis = get_cosine(fp32_out, int8_out)
if dis > distance:
distance = dis
best_scale = scale
ob.scale = best_scale
if where == "act_observer":
int8_out = mod(*int8_in)
return concat([fp32_out, int8_out])
else:
int8_out = outputs[batch_size:]
return concat([fp32_out, int8_out])
data = concat([data, data])
hook_qat_module(module, partial(search, where="weight_observer"))
module(data)
hook_qat_module(module, partial(search, where="act_observer"))
module(data)
return module
def disable_fake_quant(module: Module): def disable_fake_quant(module: Module):
r""" r"""
Recursively disable ``module`` fake quantization in QATModule through :meth:`~.Module.apply` Recursively disable ``module`` fake quantization in QATModule through :meth:`~.Module.apply`
......
...@@ -54,17 +54,15 @@ class QuantMode(Enum): ...@@ -54,17 +54,15 @@ class QuantMode(Enum):
SYMMERTIC = 1 SYMMERTIC = 1
ASYMMERTIC = 2 ASYMMERTIC = 2
TQT = 3
qparam_dict = { qparam_dict = {
QuantMode.SYMMERTIC: {"mode": QuantMode.SYMMERTIC, "scale": None,}, QuantMode.SYMMERTIC: {"mode": QuantMode.SYMMERTIC, "scale": None},
QuantMode.ASYMMERTIC: { QuantMode.ASYMMERTIC: {
"mode": QuantMode.ASYMMERTIC, "mode": QuantMode.ASYMMERTIC,
"scale": None, "scale": None,
"zero_point": None, "zero_point": None,
}, },
QuantMode.TQT: {"mode": QuantMode.TQT, "scale": None,},
} }
......
...@@ -6,17 +6,53 @@ import pytest ...@@ -6,17 +6,53 @@ import pytest
import megengine as mge import megengine as mge
import megengine.distributed as dist import megengine.distributed as dist
import megengine.quantization.observer as ob
from megengine.distributed.helper import get_device_count_by_fork from megengine.distributed.helper import get_device_count_by_fork
from megengine.quantization.observer import (
ExponentialMovingAverageObserver,
MinMaxObserver,
Observer,
PassiveObserver,
SyncExponentialMovingAverageObserver,
SyncMinMaxObserver,
)
def test_observer():
with pytest.raises(TypeError):
Observer("qint8")
def test_min_max_observer(): def test_min_max_observer():
x = np.random.rand(3, 3, 3, 3).astype("float32") x = np.random.rand(3, 3, 3, 3).astype("float32")
np_min, np_max = x.min(), x.max() np_min, np_max = x.min(), x.max()
x = mge.tensor(x) x = mge.tensor(x)
m = ob.MinMaxObserver() m = MinMaxObserver()
m(x) m(x)
assert m.min_val == np_min and m.max_val == np_max np.testing.assert_allclose(m.min_val.numpy(), np_min)
np.testing.assert_allclose(m.max_val.numpy(), np_max)
def test_exponential_moving_average_observer():
t = np.random.rand()
x1 = np.random.rand(3, 3, 3, 3).astype("float32")
x2 = np.random.rand(3, 3, 3, 3).astype("float32")
expected_min = x1.min() * t + x2.min() * (1 - t)
expected_max = x1.max() * t + x2.max() * (1 - t)
m = ExponentialMovingAverageObserver(momentum=t)
m(mge.tensor(x1, dtype=np.float32))
m(mge.tensor(x2, dtype=np.float32))
np.testing.assert_allclose(m.min_val.numpy(), expected_min)
np.testing.assert_allclose(m.max_val.numpy(), expected_max)
def test_passive_observer():
q_dict = {"scale": mge.tensor(1.0)}
m = PassiveObserver(q_dict, "qint8")
assert m.orig_scale == 1.0
assert m.scale == 1.0
m.scale = 2.0
assert m.scale == 2.0
assert m.get_qparams() == {"scale": mge.tensor(2.0)}
@pytest.mark.skipif( @pytest.mark.skipif(
...@@ -35,9 +71,39 @@ def test_sync_min_max_observer(): ...@@ -35,9 +71,39 @@ def test_sync_min_max_observer():
@dist.launcher @dist.launcher
def worker(): def worker():
rank = dist.get_rank() rank = dist.get_rank()
m = ob.SyncMinMaxObserver() m = SyncMinMaxObserver()
y = mge.tensor(x[rank * 3 : (rank + 1) * 3]) y = mge.tensor(x[rank * 3 : (rank + 1) * 3])
m(y) m(y)
assert m.min_val == np_min and m.max_val == np_max assert m.min_val == np_min and m.max_val == np_max
worker() worker()
@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
@pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.isolated_distributed
def test_sync_exponential_moving_average_observer():
word_size = get_device_count_by_fork("gpu")
t = np.random.rand()
x1 = np.random.rand(3 * word_size, 3, 3, 3).astype("float32")
x2 = np.random.rand(3 * word_size, 3, 3, 3).astype("float32")
expected_min = x1.min() * t + x2.min() * (1 - t)
expected_max = x1.max() * t + x2.max() * (1 - t)
@dist.launcher
def worker():
rank = dist.get_rank()
m = SyncExponentialMovingAverageObserver(momentum=t)
y1 = mge.tensor(x1[rank * 3 : (rank + 1) * 3])
y2 = mge.tensor(x2[rank * 3 : (rank + 1) * 3])
m(y1)
m(y2)
np.testing.assert_allclose(m.min_val.numpy(), expected_min)
np.testing.assert_allclose(m.max_val.numpy(), expected_max)
worker()
from functools import partial
from megengine.quantization import QConfig, tqt_qconfig
from megengine.quantization.fake_quant import TQT
def test_equal():
qconfig = QConfig(
weight_observer=None,
act_observer=None,
weight_fake_quant=partial(TQT, dtype="qint8", narrow_range=True),
act_fake_quant=partial(TQT, dtype="qint8", narrow_range=False),
)
assert qconfig == tqt_qconfig
...@@ -8,17 +8,194 @@ ...@@ -8,17 +8,194 @@
import numpy as np import numpy as np
import pytest import pytest
from megengine import functional
from megengine import module as Float from megengine import module as Float
from megengine import tensor from megengine import tensor
from megengine.module import qat as QAT from megengine.module import qat as QAT
from megengine.quantization import min_max_fakequant_qconfig from megengine.module import quantized as Q
from megengine.quantization import (
min_max_fakequant_qconfig,
passive_qconfig,
tqt_qconfig,
)
from megengine.quantization.fake_quant import TQT, FakeQuantize
from megengine.quantization.observer import MinMaxObserver, PassiveObserver
from megengine.quantization.quantize import ( from megengine.quantization.quantize import (
_get_quantable_module_names, _get_quantable_module_names,
apply_easy_quant,
disable_fake_quant, disable_fake_quant,
disable_observer,
enable_fake_quant,
enable_observer,
propagate_qconfig,
quantize,
quantize_qat, quantize_qat,
reset_qconfig,
) )
class Net(Float.Module):
def __init__(self):
super().__init__()
self.quant = Float.QuantStub()
self.linear = Float.Linear(3, 3)
self.dequant = Float.DequantStub()
self.linear.bias.set_value(np.random.rand(3))
def forward(self, x):
x = self.quant(x)
x = self.linear(x)
x = self.dequant(x)
return x
class QATNet(Float.Module):
def __init__(self):
super().__init__()
self.quant = QAT.QuantStub()
self.linear = QAT.Linear(3, 3)
self.dequant = QAT.DequantStub()
self.linear.bias.set_value(np.random.rand(3))
def forward(self, x):
x = self.quant(x)
x = self.linear(x)
x = self.dequant(x)
return x
def test_propagate_qconfig():
net = QATNet()
propagate_qconfig(net, min_max_fakequant_qconfig)
assert all(
[
net.quant.weight_observer is None,
net.quant.weight_fake_quant is None,
isinstance(net.quant.act_observer, MinMaxObserver),
isinstance(net.quant.act_fake_quant, FakeQuantize),
isinstance(net.linear.weight_observer, MinMaxObserver),
isinstance(net.linear.weight_fake_quant, FakeQuantize),
isinstance(net.linear.act_observer, MinMaxObserver),
isinstance(net.linear.act_fake_quant, FakeQuantize),
net.dequant.weight_observer is None,
net.dequant.weight_fake_quant is None,
net.dequant.act_observer is None,
net.dequant.act_observer is None,
]
)
def init_qat_net():
net = QATNet()
propagate_qconfig(net, min_max_fakequant_qconfig)
min_val = np.random.randint(-127, 0, size=(2,))
max_val = np.random.randint(1, 127, size=(2,))
net.linear.weight_observer.min_val.set_value(min_val[0])
net.linear.weight_observer.max_val.set_value(max_val[0])
net.linear.act_observer.min_val.set_value(min_val[1])
net.linear.act_observer.max_val.set_value(max_val[1])
return net
def test_reset_qconfig():
qat_net = init_qat_net()
new_qat_net = reset_qconfig(qat_net, passive_qconfig)
assert (
new_qat_net.linear.get_weight_qparams() == qat_net.linear.get_weight_qparams()
)
assert (
new_qat_net.linear.get_activation_qparams()
== qat_net.linear.get_activation_qparams()
)
def test_enable_and_disable_observer():
net = init_qat_net()
enable_observer(net)
assert net.quant.act_observer.enabled == True
assert net.linear.weight_observer.enabled == True
assert net.linear.act_observer.enabled == True
disable_observer(net)
assert net.quant.act_observer.enabled == False
assert net.linear.weight_observer.enabled == False
assert net.linear.act_observer.enabled == False
def test_enable_and_disable_fake_quant():
net = init_qat_net()
disable_fake_quant(net)
assert net.quant.act_fake_quant.enabled == False
assert net.linear.weight_fake_quant.enabled == False
assert net.linear.act_fake_quant.enabled == False
enable_fake_quant(net)
assert net.quant.act_fake_quant.enabled == True
assert net.linear.weight_fake_quant.enabled == True
assert net.linear.act_fake_quant.enabled == True
def init_observer(module, data):
enable_observer(module)
disable_fake_quant(module)
module(data)
disable_observer(module)
enable_fake_quant(module)
def test_enable_and_disable_all():
x = tensor(np.random.randint(1, 10, size=(3, 3)).astype(np.float32))
net = Net()
y1 = net(x).numpy()
net = quantize_qat(net, min_max_fakequant_qconfig)
init_observer(net, x)
y2 = net(x).numpy()
disable_fake_quant(net)
y3 = net(x).numpy()
enable_fake_quant(net)
y4 = net(x).numpy()
np.testing.assert_allclose(y1, y3)
np.testing.assert_allclose(y2, y4)
with pytest.raises(AssertionError):
np.testing.assert_allclose(y2, y3)
def test_quantize_qat():
net = Net()
qat_net = quantize_qat(net, inplace=False, qconfig=min_max_fakequant_qconfig)
assert isinstance(qat_net.quant, QAT.QuantStub)
assert isinstance(qat_net.linear, QAT.Linear)
assert isinstance(qat_net.dequant, QAT.DequantStub)
def test_quantize():
qat_net = init_qat_net()
q_net = quantize(qat_net, inplace=False)
assert isinstance(q_net.quant, Q.QuantStub)
assert isinstance(q_net.linear, Q.Linear)
assert isinstance(q_net.dequant, Q.DequantStub)
def test_apply_easy_quant():
qat_net = init_qat_net()
data = tensor(np.random.rand(2, 3, 3, 3), dtype=np.float32)
eq_net = reset_qconfig(qat_net, passive_qconfig, inplace=False)
apply_easy_quant(eq_net, data, 0.9, 1.1, 10)
assert isinstance(eq_net.quant.act_observer, PassiveObserver)
assert isinstance(eq_net.linear.weight_observer, PassiveObserver)
assert isinstance(eq_net.linear.act_observer, PassiveObserver)
assert eq_net.dequant.act_observer is None
def test_apply_tqt():
qat_net = init_qat_net()
tqt_net = reset_qconfig(qat_net, tqt_qconfig, inplace=False)
assert isinstance(tqt_net.quant.act_fake_quant, TQT)
assert isinstance(tqt_net.linear.weight_fake_quant, TQT)
assert isinstance(tqt_net.linear.act_fake_quant, TQT)
assert tqt_net.dequant.act_fake_quant is None
def test_get_quantable_module_names(): def test_get_quantable_module_names():
# need to make sure names from Quantized and QAT are the same # need to make sure names from Quantized and QAT are the same
def _get_qat_module_names(): def _get_qat_module_names():
...@@ -87,30 +264,3 @@ def test_convert_with_custom_mapping(): ...@@ -87,30 +264,3 @@ def test_convert_with_custom_mapping():
net = Net() net = Net()
qat_net = quantize_qat(net, inplace=False, mapping={FloatExample: QATExample}) qat_net = quantize_qat(net, inplace=False, mapping={FloatExample: QATExample})
assert isinstance(qat_net.example, QATExample) assert isinstance(qat_net.example, QATExample)
def test_disable_fake_quant():
class Net(Float.Module):
def __init__(self):
super().__init__()
self.quant = Float.QuantStub()
self.linear = Float.Linear(3, 3)
self.dequant = Float.DequantStub()
self.linear.bias.set_value(np.random.rand(3))
def forward(self, x):
x = self.quant(x)
x = self.linear(x)
x = self.dequant(x)
return x
x = tensor(np.random.randint(1, 10, size=(3, 3)).astype(np.float32))
net = Net()
y1 = net(x).numpy()
net = quantize_qat(net, min_max_fakequant_qconfig)
y2 = net(x).numpy()
disable_fake_quant(net)
y3 = net(x).numpy()
np.testing.assert_allclose(y1, y3)
with pytest.raises(AssertionError):
np.testing.assert_allclose(y2, y3)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册