diff --git a/example/auto_compression/image_classification/configs/eval.yaml b/example/auto_compression/image_classification/configs/eval.yaml index b11e4c048f283b087a7225621f4269958802c971..f7e61515c63439f0233268d05aaa968a43938192 100644 --- a/example/auto_compression/image_classification/configs/eval.yaml +++ b/example/auto_compression/image_classification/configs/eval.yaml @@ -1,8 +1,8 @@ Global: - model_dir: './MobileNetV1_infer' + model_dir: './mobilenet_dbb_inference' model_filename: 'inference.pdmodel' params_filename: "inference.pdiparams" batch_size: 128 - data_dir: './ILSVRC2012_data_demo/ILSVRC2012/' + data_dir: './ILSVRC2012/' img_size: 224 resize_size: 256 diff --git a/example/auto_compression/image_classification/eval.py b/example/auto_compression/image_classification/eval.py index 790f354354bd82dccc9b3e6e33a679b05d32305d..ecf6e04aae29c1ffe5a111eeee103d93826bde89 100644 --- a/example/auto_compression/image_classification/eval.py +++ b/example/auto_compression/image_classification/eval.py @@ -31,12 +31,12 @@ def argsparser(): parser.add_argument( '--config_path', type=str, - default='./image_classification/configs/eval.yaml', + default='./configs/eval.yaml', help="path of compression strategy config.") parser.add_argument( '--model_dir', type=str, - default='./MobileNetV1_infer', + default='./mobilenet_dbb_inference', help='model directory') return parser @@ -65,6 +65,15 @@ def eval(): exe, model_filename=global_config["model_filename"], params_filename=global_config["params_filename"]) + + features = None + for _var in val_program.list_vars(): + print(f"meeting: {_var.name}") + if _var.name == "conv2d_98.tmp_1": + print(f"find {_var.name}") + features = _var + + fetch_targets.append(features) print('Loaded model from: {}'.format(global_config["model_dir"])) val_loader = eval_reader( @@ -77,9 +86,13 @@ def eval(): for batch_id, (image, label) in enumerate(val_loader): image = np.array(image) label = np.array(label).astype('int64') - pred = exe.run(val_program, - feed={feed_target_names[0]: image}, - fetch_list=fetch_targets) + pred = exe.run( + val_program, + feed={feed_target_names[0]: image}, + fetch_list=fetch_targets) + features = np.array(pred[1]) + print(f"feature shape: {features.shape}") + pred = np.array(pred[0]) label = np.array(label) sort_array = pred.argsort(axis=1) @@ -92,6 +105,7 @@ def eval(): acc_num += 1 top_5 = float(acc_num) / len(label) results.append([top_1, top_5]) + break result = np.mean(np.array(results), axis=0) return result[0] @@ -107,10 +121,10 @@ def main(args): global_config['model_dir'] = args.model_dir global img_size, resize_size - img_size = int(global_config[ - 'img_size']) if 'img_size' in global_config else 224 - resize_size = int(global_config[ - 'resize_size']) if 'resize_size' in global_config else 256 + img_size = int( + global_config['img_size']) if 'img_size' in global_config else 224 + resize_size = int( + global_config['resize_size']) if 'resize_size' in global_config else 256 result = eval() print('Eval Top1:', result) diff --git a/paddleslim/nas/ofa/ofa.py b/paddleslim/nas/ofa/ofa.py index 1917bd1462161e4725c83f38a4c4d70fbc5ded66..35ea705027027fb08156bb84e99c57d257596707 100644 --- a/paddleslim/nas/ofa/ofa.py +++ b/paddleslim/nas/ofa/ofa.py @@ -359,8 +359,8 @@ class OFA(OFABase): if isinstance(v, dict): sample_cands[k] = self._sample_from_nestdict( v, sample_type=sample_type, task=task, phase=phase) - elif isinstance(v, list) or isinstance(v, set) or isinstance(v, - tuple): + elif isinstance(v, list) or isinstance(v, set) or isinstance( + v, tuple): if sample_type == 'largest': sample_cands[k] = v[-1] elif sample_type == 'smallest': @@ -413,8 +413,8 @@ class OFA(OFABase): key = all_tokens.index(cand) self.token_map[self.task][name][key] = cand else: - raise NotImplementedError("Task {} not in ofa layers".format( - self.task)) + raise NotImplementedError( + "Task {} not in ofa layers".format(self.task)) self.search_cands = [] for layer, t_map in self.token_map[self.task].items(): @@ -610,8 +610,8 @@ class OFA(OFABase): print(f"hit cpu in ofa-------------------------------") place = paddle.CPUPlace() else: - place = paddle.framework.core.CUDAPlace(p.gpu_device_id( - )) + place = paddle.framework.core.CUDAPlace( + p.gpu_device_id()) t_value.set(pruned_state_dict[name], place) if super_model_state_dict != None and len(super_model_state_dict) != 0: @@ -739,10 +739,9 @@ class OFA(OFABase): if key not in self._param2key.keys(): continue - ### if skip_layers and same ss both have same layer, - ### the layers related to this layer need to add to skip_layers - if self._skip_layers != None and self._param2key[ - key] in self._skip_layers: + ### if skip_layers and same ss both have same layer, + ### the layers related to this layer need to add to skip_layers + if self._skip_layers != None and self._param2key[key] in self._skip_layers: self._skip_layers += [self._param2key[sk] for sk in ss] per_ss = [] break @@ -758,8 +757,8 @@ class OFA(OFABase): self._same_ss = tmp_same_ss - ### if fixed_by_input layer in a same ss, - ### layers in this same ss should all be fixed + ### if fixed_by_input layer in a same ss, + ### layers in this same ss should all be fixed tmp_fixed_by_input = [] for ss in self._same_ss: for key in fixed_by_input: @@ -781,7 +780,7 @@ class OFA(OFABase): set(output_conv + fixed_by_input + depthwise_conv)) ### clear depthwise convs from search space because of its output channel cannot change ### clear output convs from search space because of model output shape cannot change - ### clear convs that operate with fixed input + ### clear convs that operate with fixed input for name, sublayer in model_to_traverse.named_sublayers(): if isinstance(sublayer, BaseBlock): for param in sublayer.parameters(): @@ -794,8 +793,8 @@ class OFA(OFABase): teacher_output = None if self._add_teacher: self._reset_hook_before_forward() - teacher_output = self.ofa_teacher_model.model.forward(*inputs, - **kwargs) + teacher_output = self.ofa_teacher_model.model.forward( + *inputs, **kwargs) # ============================================================ # ==================== student process ===================== diff --git a/paddleslim/quant/observers/__init__.py b/paddleslim/quant/observers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e4cfb60a336dd9eacefd98a251f60347e0b366b8 --- /dev/null +++ b/paddleslim/quant/observers/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .hist import HistObserver +from .kl import KLObserver + +__all__ = ["HistObserver", "KLObserver"] \ No newline at end of file diff --git a/paddleslim/quant/observers/base_hist.py b/paddleslim/quant/observers/base_hist.py new file mode 100644 index 0000000000000000000000000000000000000000..1a4755071ba5c1f896bdca1e5554a5b9b0287ba9 --- /dev/null +++ b/paddleslim/quant/observers/base_hist.py @@ -0,0 +1,200 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from typing import Tuple +import paddle +import numpy as np + +from .uniform import UniformObserver + + +class BaseHistObserver(UniformObserver): + """ + It is a base class of histogram observers defined some functions to + collects the values of multi batches to a histogram. + Args: + quant_bits (int): The number of bits for quantization. + sign (bool): Whether the quantized integer includes a sign. + symmetric (bool): Whether it is symmetric quantization. the quantization is symmetric. + In symmetric quantization, the range of floating point values is relaxed to be symmetric + around zero and the zero-point is always 0. + bins_count(int): The number of equal-width bins. + """ + + def __init__(self, quant_bits=8, bins_count=2048, sign=True, + symmetric=True): + super(BaseHistObserver, self).__init__( + quant_bits=quant_bits, + sign=sign, + symmetric=symmetric, ) + self._bin_count = bins_count + self._upsample_bin_count = 64 + + self._hist_min = None + self._hist_max = None + self._hist = None + + def _min_max(self, tensor): + """" Get the min and max value of a tensor. + """ + return float(paddle.min(tensor).numpy()), float( + paddle.max(tensor).numpy()) + + def _init_hists(self, inputs): + """" Initialize the histogram instance based on a tensor. + """ + _min, _max = self._min_max(inputs) + hist = None + if _max > _min: + hist, _ = np.histogram( + inputs.numpy(), range=(_min, _max), bins=self._bin_count) + hist.astype(np.float32) + return hist + + def forward(self, inputs): + self._scale = None + self._zero_point = None + self._min = None + self._max = None + + if self._hist_min is None or self._hist_max is None: + self._hist_min, self._hist_max = self._min_max(inputs) + self._hist = self._init_hists(inputs) + else: + new_min, new_max, new_hist = self._update_min_max_and_hist( + inputs, + self._hist_min, + self._hist_max, + self._hist, + self._bin_count, + self._upsample_bin_count, ) + self._hist_min, self._hist_max = new_min, new_max + self._hist = new_hist + return inputs + + def _update_min_max_and_hist(self, tensor, origin_min, origin_max, + origin_hist, bins_count, upsample_bins_count): + """ Update the histogram and its range based on the values of the target tensor. + Args: + tensor: The tensor used to update the histogram. + origin_min(float): The minimum of the original histogram's range. + origin_max(float): The max of the original histogram's range. + origin_hist: The original histogram. + bins_count(int): The number of histogram bins. + upsample_bins_count(int): The number of upsampled bins used to extend the histogram. + """ + + _origin_min, _origin_max = origin_min, origin_max + _new_min, _new_max = self._min_max(tensor) + + if (_new_max - _new_min) == 0.0: + return _origin_min, _origin_max, origin_hist + elif _origin_max - _origin_min == 0.0: + new_hist, _ = np.histogram( + tensor.numpy(), range=(_new_min, _new_max), bins=bins_count) + new_hist = new_hist.astype(np.float32) + return _new_min, _new_max, new_hist + elif _new_max <= _origin_max and _new_min >= _origin_min: + new_hist, _ = np.histogram( + tensor.numpy(), + range=(_origin_min, _origin_max), + bins=bins_count) + new_hist = new_hist.astype(np.float32) + new_hist += origin_hist + return _origin_min, _origin_max, new_hist + else: + _new_min = min(_new_min, _origin_min) + _new_max = max(_new_max, _origin_max) + _new_min, _new_max, downsample_bins_count, start_bin_idx = self._relax_min_max( + _new_min, _new_max, _origin_min, _origin_max, bins_count, + upsample_bins_count) + + new_hist, _ = np.histogram( + tensor.numpy(), range=(_new_min, _new_max), bins=bins_count) + + merged_histogram = self._merge_histograms( + new_hist, origin_hist, upsample_bins_count, + downsample_bins_count, start_bin_idx, bins_count) + return _new_min, _new_max, merged_histogram + + def _merge_histograms( + self, + new_hist: np.ndarray, + origin_hist: np.ndarray, + upsample_bins_count: int, + downsample_bins_count: int, + start_bin_idx: int, + bins_count: int, ): + upsampled_histogram = np.repeat(origin_hist, upsample_bins_count) + expanded_hist = np.zeros( + (bins_count * downsample_bins_count), dtype=np.float32) + expanded_hist[start_bin_idx:bins_count * upsample_bins_count + + start_bin_idx] = upsampled_histogram + + cumsumed_hist = np.cumsum( + expanded_hist, + dtype=np.float64)[downsample_bins_count - 1::downsample_bins_count] + shift_cumsumed_hist = np.zeros((bins_count), dtype=np.float64) + shift_cumsumed_hist[1:] = cumsumed_hist[0:-1] + sampled_hist = ( + cumsumed_hist - shift_cumsumed_hist) / upsample_bins_count + new_hist = new_hist.astype(np.float32) + new_hist += sampled_hist.astype(np.float32) + return new_hist + + def _relax_min_max(self, new_min, new_max, origin_min, origin_max, + bins_count, + upsample_bins_count) -> Tuple[float, float, int, int]: + _bin_width = (origin_max - origin_min) / ( + bins_count * upsample_bins_count) + downsample_bins_count = int( + np.ceil((new_max - new_min) / (bins_count * _bin_width))) + error = downsample_bins_count * bins_count * _bin_width - ( + new_max - new_min) + new_max += error + start_bin_idx = round((origin_min - new_min) / _bin_width) + return new_min, new_max, downsample_bins_count, start_bin_idx + + @abc.abstractmethod + def cal_min_max(self) -> Tuple[float, float]: + """ Calculate the minimum and maximum based on the histogram. """ + raise NotImplementedError("Please implement the abstract method.") + + def cal_thresholds(self): + assert self._hist is not None + self._min, self._max = self.cal_min_max() + self._scale, self._zero_point = self.cal_scales_zero_points() + + def min_value(self) -> float: + return self._min + + def max_value(self) -> float: + return self._max + + def bit_length(self): + return self._quant_bits + + def quant_axis(self): + return -1 + + def scales(self): + if self._scale is None: + self.cal_thresholds() + return self._scale + + def zero_points(self): + if self._zero_point is None: + self.cal_thresholds() + return self._zero_point diff --git a/paddleslim/quant/observers/hist.py b/paddleslim/quant/observers/hist.py new file mode 100644 index 0000000000000000000000000000000000000000..a53d6c60ee5d0c2fcdb2c0094be5853bbb3ba46d --- /dev/null +++ b/paddleslim/quant/observers/hist.py @@ -0,0 +1,95 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import numpy as np + +from .base_hist import BaseHistObserver +from paddle.quantization.factory import ObserverFactory + + +class HistObserver(ObserverFactory): + r""" + It collects tensor values into a histogram. And calculate quantization parameters + based on a percent ratio. + + Args: + quant_bits (int): The number of bits for quantization. + bins_count(int): The number of equal-width bins. + percent(float): The percentage of bins that are retained when clipping the outliers. + sign (bool): Whether the quantized integer includes a sign. + symmetric (bool): Whether it is symmetric quantization. the quantization is symmetric. + In symmetric quantization, the range of floating point values is relaxed to be symmetric + around zero and the zero-point is always 0. + + + Examples: + .. code-block:: python + + from paddle.quantization import QuantConfig + from paddle.quantization.quanters import HistObserver + quanter = HistObserver() + q_config = QuantConfig(activation=quanter, weight=quanter) + """ + + def __init__(self, + quant_bits=8, + bins_count=2048, + percent=0.999, + sign=True, + symmetric=True): + super(HistObserver, self).__init__( + quant_bits=quant_bits, + bins_count=bins_count, + percent=percent, + sign=sign, + symmetric=symmetric) + + def _get_class(self): + return PercentHistObserverLayer + + +class PercentHistObserverLayer(BaseHistObserver): + r""" + It collects tensor values into a histogram. And calculate quantization parameters + based on a percent ratio. + """ + + def __init__(self, + layer, + quant_bits=8, + bins_count=2048, + percent=0.999, + sign=True, + symmetric=True): + super(PercentHistObserverLayer, self).__init__( + quant_bits=quant_bits, + bins_count=bins_count, + sign=sign, + symmetric=symmetric) + + self._percent = percent + + def _cal_min_max_by_percent(self): + hist = self._hist / np.sum(self._hist, dtype=np.float64) + cumsumed_hist = np.cumsum(hist) + max_idx = np.argwhere(cumsumed_hist >= self._percent)[0] + min_idx = np.argwhere(cumsumed_hist >= (1 - self._percent))[0] + bin_width = (self._hist_max - self._hist_min) / hist.shape[0] + _max = self._hist_min + float((max_idx - 0.5) * bin_width) + _min = self._hist_min + float((min_idx - 0.5) * bin_width) + return _min, _max + + def cal_min_max(self): + return self._cal_min_max_by_percent() diff --git a/paddleslim/quant/observers/kl.py b/paddleslim/quant/observers/kl.py new file mode 100644 index 0000000000000000000000000000000000000000..b9653ff21916a7ad6929d616127675d365847ab1 --- /dev/null +++ b/paddleslim/quant/observers/kl.py @@ -0,0 +1,178 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import math + +from .base_hist import BaseHistObserver +from paddle.quantization.factory import ObserverFactory + + +class KLObserver(ObserverFactory): + r""" + Calculate quantization parameters that minimize the Kullback–Leibler divergence + between the distribution of floating values and the distribution of quantized + floating values. + + Args: + quant_bits (int): The number of bits for quantization. + bins_count(int): The number of equal-width bins. + + Examples: + .. code-block:: python + + from paddle.quantization import QuantConfig + from paddle.quantization.quanters import KLObserver + quanter = KLObserver() + q_config = QuantConfig(activation=quanter, weight=quanter) + """ + + def __init__(self, quant_bits=8, bins_count=2048): + super(KLObserver, self).__init__( + quant_bits=quant_bits, bins_count=bins_count) + + def _get_class(self): + return KLObserverLayer + + +class KLObserverLayer(BaseHistObserver): + """ + Per-tensor KL observer. + """ + + def __init__(self, layer, quant_bits=8, bins_count=2048): + super(KLObserverLayer, self).__init__( + quant_bits=quant_bits, + bins_count=bins_count, + sign=True, + symmetric=True) + + def _search_min_max_by_kl(self): + bin_width = (self._hist_max - self._hist_min) / self._bin_count + _max = cal_kl_threshold(self._hist, bin_width, self.bit_length()) + return 0., _max + + def cal_min_max(self): + return self._search_min_max_by_kl() + + +def expand_quantized_bins(quantized_bins, reference_bins): + ''' + Expand hist bins. + ''' + expanded_quantized_bins = [0] * len(reference_bins) + num_merged_bins = int(len(reference_bins) / len(quantized_bins)) + j_start = 0 + j_end = num_merged_bins + for idx in range(len(quantized_bins)): + zero_count = reference_bins[j_start:j_end].count(0) + num_merged_bins = j_end - j_start + if zero_count == num_merged_bins: + avg_bin_ele = 0 + else: + avg_bin_ele = quantized_bins[idx] / ( + num_merged_bins - zero_count + 0.0) + for idx1 in range(j_start, j_end): + expanded_quantized_bins[idx1] = (0 if reference_bins[idx1] == 0 else + avg_bin_ele) + j_start += num_merged_bins + j_end += num_merged_bins + if (idx + 1) == len(quantized_bins) - 1: + j_end = len(reference_bins) + return expanded_quantized_bins + + +def safe_entropy(reference_distr_P, P_sum, candidate_distr_Q, Q_sum): + ''' + Calculate the entropy. + ''' + assert len(reference_distr_P) == len(candidate_distr_Q) + tmp_sum1 = 0 + tmp_sum2 = 0 + for idx in range(len(reference_distr_P)): + p_idx = reference_distr_P[idx] + q_idx = candidate_distr_Q[idx] + if p_idx == 0: + tmp_sum1 += 0 + tmp_sum2 += 0 + else: + if q_idx == 0: + _logger.error("Fatal error!, idx = " + str(idx) + + " qindex = 0! p_idx = " + str(p_idx)) + tmp_sum1 += p_idx * (math.log(Q_sum * p_idx)) + tmp_sum2 += p_idx * (math.log(P_sum * q_idx)) + return (tmp_sum1 - tmp_sum2) / P_sum + + +def cal_kl_threshold(hist, bin_width, bits): + ''' + Using the KL-divergenc method to get the more precise threshold. + + Args: + hist(List): The hist of the tensor. + bin_width(float): The bin width for the hist. + bits(int): The quantization bits. + ''' + assert hist.ndim == 1 + hist_bins = hist.shape[0] + starting_iter = int((hist_bins - 1) * 0.5) + quant_range = 2**(bits - 1) - 1 + + P_sum = np.sum(np.array(hist).ravel()) + min_kl_divergence = 0 + min_kl_index = 0 + kl_inited = False + + for i in range(starting_iter, hist_bins): + reference_distr_P = hist[0:i].tolist() + outliers_count = sum(hist[i:]) + if reference_distr_P[i - 1] == 0: + continue + reference_distr_P[i - 1] += outliers_count + reference_distr_bins = reference_distr_P[:] + candidate_distr_Q = hist[0:i].tolist() + num_merged_bins = int(i / quant_range) + candidate_distr_Q_quantized = [0] * quant_range + j_start = 0 + j_end = num_merged_bins + for idx in range(quant_range): + candidate_distr_Q_quantized[idx] = sum( + candidate_distr_Q[j_start:j_end]) + j_start += num_merged_bins + j_end += num_merged_bins + if (idx + 1) == quant_range - 1: + j_end = i + candidate_distr_Q = expand_quantized_bins(candidate_distr_Q_quantized, + reference_distr_bins) + Q_sum = sum(candidate_distr_Q) + kl_divergence = safe_entropy(reference_distr_P, P_sum, + candidate_distr_Q, Q_sum) + if not kl_inited: + min_kl_divergence = kl_divergence + min_kl_index = i + kl_inited = True + elif kl_divergence < min_kl_divergence: + min_kl_divergence = kl_divergence + min_kl_index = i + else: + pass + if min_kl_index == 0: + while starting_iter > 0: + if hist[starting_iter] == 0: + starting_iter -= 1 + continue + else: + break + min_kl_index = starting_iter + return (min_kl_index + 0.5) * bin_width diff --git a/paddleslim/quant/observers/uniform.py b/paddleslim/quant/observers/uniform.py new file mode 100644 index 0000000000000000000000000000000000000000..216418a91c1d5bb053966e509da456cb5f17b991 --- /dev/null +++ b/paddleslim/quant/observers/uniform.py @@ -0,0 +1,101 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from typing import Tuple +import numpy as np +from paddle.quantization.base_observer import BaseObserver + + +class UniformObserver(BaseObserver): + """ This is the base class for a uniform quantization observer, which provides + common functions for calculating the scale and zero-point used in uniform quantization. + Uniform quantization maps floating point values to integers, where the scale determines + the step size of the quantizer and the floating point zero is mapped to the zero-point, + an integer value ensuring that zero is quantized without error. + + Args: + quant_bits (int): The number of bits for quantization. + sign (bool): Whether the quantized integer includes a sign. + symmetric (bool): Whether it is symmetric quantization. the quantization is symmetric. + In symmetric quantization, the range of floating point values is relaxed to be symmetric + around zero and the zero-point is always 0. + + """ + + def __init__( + self, + quant_bits=8, + sign=True, + symmetric=True, ): + super(UniformObserver, self).__init__() + self._quant_bits = quant_bits + self._sign = sign + self._symmetric = symmetric + + self._min = None + self._max = None + self._qmin = None + self._qmax = None + + self._scale = None + self._zero_point = None + + @property + def qmin_qmax(self): + """ Calculate the range of the quantized integer based on the specified + quant_bits, sign, and symmetric properties.""" + if self._sign: + self._qmin = -2**(self.bit_length() - 1) + self._qmax = 2**(self.bit_length() - 1) - 1 + else: + self._qmin = 0 + self._qmax = 2**self.bit_length() + return self._qmin, self._qmax + + @abc.abstractmethod + def min_value(self) -> float: + """ The minimum value of floating-point numbers.""" + raise NotImplementedError( + "Please implement the abstract method to get the The minimum value of floating-point numbers." + ) + + @abc.abstractmethod + def max_value(self) -> float: + """ The maximum value of floating-point numbers.""" + raise NotImplementedError( + "Please implement the abstract method to get the the maximum value value of floating-point numbers." + ) + + def cal_scales_zero_points(self) -> Tuple[float, float]: + """ Calculate the scales and zero points based on the min_value and max_value. + """ + assert self.min_value() is not None and self.max_value() is not None + _qmin, _qmax = self.qmin_qmax + # For one-sided distributions, the range (_min , _max ) is relaxed to include zero. + # It is important to ensure that common operations like zero padding do not cause quantization errors. + _min = min(self.min_value(), 0.) + _max = max(self.max_value(), 0.) + + if self._symmetric: + self._scale = max(-_min, _max) / (float(_qmax - _qmin) / 2) + if self._sign: + self._zero_point = 0 + else: + self._zero_point = (_qmax + _qmin) / 2 + else: + self._scale = (_max - _min) / float(_qmax - _qmin) + self._zero_point = _qmin - round(_min / self._scale) + self._zero_point = np.clip(self._zero_point, _qmin, _qmax) + return self._scale, self._zero_point diff --git a/tests/quantization/test_observers.py b/tests/quantization/test_observers.py new file mode 100644 index 0000000000000000000000000000000000000000..92e1139ffa17ffd638e3c2e9fe50fed54418b623 --- /dev/null +++ b/tests/quantization/test_observers.py @@ -0,0 +1,120 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +sys.path.append("../../") +import os +import unittest +import paddle +import tempfile +from paddle.vision.models import resnet18 +from paddle.quantization import QuantConfig +from paddle.quantization import PTQ + +from paddleslim.quant.observers import HistObserver, KLObserver +from paddleslim.quant.observers.hist import PercentHistObserverLayer +from paddleslim.quant.observers.kl import KLObserverLayer +from paddle.nn.quant.format import LinearDequanter, LinearQuanter + + +class TestPTQWithHistObserver(unittest.TestCase): + def __init__(self, observer, observer_type, *args, **kvargs): + super(TestPTQWithHistObserver, self).__init__(*args, **kvargs) + self.observer = observer + self.observer_type = observer_type + + def setUp(self): + paddle.set_device("cpu") + self.init_case() + self.dummy_input = paddle.rand([1, 3, 224, 224]) + self.temp_dir = tempfile.TemporaryDirectory(dir="./") + self.path = os.path.join(self.temp_dir.name, 'qat') + + def tearDown(self): + self.temp_dir.cleanup() + + def runTest(self): + self.test_quantize() + self.test_convert() + + def init_case(self): + # observer = HistObserver() + # self.observer_type = PercentHistObserverLayer + self.q_config = QuantConfig(activation=None, weight=None) + self.q_config.add_type_config( + paddle.nn.Conv2D, activation=self.observer, weight=self.observer) + + def _count_layers(self, model, layer_type): + count = 0 + for _layer in model.sublayers(True): + if isinstance(_layer, layer_type): + count += 1 + return count + + def test_quantize(self): + model = resnet18() + conv_count = self._count_layers(model, paddle.nn.Conv2D) + ptq = PTQ(self.q_config) + model.eval() + quant_model = ptq.quantize(model, inplace=False) + zero_input = paddle.zeros_like(self.dummy_input) + out = quant_model(zero_input) + out = quant_model(self.dummy_input) + out = quant_model(zero_input) + out = quant_model(self.dummy_input + 1.) + quantizer_cnt = self._count_layers(quant_model, self.observer_type) + self.assertEqual(quantizer_cnt, 2 * conv_count) + + def test_convert(self): + model = resnet18() + conv_count = self._count_layers(model, paddle.nn.Conv2D) + ptq = PTQ(self.q_config) + model.eval() + quant_model = ptq.quantize(model, inplace=False) + out = quant_model(self.dummy_input) + converted_model = ptq.convert(quant_model, inplace=False) + + # check count of LinearQuanter and LinearDequanter in dygraph + quantizer_count_in_dygraph = self._count_layers(converted_model, + LinearQuanter) + dequantizer_count_in_dygraph = self._count_layers( + converted_model, LinearDequanter) + self.assertEqual(quantizer_count_in_dygraph, conv_count) + self.assertEqual(dequantizer_count_in_dygraph, conv_count * 2) + + +observer_suite = unittest.TestSuite() +observer_suite.addTest( + TestPTQWithHistObserver( + observer=HistObserver(sign=True, symmetric=True), + observer_type=PercentHistObserverLayer)) +observer_suite.addTest( + TestPTQWithHistObserver( + observer=HistObserver(sign=False, symmetric=True), + observer_type=PercentHistObserverLayer)) +observer_suite.addTest( + TestPTQWithHistObserver( + observer=HistObserver(sign=True, symmetric=False), + observer_type=PercentHistObserverLayer)) +observer_suite.addTest( + TestPTQWithHistObserver( + observer=HistObserver(sign=False, symmetric=False), + observer_type=PercentHistObserverLayer)) +observer_suite.addTest( + TestPTQWithHistObserver( + observer=KLObserver(bins_count=256), observer_type=KLObserverLayer)) + +if __name__ == '__main__': + runner = unittest.TextTestRunner(verbosity=2) + runner.run(observer_suite)