提交 1d7dd001 编写于 作者: M Megvii Engine Team

feat(mge/quantization): add QParams and QuantDtypeMeta for quantization data structure

GitOrigin-RevId: df3416fe13fbff1cc6dd8f88f0a937aa1b6b58a9
上级 4130dcd3
......@@ -6,9 +6,6 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import warnings
from typing import Union
from ..._imperative_rt import OpDef, ops
__all__ = ["OpDef"]
......
......@@ -5,22 +5,24 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections
from collections import namedtuple
from typing import Union
import numpy as np
# normal dtype related
from .._imperative_rt import bfloat16, intb1, intb2, intb4
from .._imperative_rt.common import (
bfloat16,
get_scale,
get_zero_point,
intb1,
intb2,
intb4,
is_dtype_equal,
is_quantize,
)
# normal dtype related
def is_lowbit(dtype):
return (dtype is intb1) or (dtype is intb2) or (dtype is intb4)
......@@ -30,34 +32,80 @@ def is_bfloat16(dtype):
# quantization dtype related
_QuantDtypeMetadata = collections.namedtuple(
"QuantDtypeMetadata", ["name", "np_dtype_str", "is_unsigned", "qmin", "qmax",]
)
_metadata_dict = {
"quint8": _QuantDtypeMetadata("Quantized8Asymm", "uint8", True, 0, 255),
"qint8": _QuantDtypeMetadata("QuantizedS8", "int8", False, -128, 127),
"quint4": _QuantDtypeMetadata("Quantized4Asymm", "uint8", True, 0, 15),
"qint4": _QuantDtypeMetadata("QuantizedS4", "int8", False, -8, 7),
"qint32": _QuantDtypeMetadata(
"QuantizedS32", "int32", False, -(2 ** 31), 2 ** 31 - 1,
# use namedtuple to make class immutable, comparable and easy to print
class QuantDtypeMeta(
namedtuple(
"QuantDtypeMeta",
["name", "cname", "np_dtype_str", "qmin", "qmax", "is_unsigned"],
)
):
r"""
Store metadata for quantize dtype. Could be used to create custom quant dtype
for QAT when the network don't need to be converted for inference, but only
to export network metadata for third-party platform inference.
:param name: a unique name string.
:param cname: used in :func:`~.create_quantized_dtype` for model dump and inference.
:param np_dtype_str: used in :func:`~.create_quantized_dtype` to generate ``np.dtype``.
:param qmin: a int number indicating quant dtype's lowerbound.
:param qmax: a int number indicating quant dtype's upperbound.
:param is_unsigned: a helper value that could be inference from np_dtype_str.
"""
def __new__(
cls,
name: str,
cname: str,
np_dtype_str: str,
qmin: int,
qmax: int,
is_unsigned: bool = None,
):
assert isinstance(np_dtype_str, str)
is_unsigned = np_dtype_str[0] == "u" if is_unsigned is None else is_unsigned
return super().__new__(cls, name, cname, np_dtype_str, qmin, qmax, is_unsigned)
def __copy__(self):
return self
def __deepcopy__(self, _):
"""
Ignore deepcopy so that a dtype meta can be treated as singleton, for more
strict check in :meth:`~.FakeQuantize.fake_quant_forward`.
"""
return self
_builtin_quant_dtypes = {
"quint8": QuantDtypeMeta("quint8", "Quantized8Asymm", "uint8", 0, 255),
"qint8": QuantDtypeMeta("qint8", "QuantizedS8", "int8", -128, 127),
"qint8_narrow": QuantDtypeMeta("qint8_narrow", "QuantizedS8", "int8", -127, 127),
"quint4": QuantDtypeMeta("quint4", "Quantized4Asymm", "uint8", 0, 15),
"qint4": QuantDtypeMeta("qint4", "QuantizedS4", "int8", -8, 7),
"qint32": QuantDtypeMeta(
"qint32", "QuantizedS32", "int32", -(2 ** 31), 2 ** 31 - 1,
),
# NOTE: int2 is not supported for model dump yet
"quint2": _QuantDtypeMetadata(None, "uint8", True, 0, 3),
"qint2": _QuantDtypeMetadata(None, "int8", False, -2, 1),
"quint2": QuantDtypeMeta("quint2", None, "uint8", 0, 3),
"qint2": QuantDtypeMeta("qint2", None, "int8", -2, 1),
}
def _check_zero_point(zp: int, dtype_str: str):
qmin = _metadata_dict[dtype_str].qmin
qmax = _metadata_dict[dtype_str].qmax
def _check_zero_point(zp: int, dtype_meta: QuantDtypeMeta):
qmin = dtype_meta.qmin
qmax = dtype_meta.qmax
if zp < qmin or zp > qmax:
raise ValueError(
"zero_point should be within [{}, {}] for {}".format(qmin, qmax, dtype_str)
"zero_point should be within [{}, {}] for {}".format(
qmin, qmax, dtype_meta.name
)
)
def get_quantized_dtype(dtype_str: str, scale: float, zp: Union[int, None]):
def create_quantized_dtype(
dtype_meta: QuantDtypeMeta, scale: float, zp: Union[int, None]
):
r"""
Get quantized dtype with metadata attribute according to _metadata_dict.
......@@ -65,32 +113,34 @@ def get_quantized_dtype(dtype_str: str, scale: float, zp: Union[int, None]):
not have ``zero_point``, to be consitent with tensor generated by calling
compiled function from `CompGraph.compile(inputs, outspec)`.
:param dtype: a string indicating which dtype to return
:param dtype_meta: a QuantDtypeMeta indicating which dtype to return. the
``cname`` attribute cannot be ``None``.
:param scale: a number for scale to store in dtype's metadata
:param zp: a number for zero_point to store in dtype's metadata
"""
metadata = _metadata_dict[dtype_str]
np_dtype_str = metadata.np_dtype_str
is_unsigned = metadata.is_unsigned
if is_unsigned:
if dtype_meta.cname is None:
raise ValueError("dtype {} without cname attr is not supported.")
if dtype_meta.is_unsigned:
if zp is None or int(zp) != zp:
raise ValueError("zero_point should be an integer")
zp = int(zp)
_check_zero_point(zp, dtype_str)
_check_zero_point(zp, dtype_meta)
return np.dtype(
np_dtype_str,
dtype_meta.np_dtype_str,
metadata={
"mgb_dtype": {
"name": metadata.name,
"name": dtype_meta.cname,
"scale": float(scale),
"zero_point": zp,
}
},
)
else:
# Don't trick to combine with is_unsigned. Metadata should not contain
# invalid field to keep consistent with c dtype.
return np.dtype(
np_dtype_str,
metadata={"mgb_dtype": {"name": metadata.name, "scale": float(scale)}},
dtype_meta.np_dtype_str,
metadata={"mgb_dtype": {"name": dtype_meta.cname, "scale": float(scale)}},
)
......@@ -100,7 +150,7 @@ def quint8(scale, zero_point):
``zero_point`` (uint8). The real value represented by a quint8 data type is
float_val = scale * (uint8_val - zero_point)
"""
return get_quantized_dtype("quint8", scale, zero_point)
return create_quantized_dtype(_builtin_quant_dtypes["quint8"], scale, zero_point)
def qint8(scale):
......@@ -108,7 +158,7 @@ def qint8(scale):
Construct a quantized int8 data type with ``scale`` (float). The real value
represented by a qint8 data type is float_val = scale * int8_val
"""
return get_quantized_dtype("qint8", scale, None)
return create_quantized_dtype(_builtin_quant_dtypes["qint8"], scale, None)
def qint32(scale):
......@@ -116,7 +166,7 @@ def qint32(scale):
Construct a quantized int32 data type with ``scale`` (float). The real value
represented by a qint32 data type is float_val = scale * int32_val
"""
return get_quantized_dtype("qint32", scale, None)
return create_quantized_dtype(_builtin_quant_dtypes["qint32"], scale, None)
def quint4(scale, zero_point):
......@@ -125,7 +175,7 @@ def quint4(scale, zero_point):
``zero_point`` (uint8). The real value represented by a quint4 data type is
float_val = scale * (uint4_val - zero_point)
"""
return get_quantized_dtype("quint4", scale, zero_point)
return create_quantized_dtype(_builtin_quant_dtypes["quint4"], scale, zero_point)
def qint4(scale):
......@@ -133,42 +183,48 @@ def qint4(scale):
Construct a quantized int4 data type with ``scale`` (float). The real value
represented by a qint4 data type is float_val = scale * int4_val
"""
return get_quantized_dtype("qint4", scale, None)
return create_quantized_dtype(_builtin_quant_dtypes["qint4"], scale, None)
def _convert_to_quantized_dtype(arr: np.ndarray, dtype: np.dtype, dtype_str: str):
metadata = _metadata_dict[dtype_str]
arr_metadata = dtype.metadata["mgb_dtype"]
def _convert_to_quantized_dtype(
arr: np.ndarray, dtype: np.dtype, dtype_meta: QuantDtypeMeta
):
if not isinstance(arr, np.ndarray):
raise ValueError("arr parameter should be instance of np.ndarray")
if not is_quantize(dtype) or arr_metadata["name"] != metadata.name:
raise ValueError("dtype parameter should be a {} dtype".format(dtype_str))
is_unsigned = metadata.is_unsigned
if is_unsigned:
if (
not is_quantize(dtype)
or dtype.metadata["mgb_dtype"]["name"] != dtype_meta.cname
):
raise ValueError("dtype parameter should be a {} dtype".format(dtype_meta))
arr_metadata = dtype.metadata["mgb_dtype"]
if dtype_meta.is_unsigned:
scale, zp = (
arr_metadata["scale"],
arr_metadata["zero_point"],
)
return (
(np.round(arr / scale) + zp)
.clip(metadata.qmin, metadata.qmax)
.clip(dtype_meta.qmin, dtype_meta.qmax)
.astype(dtype)
)
else:
# don't trick to combine with is_unsigned, seeing ``get_quantized_dtype``
scale = arr_metadata["scale"]
return np.round(arr / scale).clip(metadata.qmin, metadata.qmax).astype(dtype)
return (
np.round(arr / scale).clip(dtype_meta.qmin, dtype_meta.qmax).astype(dtype)
)
def _convert_from_quantized_dtype(arr: np.ndarray, dtype_str: str):
metadata = _metadata_dict[dtype_str]
arr_metadata = arr.dtype.metadata["mgb_dtype"]
def _convert_from_quantized_dtype(arr: np.ndarray, dtype_meta: QuantDtypeMeta):
if not isinstance(arr, np.ndarray):
raise ValueError("arr parameter should be instance of np.ndarray")
if not is_quantize(arr.dtype) or arr_metadata["name"] != metadata.name:
raise ValueError("arr's dtype should be a {} dtype".format(dtype_str))
is_unsigned = metadata.is_unsigned
if is_unsigned:
if (
not is_quantize(arr.dtype)
or arr.dtype.metadata["mgb_dtype"]["name"] != dtype_meta.cname
):
raise ValueError("arr's dtype should be a {} dtype".format(dtype_meta))
arr_metadata = arr.dtype.metadata["mgb_dtype"]
if dtype_meta.is_unsigned:
scale, zp = (
arr_metadata["scale"],
arr_metadata["zero_point"],
......@@ -187,7 +243,7 @@ def convert_to_quint8(arr: np.ndarray, q: np.dtype):
:param arr: Input ndarray.
:param q: Target data type, should be a quint8.
"""
return _convert_to_quantized_dtype(arr, q, "quint8")
return _convert_to_quantized_dtype(arr, q, _builtin_quant_dtypes["quint8"])
def convert_from_quint8(arr: np.ndarray):
......@@ -196,7 +252,7 @@ def convert_from_quint8(arr: np.ndarray):
:param arr: Input ndarray.
"""
return _convert_from_quantized_dtype(arr, "quint8")
return _convert_from_quantized_dtype(arr, _builtin_quant_dtypes["quint8"])
def convert_to_qint8(arr: np.ndarray, q: np.dtype):
......@@ -206,7 +262,7 @@ def convert_to_qint8(arr: np.ndarray, q: np.dtype):
:param arr: Input ndarray.
:param q: Target data type, should be a qint8.
"""
return _convert_to_quantized_dtype(arr, q, "qint8")
return _convert_to_quantized_dtype(arr, q, _builtin_quant_dtypes["qint8"])
def convert_from_qint8(arr: np.ndarray):
......@@ -215,7 +271,7 @@ def convert_from_qint8(arr: np.ndarray):
:param arr: Input ndarray.
"""
return _convert_from_quantized_dtype(arr, "qint8")
return _convert_from_quantized_dtype(arr, _builtin_quant_dtypes["qint8"])
def convert_to_qint32(arr: np.ndarray, q: np.dtype):
......@@ -225,7 +281,7 @@ def convert_to_qint32(arr: np.ndarray, q: np.dtype):
:param arr: Input ndarray.
:param q: Target data type, should be a qint8.
"""
return _convert_to_quantized_dtype(arr, q, "qint32")
return _convert_to_quantized_dtype(arr, q, _builtin_quant_dtypes["qint32"])
def convert_from_qint32(arr):
......@@ -234,7 +290,7 @@ def convert_from_qint32(arr):
:param arr: Input ndarray.
"""
return _convert_from_quantized_dtype(arr, "qint32")
return _convert_from_quantized_dtype(arr, _builtin_quant_dtypes["qint32"])
def convert_to_quint4(arr: np.ndarray, q: np.dtype):
......@@ -244,7 +300,7 @@ def convert_to_quint4(arr: np.ndarray, q: np.dtype):
:param arr: Input ndarray.
:param q: Target data type, should be a quint4.
"""
return _convert_to_quantized_dtype(arr, q, "quint4")
return _convert_to_quantized_dtype(arr, q, _builtin_quant_dtypes["quint4"])
def convert_from_quint4(arr: np.ndarray):
......@@ -253,7 +309,7 @@ def convert_from_quint4(arr: np.ndarray):
:param arr: Input ndarray.
"""
return _convert_from_quantized_dtype(arr, "quint4")
return _convert_from_quantized_dtype(arr, _builtin_quant_dtypes["quint4"])
def convert_to_qint4(arr: np.ndarray, q: np.dtype):
......@@ -263,7 +319,7 @@ def convert_to_qint4(arr: np.ndarray, q: np.dtype):
:param arr: Input ndarray.
:param q: Target data type, should be a qint4.
"""
return _convert_to_quantized_dtype(arr, q, "qint4")
return _convert_to_quantized_dtype(arr, q, _builtin_quant_dtypes["qint4"])
def convert_from_qint4(arr: np.ndarray):
......@@ -272,4 +328,4 @@ def convert_from_qint4(arr: np.ndarray):
:param arr: Input ndarray.
"""
return _convert_from_quantized_dtype(arr, "qint4")
return _convert_from_quantized_dtype(arr, _builtin_quant_dtypes["qint4"])
......@@ -203,7 +203,7 @@ def conv_transpose2d(
assert compute_mode == "DEFAULT" or compute_mode.name == "DEFAULT"
if groups != 1:
raise NotImplementedError("TODO")
raise NotImplementedError("group transposed conv2d is not supported yet.")
stride_h, stride_w = expand_hw(stride)
pad_h, pad_w = expand_hw(padding)
......
......@@ -13,7 +13,6 @@ import itertools
import json
import os
import typing
import warnings
import weakref
import numpy as np
......
......@@ -5,7 +5,6 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import warnings
from abc import ABCMeta, abstractmethod
from collections import OrderedDict
from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union
......@@ -204,10 +203,9 @@ class Module(metaclass=ABCMeta):
if "requires_grad" in kwargs:
del kwargs["requires_grad"]
warnings.warn(
logger.warning(
"Tensor currently has no requires_grad attribute "
"so requires_grad argument is ignored here",
DeprecationWarning,
"so requires_grad argument is ignored here"
)
def predicate(obj) -> bool:
......@@ -232,10 +230,9 @@ class Module(metaclass=ABCMeta):
if "requires_grad" in kwargs:
del kwargs["requires_grad"]
warnings.warn(
logger.warning(
"Tensor currently has no requires_grad attribute "
"so requires_grad argument is ignored here",
DeprecationWarning,
"so requires_grad argument is ignored here"
)
def predicate(obj) -> bool:
......
......@@ -7,7 +7,10 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from abc import abstractmethod
from ...quantization import FakeQuantize, Observer, QConfig
# avoid circular reference
from ...quantization.fake_quant import FakeQuantize
from ...quantization.observer import Observer
from ...quantization.qconfig import QConfig
from ...tensor import Tensor
from ..module import Module
......@@ -73,19 +76,19 @@ class QATModule(Module):
# do observer
if observer is None:
oup = target
q_dict = None
qparams = None
else:
oup = observer(target)
q_dict = observer.get_qparams()
qparams = observer.get_qparams()
# do fake quant
if fake_quant is not None:
oup = fake_quant(oup, q_dict)
oup = fake_quant(oup, qparams)
# use qparams of fake_quant if have.
if hasattr(fake_quant, "get_qparams"):
q_dict = fake_quant.get_qparams()
qparams = fake_quant.get_qparams()
# set to tensor qparams.
if q_dict is not None:
oup.q_dict.update(q_dict)
if qparams is not None:
oup.qparams.update(qparams)
return oup
def apply_quant_weight(self, target: Tensor):
......@@ -118,7 +121,7 @@ class QATModule(Module):
Get weight's quantization dtype as the method from ``qconfig``.
"""
return self._get_method_result(
"get_dtype", self.weight_fake_quant, self.weight_observer
"get_quantized_dtype", self.weight_fake_quant, self.weight_observer
)
def get_activation_dtype(self):
......@@ -126,7 +129,7 @@ class QATModule(Module):
Get activation's quantization dtype as the method from ``qconfig``.
"""
return self._get_method_result(
"get_dtype", self.act_fake_quant, self.act_observer
"get_quantized_dtype", self.act_fake_quant, self.act_observer
)
def get_weight_qparams(self):
......
......@@ -7,8 +7,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .fake_quant import FakeQuantize
from .internal_fake_quant import *
from .observer import HistogramObserver, Observer
from .observer import Observer
from .qconfig import (
QConfig,
calibration_qconfig,
......@@ -20,4 +19,15 @@ from .qconfig import (
sync_ema_fakequant_qconfig,
tqt_qconfig,
)
from .utils import QuantMode
from .quantize import (
apply_easy_quant,
disable_fake_quant,
disable_observer,
enable_fake_quant,
enable_observer,
propagate_qconfig,
quantize,
quantize_qat,
reset_qconfig,
)
from .utils import QParams, QuantMode, create_qparams
......@@ -6,40 +6,48 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import math
from typing import Union
from .. import functional as F
from ..core.tensor.dtype import _metadata_dict, get_quantized_dtype
from ..core.tensor.dtype import QuantDtypeMeta, _builtin_quant_dtypes
from ..logger import get_logger
from ..module import Module
from ..tensor import Parameter, Tensor
from .utils import QuantMode, fake_quant_tensor, get_qparam_dict, tqt_forward
from ..tensor import Parameter
from .utils import (
QParams,
QParamsModuleMixin,
QuantMode,
create_qparams,
fake_quant_tensor,
tqt_forward,
)
logger = get_logger(__name__)
class _FakeQuantize(Module):
r"""
A Basic Fake Quant module.
:param dtype: a string indicating the target quantization type of input.
:param narrow_range: whether the absolute value of ``qmin`` is the same as ``qmax``,
instead of 1 greater. Usually True for weight and False for activation.
:param enable: whether do ``normal_forward`` or ``fake_quant_forward``.
"""
class _FakeQuantize(Module):
def __init__(
self, dtype: str, narrow_range: bool = False, enable: bool = True, **kwargs
self, dtype: Union[str, QuantDtypeMeta], enable: bool = True, **kwargs
):
super().__init__()
if not dtype in _metadata_dict.keys():
raise ValueError(
"unknown dtype: {}, only support {}".format(
dtype, _metadata_dict.keys()
if isinstance(dtype, str):
if not dtype in _builtin_quant_dtypes:
raise ValueError(
"unknown dtype: {}, only support {}".format(
dtype, _builtin_quant_dtypes.keys()
)
)
dtype = _builtin_quant_dtypes[dtype]
if "narrow_range" in kwargs:
del kwargs["narrow_range"]
logger.warning(
"FakeQuantize currently has no narrow_range param "
"so it is ignored here",
exc_info=DeprecationWarning,
)
self.dtype = dtype
self.narrow_range = narrow_range
self.qmin = (
-_metadata_dict[dtype].qmax if narrow_range else _metadata_dict[dtype].qmin
)
self.qmax = _metadata_dict[dtype].qmax
self.qmin = dtype.qmin
self.qmax = dtype.qmax
self.enabled = enable
def enable(self):
......@@ -48,61 +56,64 @@ class _FakeQuantize(Module):
def disable(self):
self.enabled = False
def fake_quant_forward(self, inp, q_dict=None):
return inp
def fake_quant_forward(self, inp, qparams: QParams = None):
raise NotImplementedError
def normal_foward(self, inp, q_dict=None):
def normal_foward(self, inp, qparams: QParams = None):
return inp
def forward(self, inp, q_dict=None):
def forward(self, inp, qparams: QParams = None):
if self.enabled:
return self.fake_quant_forward(inp, q_dict=q_dict)
return self.fake_quant_forward(inp, qparams=qparams)
else:
return self.normal_foward(inp, q_dict=q_dict)
return self.normal_foward(inp, qparams=qparams)
class TQT(_FakeQuantize):
class TQT(_FakeQuantize, QParamsModuleMixin):
r"""
TQT: https://arxiv.org/abs/1903.08066 Trained Quantization Thresholds
for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks.
:param dtype: a string or :class:`~.QuantDtypeMeta` indicating the target
quantization dtype of input.
:param enable: whether do ``normal_forward`` or ``fake_quant_forward``.
"""
def __init__(
self, dtype: str, narrow_range: bool = False, enable: bool = True, **kwargs
self, dtype: Union[str, QuantDtypeMeta], enable: bool = True, **kwargs
):
super().__init__(dtype, narrow_range, enable, **kwargs)
super().__init__(dtype, enable, **kwargs)
self.scale = Parameter(0.0, dtype="float32")
def fake_quant_forward(self, inp, q_dict=None):
def fake_quant_forward(self, inp, qparams: QParams = None):
# when enable, TQT will do fakequant forward, finetune the scale
return tqt_forward(self.qmin, self.qmax, inp, self.scale)
def get_qparams(self):
q_dict = get_qparam_dict(QuantMode.SYMMERTIC)
q_dict["scale"] = 2 ** self.scale.detach()
return q_dict
def set_qparams(self, q_dict):
def set_qparams(self, qparams: QParams):
assert (
q_dict["mode"] == QuantMode.SYMMERTIC
qparams.mode == QuantMode.SYMMERTIC
), "only symmetric quantization is supported by TQT"
if "scale" not in q_dict or q_dict["scale"] is None:
if qparams.scale is None:
raise AssertionError("Can not get an initialized scale")
self.scale._reset(F.log(q_dict["scale"]) / math.log(2))
self.scale[...] = F.log(qparams.scale) / math.log(2)
def get_dtype(self):
q_dict = self.get_qparams()
scale = None if "scale" not in q_dict else q_dict["scale"].numpy()
zero_point = (
None if "zero_point" not in q_dict else q_dict["zero_point"].numpy()
)
return get_quantized_dtype(self.dtype, scale, zero_point)
def get_qparams(self):
return create_qparams(QuantMode.SYMMERTIC, self.dtype, scale=2 ** self.scale)
class FakeQuantize(_FakeQuantize):
r"""
A module to do quant and dequant according to observer's scale and zero_point.
:param dtype: a string or :class:`~.QuantDtypeMeta` indicating the target
quantization dtype of input.
:param enable: whether do ``normal_forward`` or ``fake_quant_forward``.
"""
def fake_quant_forward(self, inp, q_dict=None):
return fake_quant_tensor(inp, self.qmin, self.qmax, q_dict)
def fake_quant_forward(self, inp, qparams: QParams = None):
assert (
qparams.dtype_meta is self.dtype
), "input qparams' dtype is not equal to self.dtype.\nqparams.dtype_meta={}\nself.dtype={}".format(
qparams.dtype_meta, self.dtype
)
return fake_quant_tensor(inp, qparams)
......@@ -16,4 +16,6 @@ from ..autodiff import Function
from .fake_quant import _FakeQuantize
from .observer import MinMaxObserver
from .qconfig import QConfig
from .utils import QParams
......@@ -8,51 +8,51 @@
import math
from abc import abstractmethod
from copy import deepcopy
from typing import Union
import numpy as np
from .. import functional as F
from ..core.tensor.dtype import _metadata_dict, get_quantized_dtype
from ..core.tensor.dtype import QuantDtypeMeta, _builtin_quant_dtypes
from ..distributed import WORLD, get_rank, is_distributed
from ..functional.distributed import all_reduce_max, all_reduce_min
from ..logger import get_logger
from ..module import Module
from ..tensor import Tensor
from .utils import QuantMode, get_qparam_dict
from .utils import QParams, QParamsModuleMixin, QuantMode, create_qparams
logger = get_logger(__name__)
class Observer(Module):
class Observer(Module, QParamsModuleMixin):
r"""
A base class for Observer Module.
:param dtype: a string indicating to collect scale and zero_point of which dtype.
:param narrow_range: whether the absolute value of ``qmin`` is the same as ``qmax``,
instead of 1 greater. Usually True for weight and False for activation.
"""
def __init__(self, dtype: str, narrow_range: bool = False, **kwargs):
def __init__(self, dtype: Union[str, QuantDtypeMeta], **kwargs):
super().__init__()
if dtype not in _metadata_dict.keys():
raise ValueError(
"unknown dtype: {}, only support {}".format(
dtype, _metadata_dict.keys()
if isinstance(dtype, str):
if not dtype in _builtin_quant_dtypes:
raise ValueError(
"unknown dtype: {}, only support {}".format(
dtype, _builtin_quant_dtypes.keys()
)
)
dtype = _builtin_quant_dtypes[dtype]
if "narrow_range" in kwargs:
del kwargs["narrow_range"]
logger.warning(
"FakeQuantize currently has no narrow_range param "
"so it is ignored here",
exc_info=DeprecationWarning,
)
self.dtype = dtype
self.narrow_range = narrow_range
self.qmin = (
-_metadata_dict[dtype].qmax if narrow_range else _metadata_dict[dtype].qmin
)
self.qmax = _metadata_dict[dtype].qmax
self.qmin = dtype.qmin
self.qmax = dtype.qmax
self.enabled = True
def get_dtype(self):
q_dict = self.get_qparams()
numpy_scale = None if "scale" not in q_dict else q_dict["scale"].numpy()
numpy_zero_point = (
None if "zero_point" not in q_dict else q_dict["zero_point"].numpy()
)
return get_quantized_dtype(self.dtype, numpy_scale, numpy_zero_point)
def enable(self):
self.enabled = True
......@@ -70,21 +70,16 @@ class Observer(Module):
def forward(self, x):
pass
@abstractmethod
def get_qparams(self, **kwargs):
pass
class MinMaxObserver(Observer):
def __init__(
self,
mode=QuantMode.SYMMERTIC,
eps=0.00001,
dtype="qint8",
narrow_range: bool = False,
mode: QuantMode = QuantMode.SYMMERTIC,
eps: float = 0.00001,
dtype: Union[str, QuantDtypeMeta] = "qint8",
**kwargs
):
super().__init__(dtype, narrow_range, **kwargs)
super().__init__(dtype, **kwargs)
self.mode = mode
self.min_val = Tensor(np.finfo(np.float32).max, dtype=np.float32)
self.max_val = Tensor(np.finfo(np.float32).min, dtype=np.float32)
......@@ -93,26 +88,22 @@ class MinMaxObserver(Observer):
def _calculate_qparams(self, inp_min_val, inp_max_val):
min_val = F.minimum(0.0, inp_min_val)
max_val = F.maximum(0.0, inp_max_val)
q_dict = get_qparam_dict(self.mode)
q_dict["min_val"] = inp_min_val
q_dict["max_val"] = inp_max_val
q_dict["enable_observer"] = self.enable
if self.mode == QuantMode.SYMMERTIC:
symmetric_max_vals = F.maximum(-min_val, max_val)
# use maximun to avoid scale too small at the begin
q_dict["scale"] = F.maximum(
scale = F.maximum(
symmetric_max_vals / ((self.qmax - self.qmin) / 2), self.scale_limit
)
# zero_point = self.zero_point
zero_point = None
else:
# use maximun to avoid scale too small at the begin
q_dict["scale"] = F.maximum(
scale = F.maximum(
(max_val - min_val) / (self.qmax - self.qmin), self.scale_limit
)
# caculate zero_point
q_dict["zero_point"] = self.qmin - F.round(min_val / q_dict["scale"])
zero_point = self.qmin - F.round((min_val / scale))
return q_dict
return create_qparams(self.mode, self.dtype, scale=scale, zero_point=zero_point)
def get_qparams(self):
return self._calculate_qparams(self.min_val, self.max_val)
......@@ -122,8 +113,8 @@ class MinMaxObserver(Observer):
# stop gradient
x = x_orig.detach()
# find max and min
self.min_val._reset(F.minimum(self.min_val, x.min()))
self.max_val._reset(F.maximum(self.max_val, x.max()))
self.min_val[...] = F.minimum(self.min_val, x.min())
self.max_val[...] = F.maximum(self.max_val, x.max())
return x_orig
......@@ -137,42 +128,43 @@ class SyncMinMaxObserver(MinMaxObserver):
else:
min_x = x.min()
max_x = x.max()
self.min_val._reset(F.minimum(self.min_val, min_x))
self.max_val._reset(F.maximum(self.max_val, max_x))
self.min_val[...] = F.minimum(self.min_val, min_x)
self.max_val[...] = F.maximum(self.max_val, max_x)
return x_orig
class ExponentialMovingAverageObserver(MinMaxObserver):
def __init__(
self,
momentum=0.9,
mode=QuantMode.SYMMERTIC,
eps=0.00001,
dtype="qint8",
narrow_range: bool = False,
momentum: float = 0.9,
mode: QuantMode = QuantMode.SYMMERTIC,
eps: float = 0.00001,
dtype: Union[str, QuantDtypeMeta] = "qint8",
**kwargs
):
super().__init__(mode, eps, dtype, narrow_range, **kwargs)
super().__init__(mode, eps, dtype, **kwargs)
self.momentum = Tensor(momentum, dtype="float32")
# used to avoid if-clauses in the first forward which is not supported
# in trace mode.
self.runtime_momentum = Tensor(0.0)
def set_momentum(self, momentum):
self.momentum = Tenosr(momentum, dtype="float32")
self.momentum = Tensor(momentum, dtype="float32")
def forward(self, x_orig):
if self.enabled:
# stop gradient
x = x_orig.detach()
# Exponential Moving Average
self.min_val._reset(
self.min_val[...] = (
self.min_val * self.runtime_momentum
+ (1 - self.runtime_momentum) * x.min()
)
self.max_val._reset(
self.max_val[...] = (
self.max_val * self.runtime_momentum
+ (1 - self.runtime_momentum) * x.max()
)
self.runtime_momentum = self.momentum
self.runtime_momentum[...] = self.momentum
return x_orig
......@@ -187,33 +179,34 @@ class SyncExponentialMovingAverageObserver(ExponentialMovingAverageObserver):
else:
min_x = x.min()
max_x = x.max()
self.min_val._reset(
self.min_val[...] = (
self.min_val * self.runtime_momentum
+ (1 - self.runtime_momentum) * min_x
)
self.max_val._reset(
self.max_val[...] = (
self.max_val * self.runtime_momentum
+ (1 - self.runtime_momentum) * max_x
)
self.runtime_momentum = self.momentum
self.runtime_momentum[...] = self.momentum
return x_orig
class HistogramObserver(MinMaxObserver):
def __init__(
self,
bins=2048,
upsample_rate=128,
mode=QuantMode.SYMMERTIC,
eps=0.00001,
dtype="qint8",
narrow_range: bool = False,
bins: int = 2048,
upsample_rate: int = 128,
mode: QuantMode = QuantMode.SYMMERTIC,
eps: float = 0.00001,
dtype: Union[str, QuantDtypeMeta] = "qint8",
**kwargs
):
super().__init__(mode, eps, dtype, narrow_range, **kwargs)
super().__init__(mode, eps, dtype, **kwargs)
self.bins = bins
self.upsample_rate = upsample_rate
self.dst_nbins = _metadata_dict[dtype].qmax - _metadata_dict[dtype].qmin + 1
self.dst_nbins = (
_builtin_quant_dtypes[dtype].qmax - _builtin_quant_dtypes[dtype].qmin + 1
)
self.histogram = Tensor([-1] + [0.0] * (bins - 1), dtype="float32")
def _non_linear_param_search(self):
......@@ -450,34 +443,45 @@ class HistogramObserver(MinMaxObserver):
class PassiveObserver(Observer):
r"""
This class can be set :attr:`scale` derectly.
An Observer that supports setting :attr:`scale` directly.
"""
def __init__(self, dtype: str, narrow_range: bool = False, **kwargs):
super().__init__(dtype, narrow_range, **kwargs)
self.q_dict = None
def __init__(self, dtype: Union[str, QuantDtypeMeta], **kwargs):
super().__init__(dtype, **kwargs)
self.qparams = None
self.orig_scale = None
@property
def scale(self):
return self.q_dict["scale"]
return self.qparams.scale
@scale.setter
def scale(self, value):
assert value > 0
self.q_dict["scale"][...] = Tensor(value)
def scale(self, value: np.ndarray):
assert np.all(value > 0)
self.qparams.scale[...] = Tensor(value)
def get_qparams(self):
return self.q_dict
return self.qparams
def set_qparams(self, q_dict):
self.q_dict = deepcopy(q_dict)
if "scale" not in q_dict or q_dict["scale"] is None:
def set_qparams(self, qparams: QParams):
"""
:param qparams: used to set initial scale.
"""
self.qparams = deepcopy(qparams)
if qparams.scale is None:
raise AssertionError("Can not get an initialized scale")
self.orig_scale = q_dict["scale"].numpy()
if qparams.dtype_meta is None:
qparams.dtype_meta = self.dtype
else:
assert (
qparams.dtype_meta is self.dtype
), "input qparams' dtype is not equal to self.dtype.\nqparams.dtype_meta={}\nself.dtype={}".format(
qparams.dtype_meta, self.dtype
)
self.orig_scale = qparams.scale.numpy()
def forward(self, x):
r"""
Just return input because :attr:`q_dict` is set by :func:`~.apply_easy_quant`.
Just return input because :attr:`qparams` is set by :func:`~.apply_easy_quant`.
"""
return x
......@@ -5,6 +5,7 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from collections import namedtuple
from functools import partial
from ..module import Module
......@@ -19,7 +20,13 @@ from .observer import (
)
class QConfig:
# use namedtuple to make class immutable, comparable and easy to print
class QConfig(
namedtuple(
"QConfig",
["weight_observer", "act_observer", "weight_fake_quant", "act_fake_quant"],
)
):
r"""
A config class indicating how to do quantize toward :class:`~.QATModule`'s
``activation`` and ``weight``. See :meth:`~.QATModule.set_qconfig` for detail usage.
......@@ -37,90 +44,66 @@ class QConfig:
# Default EMA QConfig for QAT.
ema_fakequant_qconfig = QConfig(
weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True),
act_observer=partial(ExponentialMovingAverageObserver, dtype="qint8", narrow_range=False),
weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True),
act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False),
weight_observer=partial(MinMaxObserver, dtype="qint8_narrow"),
act_observer=partial(ExponentialMovingAverageObserver, dtype="qint8"),
weight_fake_quant=partial(FakeQuantize, dtype="qint8_narrow"),
act_fake_quant=partial(FakeQuantize, dtype="qint8"),
)
Each parameter is a ``class`` rather than an instance. And we recommand using ``functools.partial``
to add initialization parameters of the ``class``, so that don't need to provide parameters in
:meth:`~.QATModule.set_qconfig`.
Usually we set ``narrow_range`` of weight related paramters to ``True`` and of activation related
parameters to ``False``. For the result of multiplication and addition as ``a * b + c * d``, if
four variables are all -128 of dtype ``qint8``, then the result will be ``2^15`` and cause overflow.
Weights are commonly calculated in this way, so needed to narrow the range.
Usually we choose narrow version dtype (like ``qint8_narrow``) for weight related
paramters and normal version for activation related ones. For the result of
multiplication and addition as ``a * b + c * d``, if four variables are all -128 of
dtype ``qint8``, then the result will be ``2^15`` and cause overflow.
Weights are commonly calculated in this way, so need to narrow qmin to -127.
"""
def __init__(
self, weight_observer, act_observer, weight_fake_quant, act_fake_quant
):
def __new__(cls, weight_observer, act_observer, weight_fake_quant, act_fake_quant):
if isinstance(act_observer, Module) or isinstance(weight_observer, Module):
raise ValueError(
"QConfig must not receive observer instance, please pass observer"
" class generator using `partial(Observer, ...)` instead. Use"
" partial(MyObserver, x=1) to override arguments to constructor if needed"
)
self.weight_observer = weight_observer
self.act_observer = act_observer
self.weight_fake_quant = weight_fake_quant
self.act_fake_quant = act_fake_quant
def __eq__(self, other):
def eq(a, b):
if isinstance(a, partial) and isinstance(b, partial):
return all(
[a.func == b.func, a.args == b.args, a.keywords == b.keywords]
)
else:
return a == b
return (
eq(self.weight_observer, other.weight_observer)
and eq(self.act_observer, other.act_observer)
and eq(self.weight_fake_quant, other.weight_fake_quant)
and eq(self.act_fake_quant, other.act_fake_quant)
return super().__new__(
cls, weight_observer, act_observer, weight_fake_quant, act_fake_quant
)
min_max_fakequant_qconfig = QConfig(
weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True),
act_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=False),
weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True),
act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False),
weight_observer=partial(MinMaxObserver, dtype="qint8_narrow"),
act_observer=partial(MinMaxObserver, dtype="qint8"),
weight_fake_quant=partial(FakeQuantize, dtype="qint8_narrow"),
act_fake_quant=partial(FakeQuantize, dtype="qint8"),
)
ema_fakequant_qconfig = QConfig(
weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True),
act_observer=partial(
ExponentialMovingAverageObserver, dtype="qint8", narrow_range=False
),
weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True),
act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False),
weight_observer=partial(MinMaxObserver, dtype="qint8_narrow"),
act_observer=partial(ExponentialMovingAverageObserver, dtype="qint8"),
weight_fake_quant=partial(FakeQuantize, dtype="qint8_narrow"),
act_fake_quant=partial(FakeQuantize, dtype="qint8"),
)
sync_ema_fakequant_qconfig = QConfig(
weight_observer=partial(SyncMinMaxObserver, dtype="qint8", narrow_range=True),
act_observer=partial(
SyncExponentialMovingAverageObserver, dtype="qint8", narrow_range=False
),
weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True),
act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False),
weight_observer=partial(SyncMinMaxObserver, dtype="qint8_narrow"),
act_observer=partial(SyncExponentialMovingAverageObserver, dtype="qint8"),
weight_fake_quant=partial(FakeQuantize, dtype="qint8_narrow"),
act_fake_quant=partial(FakeQuantize, dtype="qint8"),
)
ema_lowbit_fakequant_qconfig = QConfig(
weight_observer=partial(MinMaxObserver, dtype="qint4", narrow_range=False),
act_observer=partial(
ExponentialMovingAverageObserver, dtype="qint4", narrow_range=False
),
weight_fake_quant=partial(FakeQuantize, dtype="qint4", narrow_range=False),
act_fake_quant=partial(FakeQuantize, dtype="qint4", narrow_range=False),
weight_observer=partial(MinMaxObserver, dtype="qint4"),
act_observer=partial(ExponentialMovingAverageObserver, dtype="qint4"),
weight_fake_quant=partial(FakeQuantize, dtype="qint4"),
act_fake_quant=partial(FakeQuantize, dtype="qint4"),
)
calibration_qconfig = QConfig(
weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True),
act_observer=partial(HistogramObserver, dtype="qint8", narrow_range=False),
weight_observer=partial(MinMaxObserver, dtype="qint8_narrow"),
act_observer=partial(HistogramObserver, dtype="qint8"),
weight_fake_quant=None,
act_fake_quant=None,
)
......@@ -128,15 +111,15 @@ calibration_qconfig = QConfig(
tqt_qconfig = QConfig(
weight_observer=None,
act_observer=None,
weight_fake_quant=partial(TQT, dtype="qint8", narrow_range=True),
act_fake_quant=partial(TQT, dtype="qint8", narrow_range=False),
weight_fake_quant=partial(TQT, dtype="qint8_narrow"),
act_fake_quant=partial(TQT, dtype="qint8"),
)
passive_qconfig = QConfig(
weight_observer=partial(PassiveObserver, dtype="qint8", narrow_range=True),
act_observer=partial(PassiveObserver, dtype="qint8", narrow_range=False),
weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True),
act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False),
weight_observer=partial(PassiveObserver, dtype="qint8_narrow"),
act_observer=partial(PassiveObserver, dtype="qint8"),
weight_fake_quant=partial(FakeQuantize, dtype="qint8_narrow"),
act_fake_quant=partial(FakeQuantize, dtype="qint8"),
)
easyquant_qconfig = passive_qconfig
......@@ -18,6 +18,7 @@ from ..module import qat as QAT
from ..module import quantized as Quantized
from ..module.qat import QATModule
from ..module.quantized import QuantizedModule
from ..tensor import Tensor
from .qconfig import QConfig, ema_fakequant_qconfig
......@@ -147,10 +148,10 @@ def reset_qconfig(module: Module, qconfig: QConfig, inplace: bool = True):
if not inplace:
module = deepcopy(module)
def safe_call(func, q_dict):
def safe_call(func, qparams):
inst = func() if func is not None else None
if inst is not None and getattr(inst, "set_qparams", None) is not None:
inst.set_qparams(q_dict)
inst.set_qparams(qparams)
return inst
def is_qat(mod: Module):
......@@ -158,13 +159,13 @@ def reset_qconfig(module: Module, qconfig: QConfig, inplace: bool = True):
for m in list(module._flatten(predicate=is_qat)):
if m.with_weight:
weight_q_dict = m.get_weight_qparams()
m.weight_observer = safe_call(qconfig.weight_observer, weight_q_dict)
m.weight_fake_quant = safe_call(qconfig.weight_fake_quant, weight_q_dict)
weight_params = m.get_weight_qparams()
m.weight_observer = safe_call(qconfig.weight_observer, weight_params)
m.weight_fake_quant = safe_call(qconfig.weight_fake_quant, weight_params)
if m.with_act:
act_q_dict = m.get_activation_qparams()
m.act_observer = safe_call(qconfig.act_observer, act_q_dict)
m.act_fake_quant = safe_call(qconfig.act_fake_quant, act_q_dict)
act_params = m.get_activation_qparams()
m.act_observer = safe_call(qconfig.act_observer, act_params)
m.act_fake_quant = safe_call(qconfig.act_fake_quant, act_params)
return module
......@@ -202,7 +203,9 @@ def hook_qat_module(module: Module, func: Callable):
return hooks
def apply_easy_quant(module, data, start=0.8, stop=1.2, num=40):
def apply_easy_quant(
module: Module, data: Tensor, start: float = 0.8, stop: float = 1.2, num: int = 40
):
r"""
Implementation of ``EasyQuant``: https://arxiv.org/pdf/2006.16669.
Search for optimal scales.
......
......@@ -5,9 +5,10 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import abc
from enum import Enum
from functools import partial, update_wrapper, wraps
from typing import Dict
from typing import Union
import numpy as np
......@@ -15,7 +16,11 @@ from .. import functional as F
from ..autodiff import Function
from ..core._imperative_rt.core2 import apply
from ..core.ops import builtin
from ..core.tensor.dtype import _metadata_dict
from ..core.tensor.dtype import (
QuantDtypeMeta,
_builtin_quant_dtypes,
create_quantized_dtype,
)
from ..tensor import Tensor
......@@ -61,37 +66,100 @@ class QuantMode(Enum):
ASYMMERTIC = 2
qparam_dict = {
QuantMode.SYMMERTIC: {"mode": QuantMode.SYMMERTIC, "scale": None},
QuantMode.ASYMMERTIC: {
"mode": QuantMode.ASYMMERTIC,
"scale": None,
"zero_point": None,
},
class QParams:
"""
To standardize FakeQuant, Observer and Tensor's qparams format. If custom
qparams is needed, inherit this class and add custom ``__slots__``.
"""
__slots__ = "mode", "dtype_meta", "scale", "zero_point"
def __init__(
self,
mode: QuantMode,
dtype_meta: QuantDtypeMeta,
scale: Tensor,
zero_point: Tensor,
):
self.mode = mode
self.dtype_meta = dtype_meta
self.scale = scale
self.zero_point = zero_point
def update(self, qparams: "QParams"):
for key in self.__slots__:
setattr(self, key, getattr(qparams, key))
def __eq__(self, other):
if len(self.__slots__) != len(other.__slots__):
return False
for key in self.__slots__:
if not hasattr(other, key) or getattr(self, key) != getattr(other, key):
return False
return True
def __repr__(self):
content = ", ".join(
["{}={}".format(key, getattr(self, key)) for key in self.__slots__]
)
return "QParams({})".format(content)
class QParamsModuleMixin(abc.ABC):
def get_quantized_dtype(self):
qparams = self.get_qparams()
dtype = qparams.dtype_meta
scale = float(qparams.scale.numpy()) if qparams.scale is not None else None
zero_point = (
int(qparams.zero_point.numpy()) if qparams.zero_point is not None else None
)
return create_quantized_dtype(dtype, scale, zero_point)
@abc.abstractmethod
def get_qparams(self) -> QParams:
pass
_builtin_qparams = {
QuantMode.SYMMERTIC: partial(QParams, mode=QuantMode.SYMMERTIC),
QuantMode.ASYMMERTIC: partial(QParams, mode=QuantMode.ASYMMERTIC),
}
def get_qparam_dict(mode: QuantMode):
def create_qparams(
mode: QuantMode = QuantMode.SYMMERTIC,
dtype_meta: Union[str, QuantDtypeMeta] = None,
scale: Tensor = None,
zero_point: Tensor = None,
):
"""
Return the quantization parameters dictionary according to the mode.
Return :class:`~.QParams` according to the mode.
"""
return qparam_dict.get(mode, None)
if isinstance(dtype_meta, str):
dtype_meta = _builtin_quant_dtypes[dtype_meta]
if mode is None:
return QParams(mode, dtype_meta, scale, zero_point)
assert isinstance(mode, QuantMode)
return _builtin_qparams[mode](
dtype_meta=dtype_meta, scale=scale, zero_point=zero_point
)
def fake_quant_tensor(inp: Tensor, qmin: int, qmax: int, q_dict: Dict) -> Tensor:
def fake_quant_tensor(inp: Tensor, qparams: QParams) -> Tensor:
"""
Apply fake quantization to the inp tensor.
:param inp: the input tensor which need to be faked.
:param qmin: the minimum value which the integer limit to.
:param qmax: the maximum value which the integer limit to.
:param q_dict: the quantization parameter dict.
:param qparams: to get mode, qmin, qmax, scale and zero_point from.
"""
scale = q_dict["scale"]
zero_point = Tensor([0.0], dtype=np.float32)
if q_dict["mode"] == QuantMode.ASYMMERTIC:
zero_point = q_dict["zero_point"]
scale = qparams.scale
if qparams.mode == QuantMode.ASYMMERTIC:
zero_point = qparams.zero_point
else:
zero_point = Tensor([0.0], dtype=np.float32)
qmin = qparams.dtype_meta.qmin
qmax = qparams.dtype_meta.qmax
op = builtin.FakeQuant(qmin=qmin, qmax=qmax)
return apply(op, inp, scale, zero_point)[0]
......@@ -104,22 +172,34 @@ def fake_quant_bias(bias: Tensor, inp: Tensor, w_qat: Tensor) -> Tensor:
:param bias: the bias tensor which need to be faked.
:param inp: the input tensor which contain the quantization parameters.
:param qmax: the weight tensor which contain the quantization parameters.
:param w_qat: the weight tensor which contain the quantization parameters.
.. warning::
Only work for symmetric quantization method now.
"""
b_qat = bias
if hasattr(inp, "q_dict") and b_qat is not None:
if inp.q_dict["scale"] is not None and w_qat.q_dict["scale"] is not None:
# use the same mode with weight.
b_dict = get_qparam_dict(w_qat.q_dict["mode"])
b_dict["scale"] = inp.q_dict["scale"] * w_qat.q_dict["scale"]
# TODO: add zero_point for ASYMMERTIC mode.
qmax = _metadata_dict["qint32"].qmax
qmin = _metadata_dict["qint32"].qmin
b_qat = fake_quant_tensor(b_qat, qmin, qmax, b_dict)
b_qat.q_dict.update(b_dict)
if (
getattr(inp, "qparams", None) is not None
and getattr(w_qat, "qparams", None) is not None
and bias is not None
):
inp_params = inp.qparams
w_params = w_qat.qparams
if inp_params.scale is not None and w_params.scale is not None:
assert inp_params.mode == w_params.mode, "incompatible QuantMode"
# TODO: support quint8 dtype.
assert (
inp_params.dtype_meta.np_dtype_str == "int8"
and w_params.dtype_meta.np_dtype_str == "int8"
), "fake_quant_bias only support int8 like dtype now"
# use the same mode with weight.
# TODO: avoid hardcode
b_dtype = _builtin_quant_dtypes["qint32"]
b_param = create_qparams(
w_params.mode, b_dtype, scale=inp_params.scale * w_params.scale
)
b_qat = fake_quant_tensor(bias, b_param)
b_qat.qparams.update(b_param)
return b_qat
......@@ -22,6 +22,8 @@ from .logger import get_logger
from .utils.deprecation import deprecated
from .utils.naming import auto_naming
logger = get_logger(__name__)
class Tensor(_Tensor, ArrayMethodMixin):
r"""
......@@ -30,7 +32,7 @@ class Tensor(_Tensor, ArrayMethodMixin):
grad = None
dmap_callback = None
_q_dict = None
_qparams = None
def __new__(
cls, data, dtype=None, device=None, is_const=False, no_cache=False, name=None
......@@ -50,7 +52,7 @@ class Tensor(_Tensor, ArrayMethodMixin):
if isinstance(data, _Tensor):
if dtype is not None:
get_logger().warning(
logger.warning(
"dtype does not work when creating a new Tensor with another Tensor"
)
obj = _Tensor.__new__(cls, data)
......@@ -101,10 +103,12 @@ class Tensor(_Tensor, ArrayMethodMixin):
return super().dtype
@property
def q_dict(self):
if self._q_dict is None:
self._q_dict = {"mode": None, "scale": None, "zero_point": None}
return self._q_dict
def qparams(self):
from .quantization.utils import create_qparams # pylint: disable=all
if self._qparams is None:
self._qparams = create_qparams()
return self._qparams
def numpy(self) -> np.ndarray:
r"""
......@@ -185,14 +189,29 @@ class Tensor(_Tensor, ArrayMethodMixin):
def __getstate__(self):
r""" __getstate__ will be called for pickle serialization or deep copy
"""
state = {
"qdict": self.q_dict,
"numpy": self.numpy(),
"dtype": self.dtype,
"device": self.device.logical_name,
}
if self._qparams is not None:
state["qparams"] = self._qparams
return state
def __setstate__(self, state):
self._q_dict = state.pop("qdict")
from .quantization.utils import create_qparams # pylint: disable=all
if "qdict" in state:
qparams = state.pop("qdict")
logger.warning(
"Tensor's 'qdict' state is depreciated. Use 'qparams' instead"
)
elif "qparams" in state:
qparams = state.pop("qparams")
else:
qparams = None
self._reset(Tensor(state.pop("numpy"), state.pop("dtype"), state.pop("device")))
self._qparams = qparams
tensor = Tensor
......
......@@ -14,7 +14,7 @@ import pytest
import megengine.core.tensor.megbrain_graph as G
from megengine.core.ops import builtin as ops
from megengine.core.tensor.dtype import (
_metadata_dict,
_builtin_quant_dtypes,
convert_from_qint4,
convert_from_qint8,
convert_from_quint4,
......@@ -76,10 +76,10 @@ def _get_compiled_result(inp, dtype, shape, device, calc_func=None):
def _check_result_attr(oup, dtype, dtype_str, is_unsigned=True):
metadata = _metadata_dict[dtype_str]
metadata = _builtin_quant_dtypes[dtype_str]
assert "mgb_dtype" in oup.dtype.metadata
assert is_quantize(oup.dtype)
np.testing.assert_equal(oup.dtype.metadata["mgb_dtype"]["name"], metadata.name)
np.testing.assert_equal(oup.dtype.metadata["mgb_dtype"]["name"], metadata.cname)
np.testing.assert_allclose(get_scale(oup.dtype), get_scale(dtype))
if is_unsigned:
np.testing.assert_equal(get_zero_point(oup.dtype), get_zero_point(dtype))
......
......@@ -65,9 +65,9 @@ def test_tensor_serialization():
with TemporaryFile() as f:
a = Tensor(0)
a.q_dict["scale"] = Tensor(1.0)
a.qparams.scale = Tensor(1.0)
pickle.dump(a, f)
f.seek(0)
b = pickle.load(f)
assert isinstance(b.q_dict["scale"], Tensor)
np.testing.assert_equal(b.q_dict["scale"].numpy(), 1.0)
assert isinstance(b.qparams.scale, Tensor)
np.testing.assert_equal(b.qparams.scale.numpy(), 1.0)
......@@ -6,6 +6,8 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import copy
import numpy as np
from megengine.core.tensor.dtype import get_scale, get_zero_point, qint8, quint8
......@@ -86,3 +88,23 @@ def test_as_type():
b = a.astype(quint8(0.3, 128))
np.testing.assert_almost_equal(get_scale(b.dtype), 0.3)
np.testing.assert_equal(get_zero_point(b.dtype), 128)
def test_qparams():
x = Tensor(1)
assert x.qparams.scale is None
x.qparams.scale = Tensor(1.0)
assert x.qparams.scale.numpy() == 1.0
x2 = copy.copy(x)
assert x.qparams is x2.qparams and x2.qparams.scale.numpy() == 1.0
x3 = copy.deepcopy(x)
assert x.qparams is not x3.qparams and x3.qparams.scale.numpy() == 1.0
def test_name():
x = Tensor(0)
assert x.name == ""
x.name = "x"
assert x.name == "x"
x = Tensor(0, name="x")
assert x.name == "x"
......@@ -406,28 +406,3 @@ def test_copy_d2h():
def test_copy_d2d():
copy_test("gpu0", "gpu1")
copy_test("gpu0:0", "gpu0:1")
def test_name():
x = tensor(0)
assert x.name == ""
x.name = "x"
assert x.name == "x"
x = tensor(0, name="x")
assert x.name == "x"
def test_q_dict():
x = tensor(1)
assert x.q_dict["scale"] is None
x.q_dict["scale"] = tensor(1.0)
y = tensor(1)
assert y.q_dict["scale"] is None
y.q_dict["scale"] = tensor(2.0)
assert x.q_dict["scale"].numpy() == 1.0
assert y.q_dict["scale"].numpy() == 2.0
z = x + y
assert z.q_dict["scale"] is None
......@@ -12,9 +12,15 @@ import pytest
import megengine as mge
from megengine import tensor
from megengine.core.autodiff.grad import Function, Grad
from megengine.core.tensor.dtype import QuantDtypeMeta
from megengine.core.tensor.utils import make_shape_tuple
from megengine.quantization.internal_fake_quant import *
from megengine.quantization.utils import QuantMode, fake_quant_tensor, tqt_forward
from megengine.quantization.utils import (
QuantMode,
create_qparams,
fake_quant_tensor,
tqt_forward,
)
class TQT_numpy:
......@@ -111,16 +117,14 @@ def fake_quant_tensor_gt(inp, scale, zero_point, qmin, qmax):
def test_fakequant():
qmin = -126
qmax = 129
test_dtype = QuantDtypeMeta("test_qint8", None, "int8", qmin, qmax)
def run(zero_point, scale):
q_dict = {}
q_dict["mode"] = QuantMode.ASYMMERTIC
q_dict["scale"] = scale
q_dict["zero_point"] = zero_point
qparams = create_qparams(QuantMode.ASYMMERTIC, test_dtype, scale, zero_point)
inp_data = np.random.uniform(low=-512.0, high=512.0, size=(1, 32, 32, 32))
inp = tensor(inp_data, dtype=np.float32)
# test forward
oup = fake_quant_tensor(inp, qmin, qmax, q_dict).numpy()
oup = fake_quant_tensor(inp, qparams).numpy()
oup_gt = fake_quant_tensor_gt(inp, scale, zero_point, qmin, qmax).numpy()
assert np.allclose(oup, oup_gt)
assert oup.shape == oup_gt.shape
......@@ -128,7 +132,7 @@ def test_fakequant():
# test backward
x = tensor(inp_data, dtype=np.float32)
grad = Grad().wrt(x, callback=_save_to(x))
y = fake_quant_tensor(x, qmin, qmax, q_dict)
y = fake_quant_tensor(x, qparams)
grad(y, tensor(F.ones_like(x)))
x1 = tensor(inp_data, dtype=np.float32)
......
......@@ -10,7 +10,13 @@ import megengine.module.qat as QAT
import megengine.module.quantized as Q
from megengine import Parameter, Tensor
from megengine.core.tensor import dtype
from megengine.quantization import FakeQuantize, MinMaxObserver, QConfig
from megengine.quantization import (
FakeQuantize,
MinMaxObserver,
QConfig,
QuantMode,
create_qparams,
)
from megengine.quantization.quantize import (
disable_fake_quant,
disable_observer,
......@@ -18,10 +24,10 @@ from megengine.quantization.quantize import (
)
min_max_fakequant_qconfig = QConfig(
weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True),
act_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=False),
weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True),
act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False),
weight_observer=partial(MinMaxObserver, dtype="qint8_narrow"),
act_observer=partial(MinMaxObserver, dtype="qint8"),
weight_fake_quant=partial(FakeQuantize, dtype="qint8_narrow"),
act_fake_quant=partial(FakeQuantize, dtype="qint8"),
)
inp_scale = np.float32(np.random.rand() + 1)
......@@ -111,7 +117,7 @@ def test_dequant_stub():
x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
x = fake_quant_act(x, inp_scale)
x.q_dict["scale"] = inp_scale
x.qparams.scale = inp_scale
normal = normal_net(x)
qat_without_fakequant = qat_from_float(x)
......@@ -146,12 +152,12 @@ def test_elemwise(kind):
x1_scale = np.float32(np.random.rand() + 1)
x1 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
x1 = fake_quant_act(x1, x1_scale)
x1.q_dict["scale"] = x1_scale
x1.qparams.scale = x1_scale
x2_scale = np.float32(np.random.rand() + 1)
x2 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
x2 = fake_quant_act(x2, x2_scale)
x2.q_dict["scale"] = x2_scale
x2.qparams.scale = x2_scale
x1_int8 = quant(x1, x1_scale)
x2_int8 = quant(x2, x2_scale)
......@@ -187,7 +193,7 @@ def test_linear():
x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
x = fake_quant_act(x, inp_scale)
x.q_dict["scale"] = inp_scale
x.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", inp_scale))
x_int8 = quant(x, inp_scale)
......@@ -230,7 +236,7 @@ def test_conv(module):
x = mge.tensor(np.random.normal(size=(1, 3, 3, 3)).astype("float32"))
x = fake_quant_act(x, inp_scale)
x.q_dict["scale"] = inp_scale
x.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", inp_scale))
x_int8 = quant(x, inp_scale)
......
......@@ -6,6 +6,7 @@ import pytest
import megengine as mge
import megengine.distributed as dist
from megengine.distributed.helper import get_device_count_by_fork
from megengine.quantization import QuantMode, create_qparams
from megengine.quantization.observer import (
ExponentialMovingAverageObserver,
HistogramObserver,
......@@ -56,14 +57,14 @@ def test_histogram_observer():
def test_passive_observer():
q_dict = {"scale": mge.tensor(1.0)}
qparams = create_qparams(QuantMode.SYMMERTIC, "qint8", mge.tensor(1.0))
m = PassiveObserver("qint8")
m.set_qparams(q_dict)
m.set_qparams(qparams)
assert m.orig_scale == 1.0
assert m.scale == 1.0
m.scale = 2.0
assert m.scale == 2.0
assert m.get_qparams() == {"scale": mge.tensor(2.0)}
assert m.scale.numpy() == 1.0
assert m.get_qparams().dtype_meta == qparams.dtype_meta
assert m.get_qparams().scale == qparams.scale
assert m.get_qparams() == qparams
@pytest.mark.require_ngpu(2)
......
......@@ -6,6 +6,7 @@ import megengine.functional as F
from megengine.core.tensor import dtype
from megengine.distributed.helper import get_device_count_by_fork
from megengine.functional.elemwise import _elemwise_multi_type, _elwise
from megengine.quantization import QuantMode, create_qparams
def quant(x, scale):
......@@ -26,13 +27,13 @@ def test_elemwise(kind):
x1 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
x1_scale = np.float32(np.random.rand() + 1)
x1 = fake_quant(x1, x1_scale)
x1.q_dict["scale"] = x1_scale
x1.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", x1_scale))
x1_int8 = quant(x1, x1_scale)
x2 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
x2_scale = np.float32(np.random.rand() + 1)
x2 = fake_quant(x2, x2_scale)
x2.q_dict["scale"] = x2_scale
x2.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", x2_scale))
x2_int8 = quant(x2, x2_scale)
output_scale = np.float32(np.random.rand() + 1)
......
from functools import partial
from megengine.quantization import QConfig, tqt_qconfig
from megengine.quantization.fake_quant import TQT
def test_equal():
qconfig = QConfig(
weight_observer=None,
act_observer=None,
weight_fake_quant=partial(TQT, dtype="qint8", narrow_range=True),
act_fake_quant=partial(TQT, dtype="qint8", narrow_range=False),
)
assert qconfig == tqt_qconfig
......@@ -33,7 +33,7 @@ from megengine.quantization.quantize import (
)
class Net(Float.Module):
class FloatNet(Float.Module):
def __init__(self):
super().__init__()
self.quant = Float.QuantStub()
......@@ -113,25 +113,25 @@ def test_reset_qconfig():
def test_enable_and_disable_observer():
net = init_qat_net()
enable_observer(net)
assert net.quant.act_observer.enabled == True
assert net.linear.weight_observer.enabled == True
assert net.linear.act_observer.enabled == True
assert net.quant.act_observer.enabled is True
assert net.linear.weight_observer.enabled is True
assert net.linear.act_observer.enabled is True
disable_observer(net)
assert net.quant.act_observer.enabled == False
assert net.linear.weight_observer.enabled == False
assert net.linear.act_observer.enabled == False
assert net.quant.act_observer.enabled is False
assert net.linear.weight_observer.enabled is False
assert net.linear.act_observer.enabled is False
def test_enable_and_disable_fake_quant():
net = init_qat_net()
disable_fake_quant(net)
assert net.quant.act_fake_quant.enabled == False
assert net.linear.weight_fake_quant.enabled == False
assert net.linear.act_fake_quant.enabled == False
assert net.quant.act_fake_quant.enabled is False
assert net.linear.weight_fake_quant.enabled is False
assert net.linear.act_fake_quant.enabled is False
enable_fake_quant(net)
assert net.quant.act_fake_quant.enabled == True
assert net.linear.weight_fake_quant.enabled == True
assert net.linear.act_fake_quant.enabled == True
assert net.quant.act_fake_quant.enabled is True
assert net.linear.weight_fake_quant.enabled is True
assert net.linear.act_fake_quant.enabled is True
def init_observer(module, data):
......@@ -144,7 +144,7 @@ def init_observer(module, data):
def test_enable_and_disable_all():
x = Tensor(np.random.randint(1, 10, size=(3, 3)).astype(np.float32))
net = Net()
net = FloatNet()
y1 = net(x).numpy()
net = quantize_qat(net, min_max_fakequant_qconfig)
......@@ -162,7 +162,7 @@ def test_enable_and_disable_all():
def test_quantize_qat():
net = Net()
net = FloatNet()
qat_net = quantize_qat(net, inplace=False, qconfig=min_max_fakequant_qconfig)
assert isinstance(qat_net.quant, QAT.QuantStub)
assert isinstance(qat_net.linear, QAT.Linear)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册