提交 205291a3 编写于 作者: M Megvii Engine Team

feat(mge/quantization): add histgram observer

GitOrigin-RevId: a9252a6bafe19b3ac958acb7a617bbbb47dc1514
上级 7c4f1a38
......@@ -486,8 +486,16 @@ class QATModule(Module):
self.weight_observer = qconfig.weight_observer()
self.act_observer = qconfig.act_observer()
self.weight_fake_quant = qconfig.fake_quant(self.weight_observer.dtype)
self.act_fake_quant = qconfig.fake_quant(self.act_observer.dtype)
self.weight_fake_quant = (
None
if qconfig.fake_quant is None
else qconfig.fake_quant(self.weight_observer.dtype)
)
self.act_fake_quant = (
None
if qconfig.fake_quant is None
else qconfig.fake_quant(self.act_observer.dtype)
)
def apply_observer(self, target: Tensor, obs: "Observer"):
return obs(target)
......@@ -496,11 +504,10 @@ class QATModule(Module):
self, target: Tensor, fq: "FakeQuantize", obs: "Observer"
):
oup = self.apply_observer(target, obs)
if self.quantizing == self.QATMode.CALIBRATION:
return oup
else:
if fq is not None:
scale, zero_point = obs.get_qparams()
return fq(oup, scale, zero_point)
oup = fq(oup, scale, zero_point)
return oup
def set_qat_mode(self, mode: QATMode):
r"""
......
......@@ -6,8 +6,13 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .fake_quant import FakeQuantize
from .observer import Observer
from .qconfig import QConfig, ema_fakequant_qconfig, min_max_fakequant_qconfig
from .observer import HistogramObserver, Observer
from .qconfig import (
QConfig,
calibration_qconfig,
ema_fakequant_qconfig,
min_max_fakequant_qconfig,
)
from .quantize import (
disable_fake_quant,
disable_observer,
......
......@@ -5,6 +5,7 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import math
from abc import abstractmethod
import numpy as np
......@@ -12,6 +13,7 @@ import numpy as np
from .. import functional as F
from .._internal.dtype import _metadata_dict, get_quantized_dtype
from ..core import Buffer, Function, tensor
from ..jit import sideeffect
from ..module import Module
......@@ -94,9 +96,11 @@ class MinMaxObserver(Observer):
F.add_update(self.max_val, tmp_max, alpha=0.0, beta=1.0, bias=0.0)
F.add_update(self.first_flag, self.not_flag, alpha=0.0, beta=1.0, bias=0.0)
def get_qparams(self):
def _calculate_qparams(self, inp_min_val, inp_max_val):
min_val = F.minimum(0.0, inp_min_val)
max_val = F.maximum(0.0, inp_max_val)
if self.symmetric:
symmetric_max_vals = F.maximum(-self.min_val, self.max_val)
symmetric_max_vals = F.maximum(-min_val, max_val)
# use maximun to avoid scale too small at the begin
scale = F.maximum(
symmetric_max_vals / ((self.qmax - self.qmin) / 2), self.scale_limit
......@@ -105,14 +109,16 @@ class MinMaxObserver(Observer):
else:
# use maximun to avoid scale too small at the begin
scale = F.maximum(
(self.max_val - self.min_val) / (self.qmax - self.qmin),
self.scale_limit,
(max_val - min_val) / (self.qmax - self.qmin), self.scale_limit,
)
# caculate zero_point
zero_point = self.qmin - Round()((self.min_val / scale))
zero_point = self.qmin - Round()((min_val / scale))
return scale, zero_point
def get_qparams(self):
return self._calculate_qparams(self.min_val, self.max_val)
def forward(self, x_orig):
if self.enabled:
# stop gradient
......@@ -161,3 +167,251 @@ class ExponentialMovingAverageObserver(MinMaxObserver):
)
self.set_min_max(tmp_min, tmp_max)
return x_orig
class HistogramObserver(MinMaxObserver):
def __init__(self, bins=2048, upsample_rate=128, dtype="qint8", *args, **kwargs):
super().__init__(dtype=dtype, *args, **kwargs)
self.bins = bins
self.upsample_rate = upsample_rate
self.dst_nbins = _metadata_dict[dtype].qmax - _metadata_dict[dtype].qmin + 1
self.histogram = None
def _non_linear_param_search(self):
r"""Non-linear parameter search.
An approximation for L2 error minimization for selecting min/max.
By selecting new min/max, we filter out outliers in input distribution.
"""
def _get_norm(delta_begin, delta_end, density, norm_type):
r"""
Compute the norm of the values uniformaly distributed between
delta_begin and delta_end.
norm = density * (integral_{begin, end} x^2)
= density * (end^3 - begin^3) / 3
"""
assert norm_type == "L2", "Only L2 norms are currently supported"
norm = 0.0
if norm_type == "L2":
norm = (
delta_end * delta_end * delta_end
- delta_begin * delta_begin * delta_begin
) / 3
return density * norm
def _compute_quantization_error(next_start_bin, next_end_bin, norm_type):
r"""
Compute the quantization error if we use start_bin to end_bin as the
min and max to do the quantization.
"""
np_min_val = self.min_val.numpy()[0]
np_max_val = self.max_val.numpy()[0]
bin_width = (np_max_val - np_min_val) / self.bins
norm = 0.0
dst_bin_width = (
bin_width * (next_end_bin - next_start_bin + 1) / self.dst_nbins
)
if dst_bin_width == 0.0:
return 0.0
for src_bin in range(self.bins):
# distances from the beginning of first dst_bin to the beginning and
# end of src_bin
src_bin_begin = (src_bin - next_start_bin) * bin_width
src_bin_end = src_bin_begin + bin_width
# which dst_bins the beginning and end of src_bin belong to?
dst_bin_of_begin = min(
self.dst_nbins - 1,
max(0.0, math.floor(src_bin_begin / dst_bin_width)),
)
dst_bin_of_end = min(
self.dst_nbins - 1,
max(0.0, math.floor(src_bin_end / dst_bin_width)),
)
dst_bin_of_begin_center = (
dst_bin_of_begin * dst_bin_width + dst_bin_width / 2
)
density = self.histogram[src_bin] / bin_width
if dst_bin_of_begin == dst_bin_of_end:
# if src_bin is entirely within 1 dst_bin
delta_begin = src_bin_begin - dst_bin_of_begin_center
delta_end = src_bin_end - dst_bin_of_begin_center
norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
else:
delta_begin = src_bin_begin - dst_bin_of_begin_center
delta_end = dst_bin_width / 2
norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
norm = norm + (dst_bin_of_end - dst_bin_of_begin - 1) * _get_norm(
-dst_bin_width / 2, dst_bin_width / 2, density, norm_type
)
dst_bin_of_end_center = (
dst_bin_of_end * dst_bin_width + dst_bin_width / 2
)
delta_begin = -dst_bin_width / 2
delta_end = src_bin_end - dst_bin_of_end_center
norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
return norm
assert len(self.histogram) == self.bins, "bins mistmatch"
bin_width = (self.max_val - self.min_val) / self.bins
# cumulative sum
total = sum(self.histogram)
cSum = np.cumsum(self.histogram, axis=0)
stepsize = 1e-5 # granularity
alpha = 0.0 # lower bound
beta = 1.0 # upper bound
start_bin = 0
end_bin = self.bins - 1
norm_min = float("inf")
while alpha < beta:
# Find the next step
next_alpha = alpha + stepsize
next_beta = beta - stepsize
# find the left and right bins between the quantile bounds
l = start_bin
r = end_bin
while l < end_bin and cSum[l] < next_alpha * total:
l = l + 1
while r > start_bin and cSum[r] > next_beta * total:
r = r - 1
# decide the next move
next_start_bin = start_bin
next_end_bin = end_bin
if (l - start_bin) > (end_bin - r):
# move the start bin
next_start_bin = l
alpha = next_alpha
else:
# move the end bin
next_end_bin = r
beta = next_beta
if next_start_bin == start_bin and next_end_bin == end_bin:
continue
# calculate the quantization error using next_start_bin and next_end_bin
norm = _compute_quantization_error(next_start_bin, next_end_bin, "L2")
if norm > norm_min:
break
norm_min = norm
start_bin = next_start_bin
end_bin = next_end_bin
new_min = self.min_val + bin_width * start_bin
new_max = self.min_val + bin_width * (end_bin + 1)
return new_min, new_max
def get_qparams(self):
new_min, new_max = self._non_linear_param_search()
return self._calculate_qparams(new_min, new_max)
def _combine_histograms(
self, orig_hist, new_hist, upsample_rate, downsample_rate, start_idx, Nbins
):
# First up-sample the histogram with new data by a factor of L
# This creates an approximate probability density thats piecwise constant
upsampled_histogram = new_hist.repeat(upsample_rate)
# Now insert the upsampled histogram into the output
# histogram, which is initialized with zeros.
# The offset at which the histogram is introduced is determined
# by the start index as the output histogram can cover a wider range
histogram_with_output_range = np.zeros((Nbins * downsample_rate))
histogram_with_output_range[
start_idx : Nbins * upsample_rate + start_idx
] = upsampled_histogram
# Compute integral histogram, double precision is needed to ensure
# that there are no overflows
integral_histogram = np.cumsum(histogram_with_output_range, 0)[
downsample_rate - 1 :: downsample_rate
]
# Finally perform interpolation
shifted_integral_histogram = np.zeros((Nbins))
shifted_integral_histogram[1:Nbins] = integral_histogram[0:-1]
interpolated_histogram = (
integral_histogram - shifted_integral_histogram
) / upsample_rate
orig_hist = orig_hist + interpolated_histogram
return orig_hist
def _adjust_min_max(self, combined_min, combined_max, upsample_rate):
# We ensure that:
# (combined_max - combined_min)/(downsample_rate*Nbins) = (max - min)/(upsample_rate*Nbins)
# This allows us to have a common grid of resolution s, where we can align
# the input histogram
# start_idx maps min_val to the histogram bin index.
np_min_val = self.min_val.numpy()[0]
np_max_val = self.max_val.numpy()[0]
hist_bin_width = (np_max_val - np_min_val) / (self.bins * upsample_rate)
downsample_rate = int(
np.ceil((combined_max - combined_min) / (self.bins * hist_bin_width))
)
e = downsample_rate * (self.bins * hist_bin_width) - (
combined_max - combined_min
)
combined_max = combined_max + e / 2
combined_min = combined_min - e / 2
start_idx = int(np.round((np_min_val - combined_min) / hist_bin_width))
return combined_min, combined_max, downsample_rate, start_idx
@sideeffect
def sideeffect_forward(self, x_orig):
x = x_orig.numpy()
min_val = self.min_val.numpy()[0]
max_val = self.max_val.numpy()[0]
if min_val == 0 or max_val == 0:
min_val = x.min()
max_val = x.max()
self.min_val.set_value(min_val)
self.max_val.set_value(max_val)
self.histogram, _ = np.histogram(x, self.bins, (min_val, max_val))
self.histogram = self.histogram.astype(np.float64)
else:
new_min = x.min()
new_max = x.max()
combined_min = min(new_min, min_val)
combined_max = max(new_max, max_val)
# combine the existing histogram and new histogram into 1 histogram
# We do this by first upsampling the histogram to a dense grid
# and then downsampling the histogram efficiently
(
combined_min,
combined_max,
downsample_rate,
start_idx,
) = self._adjust_min_max(combined_min, combined_max, self.upsample_rate)
combined_histogram, _ = np.histogram(
x, self.bins, (combined_min, combined_max)
)
combined_histogram = combined_histogram.astype(np.float64)
if combined_min == min_val and combined_max == max_val:
combined_histogram += self.histogram
else:
combined_histogram = self._combine_histograms(
combined_histogram,
self.histogram,
self.upsample_rate,
downsample_rate,
start_idx,
self.bins,
)
self.histogram = combined_histogram
self.min_val.set_value(combined_min)
self.max_val.set_value(combined_max)
def forward(self, x_orig):
self.sideeffect_forward(x_orig)
return x_orig
......@@ -5,11 +5,13 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from functools import partial
from ..module import Module
from .fake_quant import FakeQuantize
from .observer import ExponentialMovingAverageObserver, MinMaxObserver
from .observer import (
ExponentialMovingAverageObserver,
HistogramObserver,
MinMaxObserver,
)
class QConfig:
......@@ -66,3 +68,7 @@ ema_fakequant_qconfig = QConfig(
act_observer=ExponentialMovingAverageObserver,
fake_quant=FakeQuantize,
)
calibration_qconfig = QConfig(
weight_observer=MinMaxObserver, act_observer=HistogramObserver, fake_quant=None,
)
......@@ -71,7 +71,6 @@ def quantize_calibration(module: Module, qconfig: QConfig = ema_fakequant_qconfi
mod.set_qconfig(qconfig)
module.apply(fn)
enable_observer(module)
def disable_fake_quant(module: Module):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册