提交 fda9599a 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

feat(mge/quant): add TQT quant method

GitOrigin-RevId: 00b1616e73ed34c8c09e2407b8fc7d90230f8cec
上级 285d70cb
......@@ -154,6 +154,7 @@ class Function(metaclass=ABCMeta):
memo[id(self)] = result
for k, v in self.__dict__.items():
setattr(result, k, copy.deepcopy(v, memo))
setattr(result, "saved_tensors", tmp)
self.saved_tensors = tmp
return result
......
......@@ -77,12 +77,18 @@ class QATModule(Module):
r"""
Get weight's quantization dtype as the method from ``qconfig``.
"""
if hasattr(self.act_fake_quant, "get_dtype"):
return self.weight_fake_quant.get_dtype()
else:
return self.weight_observer.get_dtype()
def get_activation_dtype(self):
r"""
Get activation's quantization dtype as the method from ``qconfig``.
"""
if hasattr(self.act_fake_quant, "get_dtype"):
return self.act_fake_quant.get_dtype()
else:
return self.act_observer.get_dtype()
@classmethod
......
......@@ -12,4 +12,5 @@ from .qconfig import (
calibration_qconfig,
ema_fakequant_qconfig,
min_max_fakequant_qconfig,
tqt_quant_qconfig,
)
......@@ -5,17 +5,20 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import copy
import math
import numpy as np
from .. import functional as F
from .._internal.dtype import _metadata_dict
from .._internal.dtype import _metadata_dict, get_quantized_dtype
from ..core import Buffer, Function, Parameter
from ..jit import sideeffect
from ..module import Module
from .observer import ObserverMode, Round
class FakeQuantize(Module):
r"""
A module to do quant and dequant according to observer's scale and zero_point.
"""
class _FakeQuantize(Module):
def __init__(self, dtype: str, enable: bool = True):
super().__init__()
if not dtype in _metadata_dict.keys():
......@@ -35,8 +38,87 @@ class FakeQuantize(Module):
def disable(self):
self.enabled = False
def fake_quant_forward(self, inp, q_dict):
return inp
def normal_foward(self, inp, q_dict):
return inp
def forward(self, inp, q_dict):
if self.enabled:
return self.fake_quant_forward(inp, q_dict)
else:
return self.normal_foward(inp, q_dict)
class TQT_Function(Function):
def __init__(self, lowerbound, upperbound):
super().__init__()
self.lowerbound = lowerbound
self.upperbound = upperbound
def forward(self, inp, scale):
t = 2 ** scale
# t = F.maximum(t, 1e-4)
inp_scaled = inp / t
inp_clipped = F.maximum(F.minimum(inp_scaled, self.upperbound), self.lowerbound)
inp_rounded = F.round(inp_clipped)
inp_flq = inp_rounded * t
self.save_for_backward(inp_scaled, inp_rounded, t)
return inp_flq
def backward(self, grad_inp_flq):
(inp_scaled, inp_rounded, t) = self.saved_tensors
mask_clip = (inp_scaled < -0.5 + self.lowerbound) + (
inp_scaled > self.upperbound + 0.5
) # mask for accumulating the gradients of |data_scaled|>L
mask_quant = F.abs(
mask_clip - 1
) # mask for accumulating the gradients with |data_scaled|<=L
grad_quant = (
grad_inp_flq * mask_quant * (inp_rounded - inp_scaled)
) # gradient within |data_scaled|<=L
grad_clip = (
grad_inp_flq * mask_clip * inp_rounded
) # gradient with | data_scaled|>L
grad_s = grad_clip.sum() + grad_quant.sum()
# dL/ds = dL/dt * t * ln(2)
grad_s = grad_s * t * math.log(2)
grad_inp = grad_inp_flq * mask_quant
return grad_inp, grad_s
class TQT(_FakeQuantize):
"""
TQT: https://arxiv.org/abs/1903.08066 Trained Quantization Thresholds
for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks
"""
def __init__(self, dtype: str, enable: bool = True):
super().__init__(dtype, enable)
self.scale = Parameter(0.0, dtype=np.float32)
def fake_quant_forward(self, inp, q_dict):
# when enable, TQT will do fakequant forward, finetune the scale
return TQT_Function(self.qmin, self.qmax)(inp, self.scale)
def normal_foward(self, inp, q_dict):
# when disable, TQT will do normal forward, initialize scale weight
tmp_scale = F.maximum(F.abs(q_dict["min_val"]), F.abs(q_dict["max_val"]))
tmp_scale = F.log(tmp_scale / 127) / F.log(2)
F.add_update(self.scale, tmp_scale, alpha=0.0, beta=1.0, bias=0.0)
return inp
def get_dtype(self):
return get_quantized_dtype(self.dtype, 2 ** self.scale.numpy()[0], None)
class FakeQuantize(_FakeQuantize):
r"""
A module to do quant and dequant according to observer's scale and zero_point.
"""
def fake_quant_forward(self, inp, q_dict):
if q_dict["mode"] == ObserverMode.SYMMERTIC:
scale = q_dict["scale"]
# Quant
......@@ -56,4 +138,3 @@ class FakeQuantize(Module):
# DeQuant
oup = (oup - zero_point) * scale
return oup
return inp
......@@ -107,6 +107,8 @@ class MinMaxObserver(Observer):
min_val = F.minimum(0.0, inp_min_val)
max_val = F.maximum(0.0, inp_max_val)
q_dict = create_observer_dict(self.mode)
q_dict["min_val"] = inp_min_val
q_dict["max_val"] = inp_max_val
if self.mode == ObserverMode.SYMMERTIC:
symmetric_max_vals = F.maximum(-min_val, max_val)
# use maximun to avoid scale too small at the begin
......
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
#'
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ..module import Module
from .fake_quant import FakeQuantize
from .fake_quant import TQT, FakeQuantize
from .observer import (
ExponentialMovingAverageObserver,
HistogramObserver,
......@@ -52,6 +52,12 @@ class QConfig:
self.fake_quant = fake_quant
tqt_quant_qconfig = QConfig(
weight_observer=ExponentialMovingAverageObserver,
act_observer=ExponentialMovingAverageObserver,
fake_quant=TQT,
)
# Default QAT QConfigs
min_max_fakequant_qconfig = QConfig(
weight_observer=MinMaxObserver,
......
......@@ -96,7 +96,6 @@ def test_deepcopy():
origin = Sigmoid(0)
new = copy.deepcopy(Sigmoid(0))
assert new.param == origin.param
assert new.saved_tensors == None
def test_save_context():
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
import pytest
import megengine as mge
import megengine._internal as mgb
from megengine.core import tensor
from megengine.quantization.fake_quant import TQT_Function
from megengine.test import assertTensorClose
class numpy_TQT_Function:
def __init__(self, lowerbound, upperbound):
super().__init__()
self.lowerbound = lowerbound
self.upperbound = upperbound
def forward(self, inp, scale):
t = 2 ** scale
# t = F.maximum(t, 1e-4)
inp_scaled = inp / t
inp_clipped = np.maximum(
np.minimum(inp_scaled, self.upperbound), self.lowerbound
)
inp_rounded = np.round(inp_clipped)
inp_flq = inp_rounded * t
self.saved_tensors = (inp_scaled, inp_rounded, t)
return inp_flq
def backward(self, grad_inp_flq):
(inp_scaled, inp_rounded, t) = self.saved_tensors
mask_clip = (inp_scaled < -0.5 + self.lowerbound) + (
inp_scaled > self.upperbound + 0.5
) # mask for accumulating the gradients of |data_scaled|>L
mask_quant = np.abs(
mask_clip - 1
) # mask for accumulating the gradients with |data_scaled|<=L
grad_quant = (
grad_inp_flq * mask_quant * (inp_rounded - inp_scaled)
) # gradient within |data_scaled|<=L
grad_clip = (
grad_inp_flq * mask_clip * inp_rounded
) # gradient with | data_scaled|>L
grad_s = grad_clip.sum() + grad_quant.sum()
# dL/ds = dL/dt * t * ln(2)
grad_s = grad_s * t * np.log(2)
grad_inp = grad_inp_flq * mask_quant
return grad_inp, grad_s
def test_TQT():
f = TQT_Function(-127, 127)
nf = numpy_TQT_Function(-127, 127)
def check_inp(a, b, c, a_np, b_np, c_np):
assertTensorClose(
f.forward(a, b).numpy(), nf.forward(a_np, b_np).astype("float32")
)
c1, c2 = f.backward(c)
c1_np, c2_np = nf.backward(c_np)
assertTensorClose(c1.numpy(), c1_np.astype("float32"))
assertTensorClose(c2.numpy(), c2_np.astype("float32"))
a = tensor()
b = tensor()
a_np = np.random.random((4, 3)).astype("float32")
b_np = np.random.random((1)).astype("float32")
a.set_value(a_np)
b.set_value(b_np)
check_inp(a, b, b, a_np, b_np, b_np)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册