From 2b6fc108ce6c67849179d8e5853b3131c9611e77 Mon Sep 17 00:00:00 2001 From: cc <52520497+juncaipeng@users.noreply.github.com> Date: Tue, 22 Jun 2021 10:22:45 +0800 Subject: [PATCH] Dygraph post trainging quantization (#33445) * dygraph post training quantization * refine the ptq config * refine ptq quantizer --- .../slim/quantization/imperative/__init__.py | 16 + .../slim/quantization/imperative/ptq.py | 112 +++++++ .../quantization/imperative/ptq_config.py | 44 +++ .../slim/quantization/imperative/ptq_hooks.py | 28 ++ .../quantization/imperative/ptq_quantizer.py | 261 ++++++++++++++++ .../quantization/imperative/ptq_registry.py | 86 ++++++ .../slim/quantization/imperative/utils.py | 119 +++++++- .../fluid/contrib/slim/tests/CMakeLists.txt | 2 + .../contrib/slim/tests/test_imperative_ptq.py | 288 ++++++++++++++++++ 9 files changed, 952 insertions(+), 4 deletions(-) create mode 100644 python/paddle/fluid/contrib/slim/quantization/imperative/ptq.py create mode 100644 python/paddle/fluid/contrib/slim/quantization/imperative/ptq_config.py create mode 100644 python/paddle/fluid/contrib/slim/quantization/imperative/ptq_hooks.py create mode 100644 python/paddle/fluid/contrib/slim/quantization/imperative/ptq_quantizer.py create mode 100644 python/paddle/fluid/contrib/slim/quantization/imperative/ptq_registry.py create mode 100644 python/paddle/fluid/contrib/slim/tests/test_imperative_ptq.py diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/__init__.py b/python/paddle/fluid/contrib/slim/quantization/imperative/__init__.py index 7ea62b5f324..77872e88a07 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/__init__.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/__init__.py @@ -20,6 +20,22 @@ from .quant_nn import * from . import qat from .qat import * +from . import ptq +from .ptq import * + +from . import ptq_config +from .ptq_config import * + +from . import ptq_quantizer +from .ptq_quantizer import * + +from . import ptq_registry +from .ptq_registry import * + __all__ = [] __all__ += quant_nn.__all__ __all__ += qat.__all__ +__all__ += ptq.__all__ +__all__ += ptq_config.__all__ +__all__ += ptq_quantizer.__all__ +__all__ += ptq_registry.__all__ diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/ptq.py b/python/paddle/fluid/contrib/slim/quantization/imperative/ptq.py new file mode 100644 index 00000000000..a275ca6f3cd --- /dev/null +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/ptq.py @@ -0,0 +1,112 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import copy +import numpy as np + +import paddle +from paddle.fluid.log_helper import get_logger + +from . import utils +from . import ptq_hooks +from . import ptq_config +from .ptq_registry import PTQRegistry + +__all__ = ['ImperativePTQ'] + +_logger = get_logger( + __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') + + +class ImperativePTQ(object): + """ + Applying static post_training quantization to the dgraph model. + """ + + def __init__(self, quant_config=ptq_config.default_ptq_config): + """ + Constructor. + Args: + algo(str): The algorithm in post_training quantizaion to be used. + activation_bits(int): quantization bit number for activations. + weight_bits(int): quantization bit number for weights. + """ + super(ImperativePTQ, self).__init__() + + assert isinstance(quant_config, ptq_config.PTQConfig) + + self._quant_config = quant_config + + def quantize(self, model, inplace=False): + """ + Add hook to the leaf layer to calculate the threshold of inputs and outputs. + + Args: + model(paddle.nn.Layer): The model to be quantized. + Returns: + None + """ + assert isinstance(model, paddle.nn.Layer), \ + "The model must be the instance of paddle.nn.Layer." + + if not inplace: + model = copy.deepcopy(model) + + for name, layer in model.named_sublayers(): + if PTQRegistry.is_supported_layer(layer) \ + and utils.is_leaf_layer(layer): + quant_config = copy.deepcopy(self._quant_config) + layer._quant_config = quant_config + + hook = ptq_hooks.quant_forward_post_hook + hook_handle = layer.register_forward_post_hook(hook) + quant_config.hook_handle = hook_handle + layer._forward_post_hooks.move_to_end( + hook_handle._hook_id, last=False) + + return model + + def convert(self, model): + """ + Process the scales and remove the hooks. + + Args: + model(paddle.nn.Layer): The model to be quantized. + Returns: + None + """ + assert isinstance(model, paddle.nn.Layer), \ + "The input model must be the instance of paddle.nn.Layer." + + for name, sub_layer in model.named_sublayers(): + if PTQRegistry.is_supported_layer(sub_layer) \ + and utils.is_leaf_layer(sub_layer): + + assert hasattr(sub_layer, "_quant_config") + quant_config = sub_layer._quant_config + quant_config.hook_handle.remove() + + quant_config.in_act_quantizer.cal_thresholds() + quant_config.out_act_quantizer.cal_thresholds() + + # get weight thresholds + if isinstance(sub_layer, tuple(utils.fake_quant_input_layers)): + weights = (sub_layer.weight, ) + quant_config.wt_quantizer.sample_data(sub_layer, weights) + + # TODO (jc): + # save input activation threshold and quant bits + + return model diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_config.py b/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_config.py new file mode 100644 index 00000000000..3b741cc4644 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_config.py @@ -0,0 +1,44 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import six +import abc +import copy + +import paddle + +from .ptq_quantizer import * + +__all__ = ['PTQConfig', 'default_ptq_config'] + + +class PTQConfig(object): + """ + The PTQ config shows how to quantize the inputs and outputs. + """ + + def __init__(self, activation_quantizer, weight_quantizer): + super(PTQConfig, self).__init__() + + assert isinstance(activation_quantizer, BaseQuantizer) + assert isinstance(weight_quantizer, BaseQuantizer) + + self.in_act_quantizer = copy.deepcopy(activation_quantizer) + self.out_act_quantizer = copy.deepcopy(activation_quantizer) + self.wt_quantizer = copy.deepcopy(weight_quantizer) + + self.hook_handle = None + + +default_ptq_config = PTQConfig(AbsmaxQuantizer(), AbsmaxQuantizer()) diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_hooks.py b/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_hooks.py new file mode 100644 index 00000000000..82a277ad28e --- /dev/null +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_hooks.py @@ -0,0 +1,28 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import math +import numpy as np +from . import ptq_config + + +def quant_forward_post_hook(layer, inputs, outputs): + """ + The forward_post_hook for PTQ. + """ + assert hasattr(layer, '_quant_config'), \ + "The layer should have _quant_config attr" + layer._quant_config.in_act_quantizer.sample_data(layer, inputs) + layer._quant_config.out_act_quantizer.sample_data(layer, (outputs, )) diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_quantizer.py b/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_quantizer.py new file mode 100644 index 00000000000..362cc0e0e4a --- /dev/null +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_quantizer.py @@ -0,0 +1,261 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import six +import abc +import copy +import math +import numpy as np + +import paddle + +from . import utils + +__all__ = [ + 'BaseQuantizer', + 'AbsmaxQuantizer', + 'PerChannelAbsmaxQuantizer', + 'KLQuantizer', + 'HistQuantizer', +] + + +def abs_max_value(tensor): + return float(paddle.max(paddle.abs(tensor)).numpy()) + + +def merge_max_value(old, new): + """ + Merge the max element one by one in two lists. + """ + assert isinstance(old, list) and isinstance(new, list) + if old != []: + assert len(old) == len(new) + for i in range(len(old)): + assert type(old[i]) == type(new[i]) + if isinstance(old[i], list): + new[i] = merge_max_value(old[i], new[i]) + else: + new[i] = old[i] if new[i] < old[i] else new[i] + return new + + +def combine_abs_max_and_hist(tensor, origin_max, origin_hist, bins, + upsample_bins): + """ + """ + + new_max = abs_max_value(tensor) + + if new_max == 0.0: + return origin_max, origin_hist + elif origin_max == 0.0: + new_hist, _ = np.histogram( + paddle.abs(tensor).numpy(), range=(0, new_max), bins=bins) + new_hist = new_hist.astype(np.float32) + return new_max, new_hist + elif new_max <= origin_max: + new_hist, _ = np.histogram( + paddle.abs(tensor).numpy(), range=(0, origin_max), bins=bins) + new_hist = new_hist.astype(np.float32) + new_hist += origin_hist + return origin_max, new_hist + else: + # bin_width = origin_max / (bins * upsample_bins) + # = new_max / (bins * downsample_bins) + bin_width = origin_max / (bins * upsample_bins) + downsampe_bins = int(math.ceil(new_max / (bins * bin_width))) + new_max = bins * bin_width * downsampe_bins + + upsampled_hist = np.repeat(origin_hist, upsample_bins) + expanded_hist = np.zeros((bins * downsampe_bins), dtype=np.float32) + expanded_hist[0:bins * upsample_bins] = upsampled_hist + cumsumed_hist = np.cumsum( + expanded_hist, dtype=np.float64)[downsampe_bins - 1::downsampe_bins] + shift_cumsumed_hist = np.zeros((bins), dtype=np.float64) + shift_cumsumed_hist[1:] = cumsumed_hist[0:-1] + sampled_hist = (cumsumed_hist - shift_cumsumed_hist) / upsample_bins + sampled_hist = sampled_hist.astype(np.float32) + + new_hist, _ = np.histogram( + paddle.abs(tensor).numpy(), range=(0, new_max), bins=bins) + new_hist = new_hist.astype(np.float32) + new_hist += sampled_hist + + return new_max, new_hist + + +@six.add_metaclass(abc.ABCMeta) +class BaseQuantizer(object): + """ + Base quantizer for activation and weight. + """ + + def __init__(self, quant_bits=8): + super(BaseQuantizer, self).__init__() + assert isinstance(quant_bits, int) + assert quant_bits > 0 and quant_bits <= 16 + + self.quant_bits = quant_bits + + self.thresholds = [] + + @abc.abstractmethod + def sample_data(self, layer, tensors): + pass + + @abc.abstractmethod + def cal_thresholds(self): + pass + + +class AbsmaxQuantizer(BaseQuantizer): + """ + Per-tensor abs max quantizer. + """ + + def __init__(self, quant_bits=8): + super(AbsmaxQuantizer, self).__init__(quant_bits) + + def sample_data(self, layer, tensors): + assert isinstance(tensors, tuple) + + abs_max_vals = [abs_max_value(t) for t in tensors] + self.thresholds = merge_max_value(self.thresholds, abs_max_vals) + + def cal_thresholds(self): + pass + + +class PerChannelAbsmaxQuantizer(BaseQuantizer): + """ + Per channel abs max quantizer. + """ + + def __init__(self, quant_bits=8): + super(PerChannelAbsmaxQuantizer, self).__init__(quant_bits) + + def sample_data(self, layer, tensors): + assert isinstance(layer, paddle.nn.Layer) + assert isinstance(tensors, tuple) + + abs_max_vals_list = [] + for idx, tensor in enumerate(tensors): + if isinstance(layer, tuple(utils.spec_channel_axis_layers)): + abs_max_vals = [ + abs_max_value(tensor[:, i]) for i in range(tensor.shape[1]) + ] + abs_max_vals_list.append(abs_max_vals) + else: + abs_max_vals = [ + abs_max_value(tensor[i]) for i in range(tensor.shape[0]) + ] + abs_max_vals_list.append(abs_max_vals) + + self.thresholds = merge_max_value(self.thresholds, abs_max_vals_list) + + def cal_thresholds(self): + pass + + +@six.add_metaclass(abc.ABCMeta) +class BaseHistQuantizer(BaseQuantizer): + """ + """ + + def __init__(self, quant_bits=8, bins=1024, upsample_bins=64): + super(BaseHistQuantizer, self).__init__(quant_bits) + self.bins = bins + self.upsample_bins = upsample_bins + + self.abs_max_vals = [] + self.hists = [] + + def sample_data(self, layer, tensors): + assert isinstance(tensors, tuple) + + if self.abs_max_vals == []: + abs_max_vals = [abs_max_value(t) for t in tensors] + self.abs_max_vals = abs_max_vals + + for idx, tensor in enumerate(tensors): + if abs_max_vals[idx] == 0.0: + self.hists.append(None) + else: + hist, _ = np.histogram( + paddle.abs(tensor).numpy(), + range=(0., abs_max_vals[idx]), + bins=self.bins) + hist = hist.astype(np.float32) + self.hists.append(hist) + else: + assert len(self.abs_max_vals) == len(tensors) + assert len(self.hists) == len(tensors) + + for idx, tensor in enumerate(tensors): + new_abs_max, new_hist = combine_abs_max_and_hist( + tensor, self.abs_max_vals[idx], self.hists[idx], self.bins, + self.upsample_bins) + self.abs_max_vals[idx] = new_abs_max + self.hists[idx] = new_hist + + @abc.abstractmethod + def cal_thresholds(self): + pass + + +class HistQuantizer(BaseHistQuantizer): + """ + """ + + def __init__(self, + quant_bits=8, + bins=1024, + upsample_bins=64, + hist_percent=0.99999): + super(HistQuantizer, self).__init__(quant_bits, bins, upsample_bins) + self.hist_percent = hist_percent + + def cal_thresholds(self): + def _helper(abs_max, hist, percent): + assert hist.ndim == 1 and percent < 1.0 + hist = hist / np.sum(hist, dtype=np.float64) + cumsumed_hist = np.cumsum(hist) + index = np.argwhere(cumsumed_hist >= percent)[0] + return float((index - 0.5) * (abs_max / hist.shape[0])) + + for idx in range(len(self.hists)): + if self.hists[idx] is None: + self.thresholds.append(self.abs_max_vals[idx]) + else: + threshold = _helper(self.abs_max_vals[idx], self.hists[idx], + self.hist_percent) + self.thresholds.append(threshold) + + +class KLQuantizer(BaseHistQuantizer): + """ + """ + + def __init__(self, quant_bits=8, bins=1024, upsample_bins=64): + super(KLQuantizer, self).__init__(quant_bits, bins, upsample_bins) + + def cal_thresholds(self): + for idx in range(len(self.hists)): + if self.hists[idx] is None: + self.thresholds.append(self.abs_max_vals[idx]) + else: + threshold = utils.cal_kl_scaling_factor( + self.hists[idx], self.abs_max_vals[idx], self.quant_bits) + self.thresholds.append(threshold) diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_registry.py b/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_registry.py new file mode 100644 index 00000000000..973d66303ec --- /dev/null +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_registry.py @@ -0,0 +1,86 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle + +__all__ = ['PTQRegistry'] + + +class LayerInfo(object): + """ + Store the argnames of the inputs and outputs. + """ + + def __init__(self, layer, input_names, weight_names, output_names): + super(LayerInfo, self).__init__() + self.layer = layer + self.input_names = input_names + self.weight_names = weight_names + self.output_names = output_names + + +PTQ_LAYERS_INFO = [ + LayerInfo(paddle.nn.Conv2D, ['Input'], ['Filter'], ['Output']), + LayerInfo(paddle.nn.Linear, ['X'], ['Y'], ['Out']), + LayerInfo(paddle.nn.BatchNorm2D, ['X'], [], ['Y']), + LayerInfo(paddle.nn.AdaptiveMaxPool2D, ['X'], [], ['Out']), + LayerInfo(paddle.nn.AdaptiveAvgPool2D, ['X'], [], ['Out']), + LayerInfo(paddle.nn.AvgPool2D, ['X'], [], ['Out']), + LayerInfo(paddle.nn.MaxPool2D, ['X'], [], ['Out']), + LayerInfo(paddle.nn.ReLU, ['X'], [], ['Out']), + LayerInfo(paddle.nn.ReLU6, ['X'], [], ['Out']), + LayerInfo(paddle.nn.Hardswish, ['X'], [], ['Out']), + LayerInfo(paddle.nn.Sigmoid, ['X'], [], ['Out']), + LayerInfo(paddle.nn.Softmax, ['X'], [], ['Out']), + LayerInfo(paddle.nn.Tanh, ['X'], [], ['Out']), + LayerInfo(paddle.nn.quant.add, ['X', 'Y'], [], ['Out']), +] + + +class PTQRegistry(object): + """ + Register the supported layers for PTQ and provide layers info. + """ + supported_layers_map = {} + is_inited = False + + def __init__(self): + super(PTQRegistry, self).__init__() + + @classmethod + def _init(cls): + if not cls.is_inited: + for layer_info in PTQ_LAYERS_INFO: + cls.supported_layers_map[layer_info.layer] = layer_info + cls.is_inited = True + + @classmethod + def is_supported_layer(cls, layer): + """ + Analyze whether the layer supports quantization. + """ + cls._init() + return layer in cls.supported_layers_map or \ + isinstance(layer, tuple(cls.supported_layers_map.keys())) + + def layer_info(cls, layer): + """ + Get the infomation for the supported layer. + """ + assert cls.is_supported_layer( + layer), "The input layer is not supported." + + for layer_key, layer_info in cls.supported_layers_map.items(): + if layer == layer_key or isinstance(layer, layer_key): + return layer_info diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/utils.py b/python/paddle/fluid/contrib/slim/quantization/imperative/utils.py index 94639b9cc68..98eefc73608 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/utils.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/utils.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import paddle -from paddle.fluid import dygraph +import math import numpy as np + +import paddle + from . import quant_nn layer_name_map = { @@ -60,6 +62,9 @@ fake_quant_leaf_layers = [ fake_quant_wrap_layers = [quant_nn.QuantizedConv2D, quant_nn.QuantizedLinear] +# The weight format of these layers is Cin * Cout * H * W +spec_channel_axis_layers = [paddle.nn.Conv2D, paddle.nn.Conv2DTranspose] + weight_op_types = [ "conv2d", "depthwise_conv2d", "matmul", "conv2d_transpose", "depthwise_conv2d_transpose" @@ -109,7 +114,7 @@ def find_parent_layer_and_sub_name(model, name): For example, if name is 'block_1/convbn_1/conv_1', the parent layer is 'block_1/convbn_1' and the sub_name is `conv_1`. """ - assert isinstance(model, dygraph.Layer), \ + assert isinstance(model, paddle.nn.Layer), \ "The model must be the instance of paddle.nn.Layer." assert len(name) > 0, "The input (name) should not be empty." @@ -131,5 +136,111 @@ def is_leaf_layer(layer): """ Whether the layer is leaf layer. """ - return isinstance(layer, dygraph.Layer) \ + return isinstance(layer, paddle.nn.Layer) \ and len(layer.sublayers()) == 0 + + +def expand_quantized_bins(quantized_bins, reference_bins): + """ + """ + expanded_quantized_bins = [0] * len(reference_bins) + num_merged_bins = int(len(reference_bins) / len(quantized_bins)) + j_start = 0 + j_end = num_merged_bins + for idx in range(len(quantized_bins)): + zero_count = reference_bins[j_start:j_end].count(0) + num_merged_bins = j_end - j_start + if zero_count == num_merged_bins: + avg_bin_ele = 0 + else: + avg_bin_ele = quantized_bins[idx] / ( + num_merged_bins - zero_count + 0.0) + for idx1 in range(j_start, j_end): + expanded_quantized_bins[idx1] = (0 if reference_bins[idx1] == 0 else + avg_bin_ele) + j_start += num_merged_bins + j_end += num_merged_bins + if (idx + 1) == len(quantized_bins) - 1: + j_end = len(reference_bins) + return expanded_quantized_bins + + +def safe_entropy(reference_distr_P, P_sum, candidate_distr_Q, Q_sum): + ''' + Calculate the entropy. + ''' + assert len(reference_distr_P) == len(candidate_distr_Q) + tmp_sum1 = 0 + tmp_sum2 = 0 + for idx in range(len(reference_distr_P)): + p_idx = reference_distr_P[idx] + q_idx = candidate_distr_Q[idx] + if p_idx == 0: + tmp_sum1 += 0 + tmp_sum2 += 0 + else: + if q_idx == 0: + _logger.error("Fatal error!, idx = " + str(idx) + + " qindex = 0! p_idx = " + str(p_idx)) + tmp_sum1 += p_idx * (math.log(Q_sum * p_idx)) + tmp_sum2 += p_idx * (math.log(P_sum * q_idx)) + return (tmp_sum1 - tmp_sum2) / P_sum + + +def cal_kl_scaling_factor(hist, abs_max, bits): + ''' + Using the KL-divergenc method to get the more precise scaling factor. + ''' + assert hist.ndim == 1 + hist_bins = hist.shape[0] + starting_iter = int((hist_bins - 1) * 0.5) + bin_width = abs_max / hist_bins + quant_range = 2**(bits - 1) - 1 + + P_sum = np.sum(np.array(hist).ravel()) + min_kl_divergence = 0 + min_kl_index = 0 + kl_inited = False + + for i in range(starting_iter, hist_bins): + reference_distr_P = hist[0:i].tolist() + outliers_count = sum(hist[i:]) + if reference_distr_P[i - 1] == 0: + continue + reference_distr_P[i - 1] += outliers_count + reference_distr_bins = reference_distr_P[:] + candidate_distr_Q = hist[0:i].tolist() + num_merged_bins = int(i / quant_range) + candidate_distr_Q_quantized = [0] * quant_range + j_start = 0 + j_end = num_merged_bins + for idx in range(quant_range): + candidate_distr_Q_quantized[idx] = sum(candidate_distr_Q[j_start: + j_end]) + j_start += num_merged_bins + j_end += num_merged_bins + if (idx + 1) == quant_range - 1: + j_end = i + candidate_distr_Q = expand_quantized_bins(candidate_distr_Q_quantized, + reference_distr_bins) + Q_sum = sum(candidate_distr_Q) + kl_divergence = safe_entropy(reference_distr_P, P_sum, + candidate_distr_Q, Q_sum) + if not kl_inited: + min_kl_divergence = kl_divergence + min_kl_index = i + kl_inited = True + elif kl_divergence < min_kl_divergence: + min_kl_divergence = kl_divergence + min_kl_index = i + else: + pass + if min_kl_index == 0: + while starting_iter > 0: + if hist[starting_iter] == 0: + starting_iter -= 1 + continue + else: + break + min_kl_index = starting_iter + return (min_kl_index + 0.5) * bin_width diff --git a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt index 5a4f7c0a1fd..febed599783 100644 --- a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt +++ b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt @@ -125,6 +125,7 @@ if(WIN32) list(REMOVE_ITEM TEST_OPS test_post_training_quantization_mobilenetv1) list(REMOVE_ITEM TEST_OPS test_post_training_quantization_resnet50) list(REMOVE_ITEM TEST_OPS test_post_training_quantization_lstm_model) + list(REMOVE_ITEM TEST_OPS test_imperative_ptq) list(REMOVE_ITEM TEST_OPS test_weight_quantization_mobilenetv1) list(REMOVE_ITEM TEST_OPS test_quantize_transpiler_v2) list(REMOVE_ITEM TEST_OPS test_imperative_qat_amp) @@ -300,6 +301,7 @@ if(NOT WIN32) set_tests_properties(test_post_training_quantization_mobilenetv1 PROPERTIES TIMEOUT 600 LABELS "RUN_TYPE=NIGHTLY") set_tests_properties(test_post_training_quantization_resnet50 PROPERTIES TIMEOUT 600 LABELS "RUN_TYPE=NIGHTLY") set_tests_properties(test_post_training_quantization_mnist PROPERTIES TIMEOUT 120) + set_tests_properties(test_imperative_ptq PROPERTIES TIMEOUT 120) set_tests_properties(test_weight_quantization_mobilenetv1 PROPERTIES TIMEOUT 120) endif() diff --git a/python/paddle/fluid/contrib/slim/tests/test_imperative_ptq.py b/python/paddle/fluid/contrib/slim/tests/test_imperative_ptq.py new file mode 100644 index 00000000000..30ba53e2fcf --- /dev/null +++ b/python/paddle/fluid/contrib/slim/tests/test_imperative_ptq.py @@ -0,0 +1,288 @@ +# copyright (c) 2018 paddlepaddle authors. all rights reserved. +# +# licensed under the apache license, version 2.0 (the "license"); +# you may not use this file except in compliance with the license. +# you may obtain a copy of the license at +# +# http://www.apache.org/licenses/license-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the license is distributed on an "as is" basis, +# without warranties or conditions of any kind, either express or implied. +# see the license for the specific language governing permissions and +# limitations under the license. + +from __future__ import print_function + +import os +import numpy as np +import random +import shutil +import time +import unittest +import logging + +import paddle +import paddle.fluid as fluid +from paddle.fluid.contrib.slim.quantization import * +from paddle.fluid.log_helper import get_logger +from paddle.dataset.common import download + +from imperative_test_utils import fix_model_dict, ImperativeLenet + +_logger = get_logger( + __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') + + +class TestImperativePTQ(unittest.TestCase): + """ + """ + + @classmethod + def setUpClass(cls): + timestamp = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime()) + cls.root_path = os.path.join(os.getcwd(), "imperative_ptq_" + timestamp) + cls.save_path = os.path.join(cls.root_path, "model") + + cls.download_path = 'dygraph_int8/download' + cls.cache_folder = os.path.expanduser('~/.cache/paddle/dataset/' + + cls.download_path) + + cls.lenet_url = "https://paddle-inference-dist.cdn.bcebos.com/int8/unittest_model_data/lenet_pretrained.tar.gz" + cls.lenet_md5 = "953b802fb73b52fae42896e3c24f0afb" + + seed = 1 + np.random.seed(seed) + paddle.static.default_main_program().random_seed = seed + paddle.static.default_startup_program().random_seed = seed + + @classmethod + def tearDownClass(cls): + try: + shutil.rmtree(cls.root_path) + except Exception as e: + print("Failed to delete {} due to {}".format(cls.root_path, str(e))) + + def cache_unzipping(self, target_folder, zip_path): + if not os.path.exists(target_folder): + cmd = 'mkdir {0} && tar xf {1} -C {0}'.format(target_folder, + zip_path) + os.system(cmd) + + def download_model(self, data_url, data_md5, folder_name): + download(data_url, self.download_path, data_md5) + file_name = data_url.split('/')[-1] + zip_path = os.path.join(self.cache_folder, file_name) + print('Data is downloaded at {0}'.format(zip_path)) + + data_cache_folder = os.path.join(self.cache_folder, folder_name) + self.cache_unzipping(data_cache_folder, zip_path) + return data_cache_folder + + def set_vars(self): + self.ptq = ImperativePTQ(default_ptq_config) + + self.batch_num = 10 + self.batch_size = 10 + self.eval_acc_top1 = 0.99 + + self.gt_thresholds = { + 'conv2d_0': [[1.0], [0.37673383951187134], [0.10933732241392136]], + 'batch_norm2d_0': [[0.37673383951187134], [0.44249194860458374]], + 're_lu_0': [[0.44249194860458374], [0.25804123282432556]], + 'max_pool2d_0': [[0.25804123282432556], [0.25804123282432556]], + 'linear_0': + [[1.7058950662612915], [14.405526161193848], [0.4373355209827423]], + 'add_0': [[1.7058950662612915, 0.0], [1.7058950662612915]], + } + + def model_train(self, model, train_reader, max_step=-1): + model.train() + adam = paddle.optimizer.Adam( + learning_rate=0.001, parameters=model.parameters()) + + for batch_id, data in enumerate(train_reader()): + x_data = np.array([x[0].reshape(1, 28, 28) + for x in data]).astype('float32') + y_data = np.array( + [x[1] for x in data]).astype('int64').reshape(-1, 1) + + img = paddle.to_tensor(x_data) + label = paddle.to_tensor(y_data) + + out = model(img) + acc = fluid.layers.accuracy(out, label) + loss = fluid.layers.cross_entropy(out, label) + avg_loss = fluid.layers.mean(loss) + avg_loss.backward() + + adam.minimize(avg_loss) + model.clear_gradients() + + if batch_id % 100 == 0: + _logger.info("Train | step {}: loss = {:}, acc= {:}".format( + batch_id, avg_loss.numpy(), acc.numpy())) + + if max_step > 0 and batch_id > max_step: # For shortening CI time + break + + def model_test(self, model, batch_num=-1, batch_size=8): + model.eval() + + test_reader = paddle.batch( + paddle.dataset.mnist.test(), batch_size=batch_size) + + eval_acc_top1_list = [] + for batch_id, data in enumerate(test_reader()): + x_data = np.array([x[0].reshape(1, 28, 28) + for x in data]).astype('float32') + y_data = np.array( + [x[1] for x in data]).astype('int64').reshape(-1, 1) + + img = paddle.to_tensor(x_data) + label = paddle.to_tensor(y_data) + + out = model(img) + acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1) + acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5) + + if batch_id % 100 == 0: + eval_acc_top1_list.append(float(acc_top1.numpy())) + _logger.info("Test | At step {}: acc1 = {:}, acc5 = {:}".format( + batch_id, acc_top1.numpy(), acc_top5.numpy())) + + if batch_num > 0 and batch_id + 1 >= batch_num: + break + + eval_acc_top1 = sum(eval_acc_top1_list) / len(eval_acc_top1_list) + + return eval_acc_top1 + + def check_thresholds(self, model): + check_num = 0 + for name, layer in model.named_sublayers(): + layer_name = layer.full_name() + if layer_name in self.gt_thresholds: + ref_val = self.gt_thresholds[layer_name] + assert hasattr(layer, '_quant_config') + + quant_config = layer._quant_config + in_val = quant_config.in_act_quantizer.thresholds + out_val = quant_config.out_act_quantizer.thresholds + wt_val = quant_config.wt_quantizer.thresholds + check_num += 1 + + self.assertTrue( + np.allclose( + ref_val[0], in_val, atol=1e-3), + "%s | The thresholds(%s) is different " + "from the ground truth(%s)." % + (layer_name, str(in_val), str(ref_val[0]))) + self.assertTrue( + np.allclose( + ref_val[1], out_val, atol=1e-3), + "%s | The thresholds(%s) is different " + "from the ground truth(%s)." % + (layer_name, str(out_val), str(ref_val[1]))) + if len(ref_val) > 2 and ref_val[2] != []: + self.assertTrue( + np.allclose( + ref_val[2], wt_val, atol=1e-3), + "%s | The thresholds(%s) is different " + "from the ground truth(%s)." % + (layer_name, str(wt_val), str(ref_val[2]))) + + self.assertTrue(check_num == len(self.gt_thresholds)) + + def test_ptq(self): + start_time = time.time() + + self.set_vars() + + params_path = self.download_model(self.lenet_url, self.lenet_md5, + "lenet") + params_path += "/lenet_pretrained/lenet.pdparams" + + with fluid.dygraph.guard(): + model = ImperativeLenet() + model_state_dict = paddle.load(params_path) + model.set_state_dict(model_state_dict) + + self.ptq.quantize(model, inplace=True) + + acc_top1 = self.model_test(model, self.batch_num, self.batch_size) + print('acc_top1: %s' % acc_top1) + self.assertTrue( + acc_top1 > self.eval_acc_top1, + msg="The test acc {%f} is less than {%f}." % + (acc_top1, self.eval_acc_top1)) + + self.ptq.convert(model) + + self.check_thresholds(model) + + input_spec = [ + paddle.static.InputSpec( + shape=[None, 1, 28, 28], dtype='float32') + ] + paddle.jit.save(layer=model, path=self.save_path, input_spec=input_spec) + print('Quantized model saved in {%s}' % self.save_path) + + end_time = time.time() + print("total time: %ss" % (end_time - start_time)) + + +class TestImperativePTQHist(TestImperativePTQ): + """ + """ + + def set_vars(self): + config = PTQConfig(HistQuantizer(), AbsmaxQuantizer()) + self.ptq = ImperativePTQ(config) + + self.batch_num = 10 + self.batch_size = 10 + self.eval_acc_top1 = 0.99 + + self.gt_thresholds = { + 'conv2d_0': + [[0.99853515625], [0.35732391771364225], [0.10933732241392136]], + 'batch_norm2d_0': [[0.35732391771364225], [0.4291427868761275]], + 're_lu_0': [[0.4291427868761275], [0.2359918110742001]], + 'max_pool2d_0': [[0.2359918110742001], [0.25665526917146053]], + 'linear_0': + [[1.7037603475152991], [14.395224522473026], [0.4373355209827423]], + 'add_0': [[1.7037603475152991, 0.0], [1.7037603475152991]], + } + + +class TestImperativePTQKL(TestImperativePTQ): + """ + """ + + def set_vars(self): + config = PTQConfig(KLQuantizer(), PerChannelAbsmaxQuantizer()) + self.ptq = ImperativePTQ(config) + + self.batch_num = 10 + self.batch_size = 10 + self.eval_acc_top1 = 0.99 + + conv2d_1_wt_thresholds = [ + 0.18116560578346252, 0.17079241573810577, 0.1702047884464264, + 0.179476797580719, 0.1454375684261322, 0.22981858253479004 + ] + self.gt_thresholds = { + 'conv2d_0': [[0.99267578125], [0.37695913558696836]], + 'conv2d_1': [[0.19189296757394914], [0.24514256547263358], + [conv2d_1_wt_thresholds]], + 'batch_norm2d_0': [[0.37695913558696836], [0.27462541429440535]], + 're_lu_0': [[0.27462541429440535], [0.19189296757394914]], + 'max_pool2d_0': [[0.19189296757394914], [0.19189296757394914]], + 'linear_0': [[1.2839322163611087], [8.957185942414352]], + 'add_0': [[1.2839322163611087, 0.0], [1.2839322163611087]], + } + + +if __name__ == '__main__': + unittest.main() -- GitLab