提交 4495c0cc 编写于 作者: M Megvii Engine Team

feat(mgb/quantization): add get quantize parameters support

GitOrigin-RevId: 5727f6356075658691c66a135b073c914969d9c9
上级 9b097859
......@@ -92,6 +92,25 @@ class QATModule(Module):
else:
return self.act_observer.get_dtype()
def _get_qparams(self, fake_quant: FakeQuantize, observer: Observer):
if hasattr(fake_quant, "get_qparams"):
return fake_quant.get_qparams()
elif observer is not None:
return observer.get_qparams()
return None
def get_weight_qparams(self):
r"""
Get weight's quantization parameters.
"""
return self._get_qparams(self.weight_fake_quant, self.weight_observer)
def get_activation_qparams(self):
r"""
Get activation's quantization parameters.
"""
return self._get_qparams(self.act_fake_quant, self.act_observer)
@classmethod
@abstractmethod
def from_float_module(cls, float_module: Module):
......
......@@ -8,7 +8,7 @@
from .fake_quant import FakeQuantize
from .internal_fake_quant import *
from .observer import HistogramObserver, Observer, ObserverMode
from .observer import HistogramObserver, Observer
from .qconfig import (
QConfig,
calibration_qconfig,
......@@ -16,3 +16,4 @@ from .qconfig import (
min_max_fakequant_qconfig,
tqt_quant_qconfig,
)
from .utils import QuantMode
......@@ -15,7 +15,8 @@ from .._internal.dtype import _metadata_dict, get_quantized_dtype
from ..core import Buffer, Function, Parameter
from ..jit import sideeffect
from ..module import Module
from .observer import ObserverMode, Round
from .observer import Round
from .utils import QuantMode, get_qparam_dict
class _FakeQuantize(Module):
......@@ -121,8 +122,18 @@ class TQT(_FakeQuantize):
F.add_update(self.scale, tmp_scale, alpha=0.0, beta=1.0, bias=0.0)
return inp
def get_qparams(self):
qdict = get_qparam_dict(QuantMode.TQT)
qdict["scale"] = 2 ** self.scale
return qdict
def get_dtype(self):
return get_quantized_dtype(self.dtype, 2 ** self.scale.numpy()[0], None)
q_dict = self.get_qparams()
scale = None if "scale" not in q_dict else q_dict["scale"].numpy()[0]
zero_point = (
None if "zero_point" not in q_dict else q_dict["zero_point"].numpy()[0]
)
return get_quantized_dtype(self.dtype, scale, zero_point)
class FakeQuantize(_FakeQuantize):
......@@ -131,7 +142,7 @@ class FakeQuantize(_FakeQuantize):
"""
def fake_quant_forward(self, inp, q_dict):
if q_dict["mode"] == ObserverMode.SYMMERTIC:
if q_dict["mode"] == QuantMode.SYMMERTIC:
scale = q_dict["scale"]
# Quant
oup = Round()(inp / scale)
......
......@@ -16,6 +16,7 @@ from .._internal.dtype import _metadata_dict, get_quantized_dtype
from ..core import Buffer, Function, tensor
from ..jit import sideeffect
from ..module import Module
from .utils import QuantMode, get_qparam_dict
class Round(Function):
......@@ -81,29 +82,10 @@ class Observer(Module):
pass
class ObserverMode(Enum):
SYMMERTIC = 1
ASYMMERTIC = 2
def create_observer_dict(mode):
if mode == ObserverMode.SYMMERTIC:
return {
"mode": ObserverMode.SYMMERTIC,
"scale": None,
}
else:
return {
"mode": ObserverMode.ASYMMERTIC,
"scale": None,
"zero_point": None,
}
class MinMaxObserver(Observer):
def __init__(
self,
mode=ObserverMode.SYMMERTIC,
mode=QuantMode.SYMMERTIC,
eps=0.00001,
dtype="qint8",
narrow_range: bool = False,
......@@ -117,10 +99,10 @@ class MinMaxObserver(Observer):
def _calculate_qparams(self, inp_min_val, inp_max_val):
min_val = F.minimum(0.0, inp_min_val)
max_val = F.maximum(0.0, inp_max_val)
q_dict = create_observer_dict(self.mode)
q_dict = get_qparam_dict(self.mode)
q_dict["min_val"] = inp_min_val
q_dict["max_val"] = inp_max_val
if self.mode == ObserverMode.SYMMERTIC:
if self.mode == QuantMode.SYMMERTIC:
symmetric_max_vals = F.maximum(-min_val, max_val)
# use maximun to avoid scale too small at the begin
q_dict["scale"] = F.maximum(
......@@ -166,7 +148,7 @@ class ExponentialMovingAverageObserver(MinMaxObserver):
def __init__(
self,
momentum=0.9,
mode=ObserverMode.SYMMERTIC,
mode=QuantMode.SYMMERTIC,
eps=0.00001,
dtype="qint8",
narrow_range: bool = False,
......@@ -204,7 +186,7 @@ class HistogramObserver(MinMaxObserver):
self,
bins=2048,
upsample_rate=128,
mode=ObserverMode.SYMMERTIC,
mode=QuantMode.SYMMERTIC,
eps=0.00001,
dtype="qint8",
narrow_range: bool = False,
......
......@@ -6,6 +6,7 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from enum import Enum
from functools import partial, update_wrapper, wraps
......@@ -21,3 +22,24 @@ def register_method_to_class(cls):
return func
return decorator
class QuantMode(Enum):
SYMMERTIC = 1
ASYMMERTIC = 2
TQT = 3
qparam_dict = {
QuantMode.SYMMERTIC: {"mode": QuantMode.SYMMERTIC, "scale": None,},
QuantMode.ASYMMERTIC: {
"mode": QuantMode.ASYMMERTIC,
"scale": None,
"zero_point": None,
},
QuantMode.TQT: {"mode": QuantMode.TQT, "scale": None,},
}
def get_qparam_dict(mode):
return qparam_dict.get(mode, None)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册