diff --git a/python_module/megengine/module/module.py b/python_module/megengine/module/module.py index 60e77dca0a7f53476e6189ed8315fd3163cb9196..1ee440aba18c3b682042630764a8dacbb4811223 100644 --- a/python_module/megengine/module/module.py +++ b/python_module/megengine/module/module.py @@ -486,8 +486,16 @@ class QATModule(Module): self.weight_observer = qconfig.weight_observer() self.act_observer = qconfig.act_observer() - self.weight_fake_quant = qconfig.fake_quant(self.weight_observer.dtype) - self.act_fake_quant = qconfig.fake_quant(self.act_observer.dtype) + self.weight_fake_quant = ( + None + if qconfig.fake_quant is None + else qconfig.fake_quant(self.weight_observer.dtype) + ) + self.act_fake_quant = ( + None + if qconfig.fake_quant is None + else qconfig.fake_quant(self.act_observer.dtype) + ) def apply_observer(self, target: Tensor, obs: "Observer"): return obs(target) @@ -496,11 +504,10 @@ class QATModule(Module): self, target: Tensor, fq: "FakeQuantize", obs: "Observer" ): oup = self.apply_observer(target, obs) - if self.quantizing == self.QATMode.CALIBRATION: - return oup - else: + if fq is not None: scale, zero_point = obs.get_qparams() - return fq(oup, scale, zero_point) + oup = fq(oup, scale, zero_point) + return oup def set_qat_mode(self, mode: QATMode): r""" diff --git a/python_module/megengine/quantization/__init__.py b/python_module/megengine/quantization/__init__.py index 9d490be8b0ea019ddc07ad056d29fc757efb7d5a..1e99493fff001d599ee3041d84a408621f0cf165 100644 --- a/python_module/megengine/quantization/__init__.py +++ b/python_module/megengine/quantization/__init__.py @@ -6,8 +6,13 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from .fake_quant import FakeQuantize -from .observer import Observer -from .qconfig import QConfig, ema_fakequant_qconfig, min_max_fakequant_qconfig +from .observer import HistogramObserver, Observer +from .qconfig import ( + QConfig, + calibration_qconfig, + ema_fakequant_qconfig, + min_max_fakequant_qconfig, +) from .quantize import ( disable_fake_quant, disable_observer, diff --git a/python_module/megengine/quantization/observer.py b/python_module/megengine/quantization/observer.py index b6799e790744cb8a8568cf3a3ec435fba016d42f..bd85a9a9788e5d05f0e1e0645385878b42565017 100644 --- a/python_module/megengine/quantization/observer.py +++ b/python_module/megengine/quantization/observer.py @@ -5,6 +5,7 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import math from abc import abstractmethod import numpy as np @@ -12,6 +13,7 @@ import numpy as np from .. import functional as F from .._internal.dtype import _metadata_dict, get_quantized_dtype from ..core import Buffer, Function, tensor +from ..jit import sideeffect from ..module import Module @@ -94,9 +96,11 @@ class MinMaxObserver(Observer): F.add_update(self.max_val, tmp_max, alpha=0.0, beta=1.0, bias=0.0) F.add_update(self.first_flag, self.not_flag, alpha=0.0, beta=1.0, bias=0.0) - def get_qparams(self): + def _calculate_qparams(self, inp_min_val, inp_max_val): + min_val = F.minimum(0.0, inp_min_val) + max_val = F.maximum(0.0, inp_max_val) if self.symmetric: - symmetric_max_vals = F.maximum(-self.min_val, self.max_val) + symmetric_max_vals = F.maximum(-min_val, max_val) # use maximun to avoid scale too small at the begin scale = F.maximum( symmetric_max_vals / ((self.qmax - self.qmin) / 2), self.scale_limit @@ -105,14 +109,16 @@ class MinMaxObserver(Observer): else: # use maximun to avoid scale too small at the begin scale = F.maximum( - (self.max_val - self.min_val) / (self.qmax - self.qmin), - self.scale_limit, + (max_val - min_val) / (self.qmax - self.qmin), self.scale_limit, ) # caculate zero_point - zero_point = self.qmin - Round()((self.min_val / scale)) + zero_point = self.qmin - Round()((min_val / scale)) return scale, zero_point + def get_qparams(self): + return self._calculate_qparams(self.min_val, self.max_val) + def forward(self, x_orig): if self.enabled: # stop gradient @@ -161,3 +167,251 @@ class ExponentialMovingAverageObserver(MinMaxObserver): ) self.set_min_max(tmp_min, tmp_max) return x_orig + + +class HistogramObserver(MinMaxObserver): + def __init__(self, bins=2048, upsample_rate=128, dtype="qint8", *args, **kwargs): + super().__init__(dtype=dtype, *args, **kwargs) + self.bins = bins + self.upsample_rate = upsample_rate + self.dst_nbins = _metadata_dict[dtype].qmax - _metadata_dict[dtype].qmin + 1 + self.histogram = None + + def _non_linear_param_search(self): + r"""Non-linear parameter search. + An approximation for L2 error minimization for selecting min/max. + By selecting new min/max, we filter out outliers in input distribution. + """ + + def _get_norm(delta_begin, delta_end, density, norm_type): + r""" + Compute the norm of the values uniformaly distributed between + delta_begin and delta_end. + norm = density * (integral_{begin, end} x^2) + = density * (end^3 - begin^3) / 3 + """ + assert norm_type == "L2", "Only L2 norms are currently supported" + norm = 0.0 + if norm_type == "L2": + norm = ( + delta_end * delta_end * delta_end + - delta_begin * delta_begin * delta_begin + ) / 3 + return density * norm + + def _compute_quantization_error(next_start_bin, next_end_bin, norm_type): + r""" + Compute the quantization error if we use start_bin to end_bin as the + min and max to do the quantization. + """ + np_min_val = self.min_val.numpy()[0] + np_max_val = self.max_val.numpy()[0] + bin_width = (np_max_val - np_min_val) / self.bins + + norm = 0.0 + dst_bin_width = ( + bin_width * (next_end_bin - next_start_bin + 1) / self.dst_nbins + ) + if dst_bin_width == 0.0: + return 0.0 + for src_bin in range(self.bins): + # distances from the beginning of first dst_bin to the beginning and + # end of src_bin + src_bin_begin = (src_bin - next_start_bin) * bin_width + src_bin_end = src_bin_begin + bin_width + + # which dst_bins the beginning and end of src_bin belong to? + dst_bin_of_begin = min( + self.dst_nbins - 1, + max(0.0, math.floor(src_bin_begin / dst_bin_width)), + ) + dst_bin_of_end = min( + self.dst_nbins - 1, + max(0.0, math.floor(src_bin_end / dst_bin_width)), + ) + dst_bin_of_begin_center = ( + dst_bin_of_begin * dst_bin_width + dst_bin_width / 2 + ) + + density = self.histogram[src_bin] / bin_width + if dst_bin_of_begin == dst_bin_of_end: + # if src_bin is entirely within 1 dst_bin + delta_begin = src_bin_begin - dst_bin_of_begin_center + delta_end = src_bin_end - dst_bin_of_begin_center + norm = norm + _get_norm(delta_begin, delta_end, density, norm_type) + else: + delta_begin = src_bin_begin - dst_bin_of_begin_center + delta_end = dst_bin_width / 2 + norm = norm + _get_norm(delta_begin, delta_end, density, norm_type) + + norm = norm + (dst_bin_of_end - dst_bin_of_begin - 1) * _get_norm( + -dst_bin_width / 2, dst_bin_width / 2, density, norm_type + ) + + dst_bin_of_end_center = ( + dst_bin_of_end * dst_bin_width + dst_bin_width / 2 + ) + + delta_begin = -dst_bin_width / 2 + delta_end = src_bin_end - dst_bin_of_end_center + norm = norm + _get_norm(delta_begin, delta_end, density, norm_type) + return norm + + assert len(self.histogram) == self.bins, "bins mistmatch" + bin_width = (self.max_val - self.min_val) / self.bins + + # cumulative sum + total = sum(self.histogram) + cSum = np.cumsum(self.histogram, axis=0) + + stepsize = 1e-5 # granularity + alpha = 0.0 # lower bound + beta = 1.0 # upper bound + start_bin = 0 + end_bin = self.bins - 1 + norm_min = float("inf") + + while alpha < beta: + # Find the next step + next_alpha = alpha + stepsize + next_beta = beta - stepsize + + # find the left and right bins between the quantile bounds + l = start_bin + r = end_bin + while l < end_bin and cSum[l] < next_alpha * total: + l = l + 1 + while r > start_bin and cSum[r] > next_beta * total: + r = r - 1 + + # decide the next move + next_start_bin = start_bin + next_end_bin = end_bin + if (l - start_bin) > (end_bin - r): + # move the start bin + next_start_bin = l + alpha = next_alpha + else: + # move the end bin + next_end_bin = r + beta = next_beta + + if next_start_bin == start_bin and next_end_bin == end_bin: + continue + + # calculate the quantization error using next_start_bin and next_end_bin + norm = _compute_quantization_error(next_start_bin, next_end_bin, "L2") + + if norm > norm_min: + break + norm_min = norm + start_bin = next_start_bin + end_bin = next_end_bin + + new_min = self.min_val + bin_width * start_bin + new_max = self.min_val + bin_width * (end_bin + 1) + return new_min, new_max + + def get_qparams(self): + new_min, new_max = self._non_linear_param_search() + return self._calculate_qparams(new_min, new_max) + + def _combine_histograms( + self, orig_hist, new_hist, upsample_rate, downsample_rate, start_idx, Nbins + ): + # First up-sample the histogram with new data by a factor of L + # This creates an approximate probability density thats piecwise constant + upsampled_histogram = new_hist.repeat(upsample_rate) + # Now insert the upsampled histogram into the output + # histogram, which is initialized with zeros. + # The offset at which the histogram is introduced is determined + # by the start index as the output histogram can cover a wider range + histogram_with_output_range = np.zeros((Nbins * downsample_rate)) + histogram_with_output_range[ + start_idx : Nbins * upsample_rate + start_idx + ] = upsampled_histogram + # Compute integral histogram, double precision is needed to ensure + # that there are no overflows + integral_histogram = np.cumsum(histogram_with_output_range, 0)[ + downsample_rate - 1 :: downsample_rate + ] + # Finally perform interpolation + shifted_integral_histogram = np.zeros((Nbins)) + shifted_integral_histogram[1:Nbins] = integral_histogram[0:-1] + interpolated_histogram = ( + integral_histogram - shifted_integral_histogram + ) / upsample_rate + orig_hist = orig_hist + interpolated_histogram + return orig_hist + + def _adjust_min_max(self, combined_min, combined_max, upsample_rate): + # We ensure that: + # (combined_max - combined_min)/(downsample_rate*Nbins) = (max - min)/(upsample_rate*Nbins) + # This allows us to have a common grid of resolution s, where we can align + # the input histogram + # start_idx maps min_val to the histogram bin index. + np_min_val = self.min_val.numpy()[0] + np_max_val = self.max_val.numpy()[0] + + hist_bin_width = (np_max_val - np_min_val) / (self.bins * upsample_rate) + downsample_rate = int( + np.ceil((combined_max - combined_min) / (self.bins * hist_bin_width)) + ) + e = downsample_rate * (self.bins * hist_bin_width) - ( + combined_max - combined_min + ) + combined_max = combined_max + e / 2 + combined_min = combined_min - e / 2 + start_idx = int(np.round((np_min_val - combined_min) / hist_bin_width)) + + return combined_min, combined_max, downsample_rate, start_idx + + @sideeffect + def sideeffect_forward(self, x_orig): + x = x_orig.numpy() + min_val = self.min_val.numpy()[0] + max_val = self.max_val.numpy()[0] + if min_val == 0 or max_val == 0: + min_val = x.min() + max_val = x.max() + self.min_val.set_value(min_val) + self.max_val.set_value(max_val) + self.histogram, _ = np.histogram(x, self.bins, (min_val, max_val)) + self.histogram = self.histogram.astype(np.float64) + else: + new_min = x.min() + new_max = x.max() + combined_min = min(new_min, min_val) + combined_max = max(new_max, max_val) + # combine the existing histogram and new histogram into 1 histogram + # We do this by first upsampling the histogram to a dense grid + # and then downsampling the histogram efficiently + ( + combined_min, + combined_max, + downsample_rate, + start_idx, + ) = self._adjust_min_max(combined_min, combined_max, self.upsample_rate) + + combined_histogram, _ = np.histogram( + x, self.bins, (combined_min, combined_max) + ) + combined_histogram = combined_histogram.astype(np.float64) + if combined_min == min_val and combined_max == max_val: + combined_histogram += self.histogram + else: + combined_histogram = self._combine_histograms( + combined_histogram, + self.histogram, + self.upsample_rate, + downsample_rate, + start_idx, + self.bins, + ) + self.histogram = combined_histogram + self.min_val.set_value(combined_min) + self.max_val.set_value(combined_max) + + def forward(self, x_orig): + self.sideeffect_forward(x_orig) + return x_orig diff --git a/python_module/megengine/quantization/qconfig.py b/python_module/megengine/quantization/qconfig.py index 9a52430088c5773bc990181ce261aa4227c5bd6c..cfabdc58eebadfbb21f3f02de2d57503f8e1992c 100644 --- a/python_module/megengine/quantization/qconfig.py +++ b/python_module/megengine/quantization/qconfig.py @@ -5,11 +5,13 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -from functools import partial - from ..module import Module from .fake_quant import FakeQuantize -from .observer import ExponentialMovingAverageObserver, MinMaxObserver +from .observer import ( + ExponentialMovingAverageObserver, + HistogramObserver, + MinMaxObserver, +) class QConfig: @@ -66,3 +68,7 @@ ema_fakequant_qconfig = QConfig( act_observer=ExponentialMovingAverageObserver, fake_quant=FakeQuantize, ) + +calibration_qconfig = QConfig( + weight_observer=MinMaxObserver, act_observer=HistogramObserver, fake_quant=None, +) diff --git a/python_module/megengine/quantization/quantize.py b/python_module/megengine/quantization/quantize.py index 1bfba352a322f41808d67db6d95807db51abfba6..3296f5762dac58936f7751b8ee3dbfd19fc0bde2 100644 --- a/python_module/megengine/quantization/quantize.py +++ b/python_module/megengine/quantization/quantize.py @@ -71,7 +71,6 @@ def quantize_calibration(module: Module, qconfig: QConfig = ema_fakequant_qconfi mod.set_qconfig(qconfig) module.apply(fn) - enable_observer(module) def disable_fake_quant(module: Module):