提交 d1be3127 编写于 作者: M Megvii Engine Team

fix(mge/quantization): fix tqt load and convert issue and observer calculate params issue

GitOrigin-RevId: f8511f72adbac3869f7bb05f3f1364329798119e
上级 30d6b4f6
......@@ -6,12 +6,8 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import math
from typing import Iterable
import numpy as np
from .. import functional as F
from ..autodiff import Function
from ..core.tensor.dtype import _metadata_dict, get_quantized_dtype
from ..module import Module
from ..tensor import Parameter, Tensor
......@@ -72,20 +68,10 @@ class TQT(_FakeQuantize):
"""
def __init__(
self,
q_dict,
dtype: str,
narrow_range: bool = False,
enable: bool = True,
**kwargs
self, dtype: str, narrow_range: bool = False, enable: bool = True, **kwargs
):
super().__init__(dtype, narrow_range, enable, **kwargs)
assert (
q_dict["mode"] == QuantMode.SYMMERTIC
), "only symmetric quantization is supported by TQT"
if "scale" not in q_dict or q_dict["scale"] is None:
raise AssertionError("Can not get an initialized scale")
self.scale = Tensor(F.log(q_dict["scale"]) / math.log(2))
self.scale = Parameter(0.0, dtype="float32")
def fake_quant_forward(self, inp, q_dict=None):
# when enable, TQT will do fakequant forward, finetune the scale
......@@ -93,14 +79,22 @@ class TQT(_FakeQuantize):
def get_qparams(self):
q_dict = get_qparam_dict(QuantMode.SYMMERTIC)
q_dict["scale"] = 2 ** self.scale
q_dict["scale"] = 2 ** self.scale.detach()
return q_dict
def set_qparams(self, q_dict):
assert (
q_dict["mode"] == QuantMode.SYMMERTIC
), "only symmetric quantization is supported by TQT"
if "scale" not in q_dict or q_dict["scale"] is None:
raise AssertionError("Can not get an initialized scale")
self.scale._reset(F.log(q_dict["scale"]) / math.log(2))
def get_dtype(self):
q_dict = self.get_qparams()
scale = None if "scale" not in q_dict else q_dict["scale"].numpy()[0]
scale = None if "scale" not in q_dict else q_dict["scale"].numpy()
zero_point = (
None if "zero_point" not in q_dict else q_dict["zero_point"].numpy()[0]
None if "zero_point" not in q_dict else q_dict["zero_point"].numpy()
)
return get_quantized_dtype(self.dtype, scale, zero_point)
......
......@@ -17,7 +17,7 @@ from ..distributed import WORLD, get_rank, is_distributed
from ..functional.distributed import all_reduce_max, all_reduce_min
from ..module import Module
from ..tensor import Tensor
from .utils import QuantMode, Round, get_qparam_dict
from .utils import QuantMode, get_qparam_dict
class Observer(Module):
......@@ -110,7 +110,7 @@ class MinMaxObserver(Observer):
(max_val - min_val) / (self.qmax - self.qmin), self.scale_limit
)
# caculate zero_point
q_dict["zero_point"] = self.qmin - Round()((min_val / q_dict["scale"]))
q_dict["zero_point"] = self.qmin - F.round(min_val / q_dict["scale"])
return q_dict
......@@ -453,12 +453,10 @@ class PassiveObserver(Observer):
This class can be set :attr:`scale` derectly.
"""
def __init__(self, q_dict, dtype: str, narrow_range: bool = False, **kwargs):
def __init__(self, dtype: str, narrow_range: bool = False, **kwargs):
super().__init__(dtype, narrow_range, **kwargs)
self.q_dict = deepcopy(q_dict)
if "scale" not in q_dict or q_dict["scale"] is None:
raise AssertionError("Can not get an initialized scale")
self.orig_scale = q_dict["scale"].numpy()
self.q_dict = None
self.orig_scale = None
@property
def scale(self):
......@@ -472,6 +470,12 @@ class PassiveObserver(Observer):
def get_qparams(self):
return self.q_dict
def set_qparams(self, q_dict):
self.q_dict = deepcopy(q_dict)
if "scale" not in q_dict or q_dict["scale"] is None:
raise AssertionError("Can not get an initialized scale")
self.orig_scale = q_dict["scale"].numpy()
def forward(self, x):
r"""
Just return input because :attr:`q_dict` is set by :func:`~.apply_easy_quant`.
......
......@@ -152,7 +152,10 @@ def reset_qconfig(module: Module, qconfig: QConfig, inplace: bool = True):
module = deepcopy(module)
def safe_call(func, q_dict):
return func(q_dict=q_dict) if func is not None else None
inst = func() if func is not None else None
if inst is not None and getattr(inst, "set_qparams", None) is not None:
inst.set_qparams(q_dict)
return inst
for m in list(module._flatten(predicate=is_qat)):
if m.with_weight:
......
......@@ -41,8 +41,8 @@ def test_exponential_moving_average_observer():
m = ExponentialMovingAverageObserver(momentum=t)
m(mge.tensor(x1, dtype=np.float32))
m(mge.tensor(x2, dtype=np.float32))
np.testing.assert_allclose(m.min_val.numpy(), expected_min)
np.testing.assert_allclose(m.max_val.numpy(), expected_max)
np.testing.assert_allclose(m.min_val.numpy(), expected_min, atol=1e-5)
np.testing.assert_allclose(m.max_val.numpy(), expected_max, atol=1e-5)
def test_histogram_observer():
......@@ -57,7 +57,8 @@ def test_histogram_observer():
def test_passive_observer():
q_dict = {"scale": mge.tensor(1.0)}
m = PassiveObserver(q_dict, "qint8")
m = PassiveObserver("qint8")
m.set_qparams(q_dict)
assert m.orig_scale == 1.0
assert m.scale == 1.0
m.scale = 2.0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册