提交 fda9599a 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

feat(mge/quant): add TQT quant method

GitOrigin-RevId: 00b1616e73ed34c8c09e2407b8fc7d90230f8cec
上级 285d70cb
...@@ -154,6 +154,7 @@ class Function(metaclass=ABCMeta): ...@@ -154,6 +154,7 @@ class Function(metaclass=ABCMeta):
memo[id(self)] = result memo[id(self)] = result
for k, v in self.__dict__.items(): for k, v in self.__dict__.items():
setattr(result, k, copy.deepcopy(v, memo)) setattr(result, k, copy.deepcopy(v, memo))
setattr(result, "saved_tensors", tmp)
self.saved_tensors = tmp self.saved_tensors = tmp
return result return result
......
...@@ -77,13 +77,19 @@ class QATModule(Module): ...@@ -77,13 +77,19 @@ class QATModule(Module):
r""" r"""
Get weight's quantization dtype as the method from ``qconfig``. Get weight's quantization dtype as the method from ``qconfig``.
""" """
return self.weight_observer.get_dtype() if hasattr(self.act_fake_quant, "get_dtype"):
return self.weight_fake_quant.get_dtype()
else:
return self.weight_observer.get_dtype()
def get_activation_dtype(self): def get_activation_dtype(self):
r""" r"""
Get activation's quantization dtype as the method from ``qconfig``. Get activation's quantization dtype as the method from ``qconfig``.
""" """
return self.act_observer.get_dtype() if hasattr(self.act_fake_quant, "get_dtype"):
return self.act_fake_quant.get_dtype()
else:
return self.act_observer.get_dtype()
@classmethod @classmethod
@abstractmethod @abstractmethod
......
...@@ -12,4 +12,5 @@ from .qconfig import ( ...@@ -12,4 +12,5 @@ from .qconfig import (
calibration_qconfig, calibration_qconfig,
ema_fakequant_qconfig, ema_fakequant_qconfig,
min_max_fakequant_qconfig, min_max_fakequant_qconfig,
tqt_quant_qconfig,
) )
...@@ -5,17 +5,20 @@ ...@@ -5,17 +5,20 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import copy
import math
import numpy as np
from .. import functional as F from .. import functional as F
from .._internal.dtype import _metadata_dict from .._internal.dtype import _metadata_dict, get_quantized_dtype
from ..core import Buffer, Function, Parameter
from ..jit import sideeffect
from ..module import Module from ..module import Module
from .observer import ObserverMode, Round from .observer import ObserverMode, Round
class FakeQuantize(Module): class _FakeQuantize(Module):
r"""
A module to do quant and dequant according to observer's scale and zero_point.
"""
def __init__(self, dtype: str, enable: bool = True): def __init__(self, dtype: str, enable: bool = True):
super().__init__() super().__init__()
if not dtype in _metadata_dict.keys(): if not dtype in _metadata_dict.keys():
...@@ -35,25 +38,103 @@ class FakeQuantize(Module): ...@@ -35,25 +38,103 @@ class FakeQuantize(Module):
def disable(self): def disable(self):
self.enabled = False self.enabled = False
def fake_quant_forward(self, inp, q_dict):
return inp
def normal_foward(self, inp, q_dict):
return inp
def forward(self, inp, q_dict): def forward(self, inp, q_dict):
if self.enabled: if self.enabled:
if q_dict["mode"] == ObserverMode.SYMMERTIC: return self.fake_quant_forward(inp, q_dict)
scale = q_dict["scale"] else:
# Quant return self.normal_foward(inp, q_dict)
oup = Round()(inp / scale)
# clip
oup = F.minimum(F.maximum(oup, self.qmin), self.qmax) class TQT_Function(Function):
# DeQuant def __init__(self, lowerbound, upperbound):
oup = (oup) * scale super().__init__()
return oup self.lowerbound = lowerbound
else: self.upperbound = upperbound
scale = q_dict["scale"]
zero_point = q_dict["zero_point"] def forward(self, inp, scale):
# Quant t = 2 ** scale
oup = Round()(inp / scale) + zero_point # t = F.maximum(t, 1e-4)
# clip inp_scaled = inp / t
oup = F.minimum(F.maximum(oup, self.qmin), self.qmax) inp_clipped = F.maximum(F.minimum(inp_scaled, self.upperbound), self.lowerbound)
# DeQuant inp_rounded = F.round(inp_clipped)
oup = (oup - zero_point) * scale inp_flq = inp_rounded * t
return oup self.save_for_backward(inp_scaled, inp_rounded, t)
return inp_flq
def backward(self, grad_inp_flq):
(inp_scaled, inp_rounded, t) = self.saved_tensors
mask_clip = (inp_scaled < -0.5 + self.lowerbound) + (
inp_scaled > self.upperbound + 0.5
) # mask for accumulating the gradients of |data_scaled|>L
mask_quant = F.abs(
mask_clip - 1
) # mask for accumulating the gradients with |data_scaled|<=L
grad_quant = (
grad_inp_flq * mask_quant * (inp_rounded - inp_scaled)
) # gradient within |data_scaled|<=L
grad_clip = (
grad_inp_flq * mask_clip * inp_rounded
) # gradient with | data_scaled|>L
grad_s = grad_clip.sum() + grad_quant.sum()
# dL/ds = dL/dt * t * ln(2)
grad_s = grad_s * t * math.log(2)
grad_inp = grad_inp_flq * mask_quant
return grad_inp, grad_s
class TQT(_FakeQuantize):
"""
TQT: https://arxiv.org/abs/1903.08066 Trained Quantization Thresholds
for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks
"""
def __init__(self, dtype: str, enable: bool = True):
super().__init__(dtype, enable)
self.scale = Parameter(0.0, dtype=np.float32)
def fake_quant_forward(self, inp, q_dict):
# when enable, TQT will do fakequant forward, finetune the scale
return TQT_Function(self.qmin, self.qmax)(inp, self.scale)
def normal_foward(self, inp, q_dict):
# when disable, TQT will do normal forward, initialize scale weight
tmp_scale = F.maximum(F.abs(q_dict["min_val"]), F.abs(q_dict["max_val"]))
tmp_scale = F.log(tmp_scale / 127) / F.log(2)
F.add_update(self.scale, tmp_scale, alpha=0.0, beta=1.0, bias=0.0)
return inp return inp
def get_dtype(self):
return get_quantized_dtype(self.dtype, 2 ** self.scale.numpy()[0], None)
class FakeQuantize(_FakeQuantize):
r"""
A module to do quant and dequant according to observer's scale and zero_point.
"""
def fake_quant_forward(self, inp, q_dict):
if q_dict["mode"] == ObserverMode.SYMMERTIC:
scale = q_dict["scale"]
# Quant
oup = Round()(inp / scale)
# clip
oup = F.minimum(F.maximum(oup, self.qmin), self.qmax)
# DeQuant
oup = (oup) * scale
return oup
else:
scale = q_dict["scale"]
zero_point = q_dict["zero_point"]
# Quant
oup = Round()(inp / scale) + zero_point
# clip
oup = F.minimum(F.maximum(oup, self.qmin), self.qmax)
# DeQuant
oup = (oup - zero_point) * scale
return oup
...@@ -107,6 +107,8 @@ class MinMaxObserver(Observer): ...@@ -107,6 +107,8 @@ class MinMaxObserver(Observer):
min_val = F.minimum(0.0, inp_min_val) min_val = F.minimum(0.0, inp_min_val)
max_val = F.maximum(0.0, inp_max_val) max_val = F.maximum(0.0, inp_max_val)
q_dict = create_observer_dict(self.mode) q_dict = create_observer_dict(self.mode)
q_dict["min_val"] = inp_min_val
q_dict["max_val"] = inp_max_val
if self.mode == ObserverMode.SYMMERTIC: if self.mode == ObserverMode.SYMMERTIC:
symmetric_max_vals = F.maximum(-min_val, max_val) symmetric_max_vals = F.maximum(-min_val, max_val)
# use maximun to avoid scale too small at the begin # use maximun to avoid scale too small at the begin
......
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
# #
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
# #'
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ..module import Module from ..module import Module
from .fake_quant import FakeQuantize from .fake_quant import TQT, FakeQuantize
from .observer import ( from .observer import (
ExponentialMovingAverageObserver, ExponentialMovingAverageObserver,
HistogramObserver, HistogramObserver,
...@@ -52,6 +52,12 @@ class QConfig: ...@@ -52,6 +52,12 @@ class QConfig:
self.fake_quant = fake_quant self.fake_quant = fake_quant
tqt_quant_qconfig = QConfig(
weight_observer=ExponentialMovingAverageObserver,
act_observer=ExponentialMovingAverageObserver,
fake_quant=TQT,
)
# Default QAT QConfigs # Default QAT QConfigs
min_max_fakequant_qconfig = QConfig( min_max_fakequant_qconfig = QConfig(
weight_observer=MinMaxObserver, weight_observer=MinMaxObserver,
......
...@@ -96,7 +96,6 @@ def test_deepcopy(): ...@@ -96,7 +96,6 @@ def test_deepcopy():
origin = Sigmoid(0) origin = Sigmoid(0)
new = copy.deepcopy(Sigmoid(0)) new = copy.deepcopy(Sigmoid(0))
assert new.param == origin.param assert new.param == origin.param
assert new.saved_tensors == None
def test_save_context(): def test_save_context():
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
import pytest
import megengine as mge
import megengine._internal as mgb
from megengine.core import tensor
from megengine.quantization.fake_quant import TQT_Function
from megengine.test import assertTensorClose
class numpy_TQT_Function:
def __init__(self, lowerbound, upperbound):
super().__init__()
self.lowerbound = lowerbound
self.upperbound = upperbound
def forward(self, inp, scale):
t = 2 ** scale
# t = F.maximum(t, 1e-4)
inp_scaled = inp / t
inp_clipped = np.maximum(
np.minimum(inp_scaled, self.upperbound), self.lowerbound
)
inp_rounded = np.round(inp_clipped)
inp_flq = inp_rounded * t
self.saved_tensors = (inp_scaled, inp_rounded, t)
return inp_flq
def backward(self, grad_inp_flq):
(inp_scaled, inp_rounded, t) = self.saved_tensors
mask_clip = (inp_scaled < -0.5 + self.lowerbound) + (
inp_scaled > self.upperbound + 0.5
) # mask for accumulating the gradients of |data_scaled|>L
mask_quant = np.abs(
mask_clip - 1
) # mask for accumulating the gradients with |data_scaled|<=L
grad_quant = (
grad_inp_flq * mask_quant * (inp_rounded - inp_scaled)
) # gradient within |data_scaled|<=L
grad_clip = (
grad_inp_flq * mask_clip * inp_rounded
) # gradient with | data_scaled|>L
grad_s = grad_clip.sum() + grad_quant.sum()
# dL/ds = dL/dt * t * ln(2)
grad_s = grad_s * t * np.log(2)
grad_inp = grad_inp_flq * mask_quant
return grad_inp, grad_s
def test_TQT():
f = TQT_Function(-127, 127)
nf = numpy_TQT_Function(-127, 127)
def check_inp(a, b, c, a_np, b_np, c_np):
assertTensorClose(
f.forward(a, b).numpy(), nf.forward(a_np, b_np).astype("float32")
)
c1, c2 = f.backward(c)
c1_np, c2_np = nf.backward(c_np)
assertTensorClose(c1.numpy(), c1_np.astype("float32"))
assertTensorClose(c2.numpy(), c2_np.astype("float32"))
a = tensor()
b = tensor()
a_np = np.random.random((4, 3)).astype("float32")
b_np = np.random.random((1)).astype("float32")
a.set_value(a_np)
b.set_value(b_np)
check_inp(a, b, b, a_np, b_np, b_np)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册