提交 4495c0cc 编写于 作者: M Megvii Engine Team

feat(mgb/quantization): add get quantize parameters support

GitOrigin-RevId: 5727f6356075658691c66a135b073c914969d9c9
上级 9b097859
...@@ -92,6 +92,25 @@ class QATModule(Module): ...@@ -92,6 +92,25 @@ class QATModule(Module):
else: else:
return self.act_observer.get_dtype() return self.act_observer.get_dtype()
def _get_qparams(self, fake_quant: FakeQuantize, observer: Observer):
if hasattr(fake_quant, "get_qparams"):
return fake_quant.get_qparams()
elif observer is not None:
return observer.get_qparams()
return None
def get_weight_qparams(self):
r"""
Get weight's quantization parameters.
"""
return self._get_qparams(self.weight_fake_quant, self.weight_observer)
def get_activation_qparams(self):
r"""
Get activation's quantization parameters.
"""
return self._get_qparams(self.act_fake_quant, self.act_observer)
@classmethod @classmethod
@abstractmethod @abstractmethod
def from_float_module(cls, float_module: Module): def from_float_module(cls, float_module: Module):
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
from .fake_quant import FakeQuantize from .fake_quant import FakeQuantize
from .internal_fake_quant import * from .internal_fake_quant import *
from .observer import HistogramObserver, Observer, ObserverMode from .observer import HistogramObserver, Observer
from .qconfig import ( from .qconfig import (
QConfig, QConfig,
calibration_qconfig, calibration_qconfig,
...@@ -16,3 +16,4 @@ from .qconfig import ( ...@@ -16,3 +16,4 @@ from .qconfig import (
min_max_fakequant_qconfig, min_max_fakequant_qconfig,
tqt_quant_qconfig, tqt_quant_qconfig,
) )
from .utils import QuantMode
...@@ -15,7 +15,8 @@ from .._internal.dtype import _metadata_dict, get_quantized_dtype ...@@ -15,7 +15,8 @@ from .._internal.dtype import _metadata_dict, get_quantized_dtype
from ..core import Buffer, Function, Parameter from ..core import Buffer, Function, Parameter
from ..jit import sideeffect from ..jit import sideeffect
from ..module import Module from ..module import Module
from .observer import ObserverMode, Round from .observer import Round
from .utils import QuantMode, get_qparam_dict
class _FakeQuantize(Module): class _FakeQuantize(Module):
...@@ -121,8 +122,18 @@ class TQT(_FakeQuantize): ...@@ -121,8 +122,18 @@ class TQT(_FakeQuantize):
F.add_update(self.scale, tmp_scale, alpha=0.0, beta=1.0, bias=0.0) F.add_update(self.scale, tmp_scale, alpha=0.0, beta=1.0, bias=0.0)
return inp return inp
def get_qparams(self):
qdict = get_qparam_dict(QuantMode.TQT)
qdict["scale"] = 2 ** self.scale
return qdict
def get_dtype(self): def get_dtype(self):
return get_quantized_dtype(self.dtype, 2 ** self.scale.numpy()[0], None) q_dict = self.get_qparams()
scale = None if "scale" not in q_dict else q_dict["scale"].numpy()[0]
zero_point = (
None if "zero_point" not in q_dict else q_dict["zero_point"].numpy()[0]
)
return get_quantized_dtype(self.dtype, scale, zero_point)
class FakeQuantize(_FakeQuantize): class FakeQuantize(_FakeQuantize):
...@@ -131,7 +142,7 @@ class FakeQuantize(_FakeQuantize): ...@@ -131,7 +142,7 @@ class FakeQuantize(_FakeQuantize):
""" """
def fake_quant_forward(self, inp, q_dict): def fake_quant_forward(self, inp, q_dict):
if q_dict["mode"] == ObserverMode.SYMMERTIC: if q_dict["mode"] == QuantMode.SYMMERTIC:
scale = q_dict["scale"] scale = q_dict["scale"]
# Quant # Quant
oup = Round()(inp / scale) oup = Round()(inp / scale)
......
...@@ -16,6 +16,7 @@ from .._internal.dtype import _metadata_dict, get_quantized_dtype ...@@ -16,6 +16,7 @@ from .._internal.dtype import _metadata_dict, get_quantized_dtype
from ..core import Buffer, Function, tensor from ..core import Buffer, Function, tensor
from ..jit import sideeffect from ..jit import sideeffect
from ..module import Module from ..module import Module
from .utils import QuantMode, get_qparam_dict
class Round(Function): class Round(Function):
...@@ -81,29 +82,10 @@ class Observer(Module): ...@@ -81,29 +82,10 @@ class Observer(Module):
pass pass
class ObserverMode(Enum):
SYMMERTIC = 1
ASYMMERTIC = 2
def create_observer_dict(mode):
if mode == ObserverMode.SYMMERTIC:
return {
"mode": ObserverMode.SYMMERTIC,
"scale": None,
}
else:
return {
"mode": ObserverMode.ASYMMERTIC,
"scale": None,
"zero_point": None,
}
class MinMaxObserver(Observer): class MinMaxObserver(Observer):
def __init__( def __init__(
self, self,
mode=ObserverMode.SYMMERTIC, mode=QuantMode.SYMMERTIC,
eps=0.00001, eps=0.00001,
dtype="qint8", dtype="qint8",
narrow_range: bool = False, narrow_range: bool = False,
...@@ -117,10 +99,10 @@ class MinMaxObserver(Observer): ...@@ -117,10 +99,10 @@ class MinMaxObserver(Observer):
def _calculate_qparams(self, inp_min_val, inp_max_val): def _calculate_qparams(self, inp_min_val, inp_max_val):
min_val = F.minimum(0.0, inp_min_val) min_val = F.minimum(0.0, inp_min_val)
max_val = F.maximum(0.0, inp_max_val) max_val = F.maximum(0.0, inp_max_val)
q_dict = create_observer_dict(self.mode) q_dict = get_qparam_dict(self.mode)
q_dict["min_val"] = inp_min_val q_dict["min_val"] = inp_min_val
q_dict["max_val"] = inp_max_val q_dict["max_val"] = inp_max_val
if self.mode == ObserverMode.SYMMERTIC: if self.mode == QuantMode.SYMMERTIC:
symmetric_max_vals = F.maximum(-min_val, max_val) symmetric_max_vals = F.maximum(-min_val, max_val)
# use maximun to avoid scale too small at the begin # use maximun to avoid scale too small at the begin
q_dict["scale"] = F.maximum( q_dict["scale"] = F.maximum(
...@@ -166,7 +148,7 @@ class ExponentialMovingAverageObserver(MinMaxObserver): ...@@ -166,7 +148,7 @@ class ExponentialMovingAverageObserver(MinMaxObserver):
def __init__( def __init__(
self, self,
momentum=0.9, momentum=0.9,
mode=ObserverMode.SYMMERTIC, mode=QuantMode.SYMMERTIC,
eps=0.00001, eps=0.00001,
dtype="qint8", dtype="qint8",
narrow_range: bool = False, narrow_range: bool = False,
...@@ -204,7 +186,7 @@ class HistogramObserver(MinMaxObserver): ...@@ -204,7 +186,7 @@ class HistogramObserver(MinMaxObserver):
self, self,
bins=2048, bins=2048,
upsample_rate=128, upsample_rate=128,
mode=ObserverMode.SYMMERTIC, mode=QuantMode.SYMMERTIC,
eps=0.00001, eps=0.00001,
dtype="qint8", dtype="qint8",
narrow_range: bool = False, narrow_range: bool = False,
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from enum import Enum
from functools import partial, update_wrapper, wraps from functools import partial, update_wrapper, wraps
...@@ -21,3 +22,24 @@ def register_method_to_class(cls): ...@@ -21,3 +22,24 @@ def register_method_to_class(cls):
return func return func
return decorator return decorator
class QuantMode(Enum):
SYMMERTIC = 1
ASYMMERTIC = 2
TQT = 3
qparam_dict = {
QuantMode.SYMMERTIC: {"mode": QuantMode.SYMMERTIC, "scale": None,},
QuantMode.ASYMMERTIC: {
"mode": QuantMode.ASYMMERTIC,
"scale": None,
"zero_point": None,
},
QuantMode.TQT: {"mode": QuantMode.TQT, "scale": None,},
}
def get_qparam_dict(mode):
return qparam_dict.get(mode, None)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册