提交 7b2c5a73 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

fix(mge/quantization): fix histogram observer load and store issue

GitOrigin-RevId: b0a2b476e4490f960fc6fde48a715bbdab5ce128
上级 e6820b91
......@@ -6,7 +6,7 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .fake_quant import FakeQuantize
from .observer import HistogramObserver, Observer
from .observer import HistogramObserver, Observer, ObserverMode
from .qconfig import (
QConfig,
calibration_qconfig,
......
......@@ -132,7 +132,7 @@ class MinMaxObserver(Observer):
(max_val - min_val) / (self.qmax - self.qmin), self.scale_limit,
)
# caculate zero_point
q_dict["zero_point"] = self.qmin - Round()((min_val / scale))
q_dict["zero_point"] = self.qmin - Round()((min_val / q_dict["scale"]))
return q_dict
......@@ -204,7 +204,7 @@ class HistogramObserver(MinMaxObserver):
self.bins = bins
self.upsample_rate = upsample_rate
self.dst_nbins = _metadata_dict[dtype].qmax - _metadata_dict[dtype].qmin + 1
self.histogram = None
self.histogram = Buffer([0.0] * bins)
def _non_linear_param_search(self):
r"""Non-linear parameter search.
......@@ -212,6 +212,12 @@ class HistogramObserver(MinMaxObserver):
By selecting new min/max, we filter out outliers in input distribution.
"""
np_min_val = self.min_val.numpy()[0]
np_max_val = self.max_val.numpy()[0]
np_histogram = self.histogram.numpy()
assert len(np_histogram) == self.bins, "bins mistmatch"
bin_width = (np_max_val - np_min_val) / self.bins
def _get_norm(delta_begin, delta_end, density, norm_type):
r"""
Compute the norm of the values uniformaly distributed between
......@@ -233,9 +239,6 @@ class HistogramObserver(MinMaxObserver):
Compute the quantization error if we use start_bin to end_bin as the
min and max to do the quantization.
"""
np_min_val = self.min_val.numpy()[0]
np_max_val = self.max_val.numpy()[0]
bin_width = (np_max_val - np_min_val) / self.bins
norm = 0.0
dst_bin_width = (
......@@ -262,7 +265,7 @@ class HistogramObserver(MinMaxObserver):
dst_bin_of_begin * dst_bin_width + dst_bin_width / 2
)
density = self.histogram[src_bin] / bin_width
density = np_histogram[src_bin] / bin_width
if dst_bin_of_begin == dst_bin_of_end:
# if src_bin is entirely within 1 dst_bin
delta_begin = src_bin_begin - dst_bin_of_begin_center
......@@ -286,12 +289,9 @@ class HistogramObserver(MinMaxObserver):
norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
return norm
assert len(self.histogram) == self.bins, "bins mistmatch"
bin_width = (self.max_val - self.min_val) / self.bins
# cumulative sum
total = sum(self.histogram)
cSum = np.cumsum(self.histogram, axis=0)
total = sum(np_histogram)
cSum = np.cumsum(np_histogram, axis=0)
stepsize = 1e-5 # granularity
alpha = 0.0 # lower bound
......@@ -400,46 +400,39 @@ class HistogramObserver(MinMaxObserver):
x = x_orig.numpy()
min_val = self.min_val.numpy()[0]
max_val = self.max_val.numpy()[0]
histogram = self.histogram.numpy()
new_min = x.min()
new_max = x.max()
if min_val == 0 or max_val == 0:
min_val = x.min()
max_val = x.max()
self.min_val.set_value(min_val)
self.max_val.set_value(max_val)
self.histogram, _ = np.histogram(x, self.bins, (min_val, max_val))
self.histogram = self.histogram.astype(np.float64)
new_histogram, _ = np.histogram(x, self.bins, (new_min, new_max))
else:
new_min = x.min()
new_max = x.max()
combined_min = min(new_min, min_val)
combined_max = max(new_max, max_val)
new_min = min(new_min, min_val)
new_max = max(new_max, max_val)
# combine the existing histogram and new histogram into 1 histogram
# We do this by first upsampling the histogram to a dense grid
# and then downsampling the histogram efficiently
(
combined_min,
combined_max,
downsample_rate,
start_idx,
) = self._adjust_min_max(combined_min, combined_max, self.upsample_rate)
combined_histogram, _ = np.histogram(
x, self.bins, (combined_min, combined_max)
(new_min, new_max, downsample_rate, start_idx,) = self._adjust_min_max(
new_min, new_max, self.upsample_rate
)
combined_histogram = combined_histogram.astype(np.float64)
if combined_min == min_val and combined_max == max_val:
combined_histogram += self.histogram
new_histogram, _ = np.histogram(x, self.bins, (new_min, new_max))
new_histogram = new_histogram.astype(np.float64)
if new_min == min_val and new_max == max_val:
new_histogram += histogram
else:
combined_histogram = self._combine_histograms(
combined_histogram,
self.histogram,
new_histogram = self._combine_histograms(
new_histogram,
histogram,
self.upsample_rate,
downsample_rate,
start_idx,
self.bins,
)
self.histogram = combined_histogram
self.min_val.set_value(combined_min)
self.max_val.set_value(combined_max)
self.histogram.set_value(new_histogram)
self.min_val.set_value(new_min)
self.max_val.set_value(new_max)
def forward(self, x_orig):
self.sideeffect_forward(x_orig)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册