提交 a8e2f02b 编写于 作者: R RachelXu7

Add Observers

上级 b692d8ec
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .mse import MSEObserver
from .emd import EMDObserver
from .avg import AVGObserver
from .hist import HistObserver
from .kl import KLObserver
__all__ = [
"MSEObserver", "EMDObserver", "AVGObserver", "HistObserver", "KLObserver"
]
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from .uniform import UniformObserver
from paddle.quantization.factory import ObserverFactory
class AVGObserver(ObserverFactory):
r"""
It collects maximum absolute values of target tensor.
Args:
bit_length(int, optional): Number of bits to represent an quantized integer in binary.
dtype(str, optional): The data type of input tensor.
name (str, optional): This parameter is used by developers to print debugging information. \
For details, please refer to :ref:`api_guide_Name`. Default is None.
Examples:
.. code-block:: python
from paddle.quantization import QuantConfig
from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.99)
q_config = QuantConfig(activation=quanter, weight=quanter)
"""
def __init__(self, quant_bits=8):
super(AVGObserver, self).__init__(quant_bits=quant_bits)
def _get_class(self):
return AVGObserverLayer
class AVGObserverLayer(UniformObserver):
def __init__(
self,
layer,
quant_bits=8, ):
super(AVGObserverLayer, self).__init__(quant_bits=quant_bits)
self._quant_bits = quant_bits
self._avg_list = []
def forward(self, inputs):
""" Calculate forward pass.
"""
self._scale = None
self._zero_point = None
self._min = None
self._max = None
self._avg_min, self._avg_max = self.cal_min_max(inputs)
self._avg_list.append(self._avg_max)
return inputs
def cal_min_max(self, inputs):
abs_avg_value = paddle.abs(inputs.reshape((inputs.shape[0], -1)))
abs_avg_value = float(paddle.mean(paddle.max(abs_avg_value, axis=(1))))
return 0, abs_avg_value
def cal_thresholds(self):
""" Compute thresholds for MAX function.
"""
self._min, self._max = self._avg_min, paddle.mean(
paddle.to_tensor(self._avg_list))
self._scale, self._zero_point = self.cal_scales_zero_points()
def bit_length(self):
""" Return the bit length of quantized data.
"""
return self._quant_bits
def quant_axis(self):
""" Return quantization axis.
"""
return -1
def scales(self):
""" Return output scales.
"""
if self._scale is None:
self.cal_thresholds()
return self._scale
def zero_points(self):
""" Return output zero points.
"""
if self._zero_point is None:
self.cal_thresholds()
return self._zero_point
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import paddle
import numpy as np
from .uniform import UniformObserver
class BaseHistObserver(UniformObserver):
"""
Per-tensor abs max quantizer.
"""
def __init__(self,
quant_bits=8,
bins_count=2048,
upsample_bins_count=64,
sign=True,
symmetric=True):
super(BaseHistObserver, self).__init__(
quant_bits=quant_bits,
sign=sign,
symmetric=symmetric, )
self._bin_count = bins_count
self._upsample_bin_count = upsample_bins_count
self._hist_min = None
self._hist_max = None
self._hist = None
def _min_max(self, tensor):
"""" """
return float(paddle.min(tensor).numpy()), float(
paddle.max(tensor).numpy())
def _init_hists(self, inputs):
_min, _max = self._min_max(inputs)
hist = None
if _max > _min:
hist, _ = np.histogram(
inputs.numpy(), range=(_min, _max), bins=self._bin_count)
hist.astype(np.float32)
return hist
def forward(self, inputs):
""" Calculate forward pass.
"""
self._scale = None
self._zero_point = None
self._min = None
self._max = None
if self._hist_min is None or self._hist_max is None:
self._hist_min, self._hist_max = self._min_max(inputs)
self._hist = self._init_hists(inputs)
else:
new_min, new_max, new_hist = self.update_min_max_and_hist(
inputs,
self._hist_min,
self._hist_max,
self._hist,
self._bin_count,
self._upsample_bin_count, )
self._hist_min, self._hist_max = new_min, new_max
self._hist = new_hist
return inputs
def update_min_max_and_hist(self, tensor, origin_min, origin_max,
origin_hist, bins_count, upsample_bins_count):
_origin_min, _origin_max = origin_min, origin_max
_new_min, _new_max = self._min_max(tensor)
if (_new_max - _new_min) == 0.0:
return _origin_min, _origin_max, origin_hist
elif _origin_max - _origin_min == 0.0:
new_hist, _ = np.histogram(
tensor.numpy(), range=(_new_min, _new_max), bins=bins_count)
new_hist = new_hist.astype(np.float32)
return _new_min, _new_max, new_hist
elif _new_max <= _origin_max and _new_min >= _origin_min:
new_hist, _ = np.histogram(
tensor.numpy(),
range=(_origin_min, _origin_max),
bins=bins_count)
new_hist = new_hist.astype(np.float32)
new_hist += origin_hist
return _origin_min, _origin_max, new_hist
else:
_new_min = min(_new_min, _origin_min)
_new_max = max(_new_max, _origin_max)
_new_min, _new_max, downsample_bins_count, start_bin_idx = self._relax_min_max(
_new_min, _new_max, _origin_min, _origin_max, bins_count,
upsample_bins_count)
new_hist, _ = np.histogram(
tensor.numpy(), range=(_new_min, _new_max), bins=bins_count)
merged_histogram = self._merge_histograms(
new_hist, origin_hist, upsample_bins_count,
downsample_bins_count, start_bin_idx, bins_count)
return _new_min, _new_max, merged_histogram
def _merge_histograms(
self,
new_hist: np.ndarray,
origin_hist: np.ndarray,
upsample_bins_count: int,
downsample_bins_count: int,
start_bin_idx: int,
bins_count: int, ):
upsampled_histogram = np.repeat(origin_hist, upsample_bins_count)
expanded_hist = np.zeros(
(bins_count * downsample_bins_count), dtype=np.float32)
expanded_hist[start_bin_idx:bins_count * upsample_bins_count +
start_bin_idx] = upsampled_histogram
cumsumed_hist = np.cumsum(
expanded_hist,
dtype=np.float64)[downsample_bins_count - 1::downsample_bins_count]
shift_cumsumed_hist = np.zeros((bins_count), dtype=np.float64)
shift_cumsumed_hist[1:] = cumsumed_hist[0:-1]
sampled_hist = (
cumsumed_hist - shift_cumsumed_hist) / upsample_bins_count
new_hist += sampled_hist.astype(np.float32)
return new_hist
def _relax_min_max(self, new_min, new_max, origin_min, origin_max,
bins_count, upsample_bins_count):
_bin_width = (origin_max - origin_min) / (
bins_count * upsample_bins_count)
downsample_bins_count = np.ceil(
(new_max - new_min) / (bins_count * _bin_width))
error = downsample_bins_count * bins_count * _bin_width - (
new_max - new_min)
new_max += error
start_bin_idx = round((origin_min - new_min) / _bin_width)
return new_min, new_max, downsample_bins_count, start_bin_idx
@abc.abstractmethod
def cal_min_max(self):
pass
def cal_thresholds(self):
""" Compute thresholds for MAX function.
"""
assert self._hist is not None
self._min, self._max = self.cal_min_max()
self._scale, self._zero_point = self.cal_scales_zero_points()
def bit_length(self):
""" Return the bit length of quantized data.
"""
return self._quant_bits
def quant_axis(self):
""" Return quantization axis.
"""
return -1
def scales(self):
""" Return output scales.
"""
if self._scale is None:
self.cal_thresholds()
return self._scale
def zero_points(self):
""" Return output zero points.
"""
if self._zero_point is None:
self.cal_thresholds()
return self._zero_point
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from .uniform import UniformObserver
from paddle.quantization.factory import ObserverFactory
class EMDObserver(ObserverFactory):
r"""
It collects maximum absolute values of target tensor.
Args:
bit_length(int, optional): Number of bits to represent an quantized integer in binary.
dtype(str, optional): The data type of input tensor.
name (str, optional): This parameter is used by developers to print debugging information. \
For details, please refer to :ref:`api_guide_Name`. Default is None.
Examples:
.. code-block:: python
from paddle.quantization import QuantConfig
from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.99)
q_config = QuantConfig(activation=quanter, weight=quanter)
"""
def __init__(self, quant_bits=8):
super(EMDObserver, self).__init__(quant_bits=quant_bits)
def _get_class(self):
return EMDObserverLayer
class EMDObserverLayer(UniformObserver):
def __init__(self, layer, quant_bits=8):
super(EMDObserverLayer, self).__init__(quant_bits=quant_bits)
self._quant_bits = quant_bits
self._calibration_loss = float('inf')
self.qmin, self.qmax = self.qmin_qmax
def forward(self, inputs):
""" Calculate forward pass.
"""
self._scale = None
self._zero_point = None
self._min = None
self._max = None
self._emd_min, self._emd_max = self.cal_min_max(inputs)
return inputs
def cal_min_max(self, inputs):
abs_max_value = float(paddle.max(paddle.flatten(inputs)))
abs_max_value = 1e-8 if abs_max_value == 0.0 else abs_max_value
s = 0.3
while s <= 1.0:
scale = s * abs_max_value
s += 0.02
bins = 2**(self._quant_bits - 1) - 1
quant_var = paddle.clip(
paddle.round(inputs / scale * self.qmax), -self.qmax - 1,
self.qmax)
quant_dequant_var = quant_var / self.qmax * scale
emd_loss = paddle.abs(
paddle.mean(inputs) - paddle.mean(quant_dequant_var)
) + paddle.abs(paddle.std(inputs) - paddle.std(quant_dequant_var))
emd_loss = float(emd_loss)
if emd_loss <= self._calibration_loss:
self._calibration_loss = emd_loss
return 0, scale
def cal_thresholds(self):
""" Compute thresholds for MAX function.
"""
self._min, self._max = self._emd_min, self._emd_max
self._scale, self._zero_point = self.cal_scales_zero_points()
def bit_length(self):
""" Return the bit length of quantized data.
"""
return self._quant_bits
def quant_axis(self):
""" Return quantization axis.
"""
return -1
def scales(self):
""" Return output scales.
"""
if self._scale is None:
self.cal_thresholds()
return self._scale
def zero_points(self):
""" Return output zero points.
"""
if self._zero_point is None:
self.cal_thresholds()
return self._zero_point
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numpy as np
from .base_hist import BaseHistObserver
from paddle.quantization.factory import ObserverFactory
class HistObserver(ObserverFactory):
r"""
It collects maximum absolute values of target tensor.
Args:
bit_length(int, optional): Number of bits to represent an quantized integer in binary.
dtype(str, optional): The data type of input tensor.
name (str, optional): This parameter is used by developers to print debugging information. \
For details, please refer to :ref:`api_guide_Name`. Default is None.
Examples:
.. code-block:: python
from paddle.quantization import QuantConfig
from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.99)
q_config = QuantConfig(activation=quanter, weight=quanter)
"""
def __init__(self,
quant_bits=8,
bins_count=2048,
upsample_bins_count=64,
percent=0.999,
sign=True,
symmetric=True):
super(HistObserver, self).__init__(
quant_bits=quant_bits,
bins_count=bins_count,
upsample_bins_count=upsample_bins_count,
percent=percent,
sign=sign,
symmetric=symmetric)
def _get_class(self):
return PercentHistObserverLayer
class PercentHistObserverLayer(BaseHistObserver):
"""
Per-tensor abs max quantizer.
"""
def __init__(self,
layer,
quant_bits=8,
bins_count=2048,
upsample_bins_count=64,
percent=0.999,
sign=True,
symmetric=True):
super(PercentHistObserverLayer, self).__init__(
quant_bits=quant_bits,
bins_count=bins_count,
upsample_bins_count=upsample_bins_count,
sign=sign,
symmetric=symmetric)
self._percent = percent
def _cal_min_max_by_percent(self):
hist = self._hist / np.sum(self._hist, dtype=np.float64)
cumsumed_hist = np.cumsum(hist)
max_idx = np.argwhere(cumsumed_hist >= self._percent)[0]
min_idx = np.argwhere(cumsumed_hist >= (1 - self._percent))[0]
bin_width = (self._hist_max - self._hist_min) / hist.shape[0]
_max = self._hist_min + float((max_idx - 0.5) * bin_width)
_min = self._hist_min + float((min_idx - 0.5) * bin_width)
return _min, _max
def cal_min_max(self):
return self._cal_min_max_by_percent()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import math
from .base_hist import BaseHistObserver
from paddle.quantization.factory import ObserverFactory
class KLObserver(ObserverFactory):
r"""
It collects maximum absolute values of target tensor.
Args:
bit_length(int, optional): Number of bits to represent an quantized integer in binary.
dtype(str, optional): The data type of input tensor.
name (str, optional): This parameter is used by developers to print debugging information. \
For details, please refer to :ref:`api_guide_Name`. Default is None.
Examples:
.. code-block:: python
from paddle.quantization import QuantConfig
from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.99)
q_config = QuantConfig(activation=quanter, weight=quanter)
"""
def __init__(
self,
quant_bits=8,
bins_count=2048,
upsample_bins_count=64, ):
super(KLObserver, self).__init__(
quant_bits=quant_bits,
bins_count=bins_count,
upsample_bins_count=upsample_bins_count, )
def _get_class(self):
return KLObserverLayer
class KLObserverLayer(BaseHistObserver):
"""
Per-tensor KL observer.
"""
def __init__(
self,
layer,
quant_bits=8,
bins_count=2048,
upsample_bins_count=64, ):
super(KLObserverLayer, self).__init__(
quant_bits=quant_bits,
bins_count=bins_count,
upsample_bins_count=upsample_bins_count,
sign=True,
symmetric=True)
def _search_min_max_by_kl(self):
bin_width = (self._hist_max - self._hist_min) / self._bin_count
_max = cal_kl_threshold(self._hist, bin_width, self.bit_length())
return 0., _max
def cal_min_max(self):
return self._search_min_max_by_kl()
def expand_quantized_bins(quantized_bins, reference_bins):
'''
Expand hist bins.
'''
expanded_quantized_bins = [0] * len(reference_bins)
num_merged_bins = int(len(reference_bins) / len(quantized_bins))
j_start = 0
j_end = num_merged_bins
for idx in range(len(quantized_bins)):
zero_count = reference_bins[j_start:j_end].count(0)
num_merged_bins = j_end - j_start
if zero_count == num_merged_bins:
avg_bin_ele = 0
else:
avg_bin_ele = quantized_bins[idx] / (
num_merged_bins - zero_count + 0.0)
for idx1 in range(j_start, j_end):
expanded_quantized_bins[idx1] = (0 if reference_bins[idx1] == 0 else
avg_bin_ele)
j_start += num_merged_bins
j_end += num_merged_bins
if (idx + 1) == len(quantized_bins) - 1:
j_end = len(reference_bins)
return expanded_quantized_bins
def safe_entropy(reference_distr_P, P_sum, candidate_distr_Q, Q_sum):
'''
Calculate the entropy.
'''
assert len(reference_distr_P) == len(candidate_distr_Q)
tmp_sum1 = 0
tmp_sum2 = 0
for idx in range(len(reference_distr_P)):
p_idx = reference_distr_P[idx]
q_idx = candidate_distr_Q[idx]
if p_idx == 0:
tmp_sum1 += 0
tmp_sum2 += 0
else:
if q_idx == 0:
_logger.error("Fatal error!, idx = " + str(idx) +
" qindex = 0! p_idx = " + str(p_idx))
tmp_sum1 += p_idx * (math.log(Q_sum * p_idx))
tmp_sum2 += p_idx * (math.log(P_sum * q_idx))
return (tmp_sum1 - tmp_sum2) / P_sum
def cal_kl_threshold(hist, bin_width, bits):
'''
Using the KL-divergenc method to get the more precise threshold.
Args:
hist(List): The hist of the tensor.
bin_width(float): The bin width for the hist.
bits(int): The quantization bits.
'''
assert hist.ndim == 1
hist_bins = hist.shape[0]
starting_iter = int((hist_bins - 1) * 0.5)
quant_range = 2**(bits - 1) - 1
P_sum = np.sum(np.array(hist).ravel())
min_kl_divergence = 0
min_kl_index = 0
kl_inited = False
for i in range(starting_iter, hist_bins):
reference_distr_P = hist[0:i].tolist()
outliers_count = sum(hist[i:])
if reference_distr_P[i - 1] == 0:
continue
reference_distr_P[i - 1] += outliers_count
reference_distr_bins = reference_distr_P[:]
candidate_distr_Q = hist[0:i].tolist()
num_merged_bins = int(i / quant_range)
candidate_distr_Q_quantized = [0] * quant_range
j_start = 0
j_end = num_merged_bins
for idx in range(quant_range):
candidate_distr_Q_quantized[idx] = sum(
candidate_distr_Q[j_start:j_end])
j_start += num_merged_bins
j_end += num_merged_bins
if (idx + 1) == quant_range - 1:
j_end = i
candidate_distr_Q = expand_quantized_bins(candidate_distr_Q_quantized,
reference_distr_bins)
Q_sum = sum(candidate_distr_Q)
kl_divergence = safe_entropy(reference_distr_P, P_sum,
candidate_distr_Q, Q_sum)
if not kl_inited:
min_kl_divergence = kl_divergence
min_kl_index = i
kl_inited = True
elif kl_divergence < min_kl_divergence:
min_kl_divergence = kl_divergence
min_kl_index = i
else:
pass
if min_kl_index == 0:
while starting_iter > 0:
if hist[starting_iter] == 0:
starting_iter -= 1
continue
else:
break
min_kl_index = starting_iter
return (min_kl_index + 0.5) * bin_width
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from .uniform import UniformObserver
from paddle.quantization.factory import ObserverFactory
class MSEObserver(ObserverFactory):
r"""
It collects maximum absolute values of target tensor.
Args:
bit_length(int, optional): Number of bits to represent an quantized integer in binary.
dtype(str, optional): The data type of input tensor.
name (str, optional): This parameter is used by developers to print debugging information. \
For details, please refer to :ref:`api_guide_Name`. Default is None.
Examples:
.. code-block:: python
from paddle.quantization import QuantConfig
from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.99)
q_config = QuantConfig(activation=quanter, weight=quanter)
"""
def __init__(self, quant_bits=8):
super(MSEObserver, self).__init__(quant_bits=quant_bits)
def _get_class(self):
return MSEObserverLayer
class MSEObserverLayer(UniformObserver):
def __init__(self, layer, quant_bits=8):
super(MSEObserverLayer, self).__init__(quant_bits=quant_bits)
self.quant_bits = quant_bits
self.calibration_loss = float('inf')
self.qmin, self.qmax = self.qmin_qmax
def forward(self, inputs):
""" Calculate forward pass.
"""
self._scale = None
self._zero_point = None
self._min = None
self._max = None
self._mse_min, self._mse_max = self.cal_min_max(inputs)
return inputs
def cal_min_max(self, inputs):
abs_max_value = float(paddle.max(paddle.abs(inputs.flatten())))
abs_max_value = 1e-8 if abs_max_value == 0.0 else abs_max_value
s = 0.3
while s <= 1.0:
scale = s * abs_max_value
s += 0.02
quant_var = paddle.clip(
paddle.round(inputs / scale * self.qmax), -self.qmax - 1,
self.qmax)
quant_dequant_var = quant_var / self.qmax * scale
mse_loss = float(((inputs - quant_dequant_var)**2).mean())
if mse_loss <= self.calibration_loss:
self.calibration_loss = mse_loss
return 0, scale
def cal_thresholds(self):
""" Compute thresholds for MAX function.
"""
self._min, self._max = self._mse_min, self._mse_max
self._scale, self._zero_point = self.cal_scales_zero_points()
def bit_length(self):
""" Return the bit length of quantized data.
"""
return self._quant_bits
def quant_axis(self):
""" Return quantization axis.
"""
return -1
def scales(self):
""" Return output scales.
"""
if self._scale is None:
self.cal_thresholds()
return self._scale
def zero_points(self):
""" Return output zero points.
"""
if self._zero_point is None:
self.cal_thresholds()
return self._zero_point
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from paddle.quantization.base_observer import BaseObserver
class UniformObserver(BaseObserver):
""" An abstract class used for uniform quantization.
"""
def __init__(
self,
quant_bits=8,
sign=True,
symmetric=True, ):
super(UniformObserver, self).__init__()
self._quant_bits = quant_bits
self._sign = sign
self._symmetric = symmetric
self._min = None
self._max = None
self._qmin = None
self._qmax = None
self._scale = None
self._zero_point = None
@property
def qmin_qmax(self):
""" Get the range of the integer."""
if self._qmin is not None and self._qmax is not None:
return self.qmin, self.qmax
if self._sign:
self.qmin = -2**(self.bit_length() - 1)
self.qmax = 2**(self.bit_length() - 1) - 1
else:
self.qmin = 0
self.qmax = 2**self.bit_length()
return self.qmin, self.qmax
def cal_scales_zero_points(self):
""" Compute the scales and zero_points.
"""
assert self._min is not None and self._max is not None
_qmin, _qmax = self.qmin_qmax
# For one-sided distributions, the range (_min , _max ) is relaxed to include zero.
# It is important to ensure that common operations like zero padding do not cause quantization errors.
_min = min(self._min, 0.)
_max = max(self._max, 0.)
if self._symmetric:
self._scale = max(-_min, _max) / (float(_qmax - _qmin) / 2)
if self._sign:
self._zero_point = 0
else:
self._zero_point = (_qmax + _qmin) / 2
else:
self._scale = (_max - _min) / float(_qmax - _qmin)
self._zero_point = _qmin - round(_min / self._scale)
self._zero_point = np.clip(self._zero_point, _qmin, _qmax)
return self._scale, self._zero_point
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .mse import MSEObserver
from .emd import EMDObserver
from .avg import AVGObserver
from .hist import HistObserver
from .kl import KLObserver
__all__ = [
"MSEObserver", "EMDObserver", "AVGObserver", "HistObserver", "KLObserver"
]
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from .uniform import UniformObserver
from paddle.quantization.factory import ObserverFactory
class AVGObserver(ObserverFactory):
r"""
It collects maximum absolute values of target tensor.
Args:
bit_length(int, optional): Number of bits to represent an quantized integer in binary.
dtype(str, optional): The data type of input tensor.
name (str, optional): This parameter is used by developers to print debugging information. \
For details, please refer to :ref:`api_guide_Name`. Default is None.
Examples:
.. code-block:: python
from paddle.quantization import QuantConfig
from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.99)
q_config = QuantConfig(activation=quanter, weight=quanter)
"""
def __init__(self, quant_bits=8):
super(AVGObserver, self).__init__(quant_bits=quant_bits)
def _get_class(self):
return AVGObserverLayer
class AVGObserverLayer(UniformObserver):
def __init__(
self,
layer,
quant_bits=8, ):
super(AVGObserverLayer, self).__init__(quant_bits=quant_bits)
self._quant_bits = quant_bits
self._avg_list = []
def forward(self, inputs):
""" Calculate forward pass.
"""
self._scale = None
self._zero_point = None
self._min = None
self._max = None
self._avg_min, self._avg_max = self.cal_min_max(inputs)
self._avg_list.append(self._avg_max)
return inputs
def cal_min_max(self, inputs):
abs_avg_value = paddle.abs(inputs.reshape((inputs.shape[0], -1)))
abs_avg_value = float(paddle.mean(paddle.max(abs_avg_value, axis=(1))))
return 0, abs_avg_value
def cal_thresholds(self):
""" Compute thresholds for MAX function.
"""
self._min, self._max = self._avg_min, paddle.mean(
paddle.to_tensor(self._avg_list))
self._scale, self._zero_point = self.cal_scales_zero_points()
def bit_length(self):
""" Return the bit length of quantized data.
"""
return self._quant_bits
def quant_axis(self):
""" Return quantization axis.
"""
return -1
def scales(self):
""" Return output scales.
"""
if self._scale is None:
self.cal_thresholds()
return self._scale
def zero_points(self):
""" Return output zero points.
"""
if self._zero_point is None:
self.cal_thresholds()
return self._zero_point
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import paddle
import numpy as np
from .uniform import UniformObserver
class BaseHistObserver(UniformObserver):
"""
Per-tensor abs max quantizer.
"""
def __init__(self,
quant_bits=8,
bins_count=2048,
upsample_bins_count=64,
sign=True,
symmetric=True):
super(BaseHistObserver, self).__init__(
quant_bits=quant_bits,
sign=sign,
symmetric=symmetric, )
self._bin_count = bins_count
self._upsample_bin_count = upsample_bins_count
self._hist_min = None
self._hist_max = None
self._hist = None
def _min_max(self, tensor):
"""" """
return float(paddle.min(tensor).numpy()), float(
paddle.max(tensor).numpy())
def _init_hists(self, inputs):
_min, _max = self._min_max(inputs)
hist = None
if _max > _min:
hist, _ = np.histogram(
inputs.numpy(), range=(_min, _max), bins=self._bin_count)
hist.astype(np.float32)
return hist
def forward(self, inputs):
""" Calculate forward pass.
"""
self._scale = None
self._zero_point = None
self._min = None
self._max = None
if self._hist_min is None or self._hist_max is None:
self._hist_min, self._hist_max = self._min_max(inputs)
self._hist = self._init_hists(inputs)
else:
new_min, new_max, new_hist = self.update_min_max_and_hist(
inputs,
self._hist_min,
self._hist_max,
self._hist,
self._bin_count,
self._upsample_bin_count, )
self._hist_min, self._hist_max = new_min, new_max
self._hist = new_hist
return inputs
def update_min_max_and_hist(self, tensor, origin_min, origin_max,
origin_hist, bins_count, upsample_bins_count):
_origin_min, _origin_max = origin_min, origin_max
_new_min, _new_max = self._min_max(tensor)
if (_new_max - _new_min) == 0.0:
return _origin_min, _origin_max, origin_hist
elif _origin_max - _origin_min == 0.0:
new_hist, _ = np.histogram(
tensor.numpy(), range=(_new_min, _new_max), bins=bins_count)
new_hist = new_hist.astype(np.float32)
return _new_min, _new_max, new_hist
elif _new_max <= _origin_max and _new_min >= _origin_min:
new_hist, _ = np.histogram(
tensor.numpy(),
range=(_origin_min, _origin_max),
bins=bins_count)
new_hist = new_hist.astype(np.float32)
new_hist += origin_hist
return _origin_min, _origin_max, new_hist
else:
_new_min = min(_new_min, _origin_min)
_new_max = max(_new_max, _origin_max)
_new_min, _new_max, downsample_bins_count, start_bin_idx = self._relax_min_max(
_new_min, _new_max, _origin_min, _origin_max, bins_count,
upsample_bins_count)
new_hist, _ = np.histogram(
tensor.numpy(), range=(_new_min, _new_max), bins=bins_count)
merged_histogram = self._merge_histograms(
new_hist, origin_hist, upsample_bins_count,
downsample_bins_count, start_bin_idx, bins_count)
return _new_min, _new_max, merged_histogram
def _merge_histograms(
self,
new_hist: np.ndarray,
origin_hist: np.ndarray,
upsample_bins_count: int,
downsample_bins_count: int,
start_bin_idx: int,
bins_count: int, ):
upsampled_histogram = np.repeat(origin_hist, upsample_bins_count)
expanded_hist = np.zeros(
(bins_count * downsample_bins_count), dtype=np.float32)
expanded_hist[start_bin_idx:bins_count * upsample_bins_count +
start_bin_idx] = upsampled_histogram
cumsumed_hist = np.cumsum(
expanded_hist,
dtype=np.float64)[downsample_bins_count - 1::downsample_bins_count]
shift_cumsumed_hist = np.zeros((bins_count), dtype=np.float64)
shift_cumsumed_hist[1:] = cumsumed_hist[0:-1]
sampled_hist = (
cumsumed_hist - shift_cumsumed_hist) / upsample_bins_count
new_hist += sampled_hist.astype(np.float32)
return new_hist
def _relax_min_max(self, new_min, new_max, origin_min, origin_max,
bins_count, upsample_bins_count):
_bin_width = (origin_max - origin_min) / (
bins_count * upsample_bins_count)
downsample_bins_count = np.ceil(
(new_max - new_min) / (bins_count * _bin_width))
error = downsample_bins_count * bins_count * _bin_width - (
new_max - new_min)
new_max += error
start_bin_idx = round((origin_min - new_min) / _bin_width)
return new_min, new_max, downsample_bins_count, start_bin_idx
@abc.abstractmethod
def cal_min_max(self):
pass
def cal_thresholds(self):
""" Compute thresholds for MAX function.
"""
assert self._hist is not None
self._min, self._max = self.cal_min_max()
self._scale, self._zero_point = self.cal_scales_zero_points()
def bit_length(self):
""" Return the bit length of quantized data.
"""
return self._quant_bits
def quant_axis(self):
""" Return quantization axis.
"""
return -1
def scales(self):
""" Return output scales.
"""
if self._scale is None:
self.cal_thresholds()
return self._scale
def zero_points(self):
""" Return output zero points.
"""
if self._zero_point is None:
self.cal_thresholds()
return self._zero_point
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from .uniform import UniformObserver
from paddle.quantization.factory import ObserverFactory
class EMDObserver(ObserverFactory):
r"""
It collects maximum absolute values of target tensor.
Args:
bit_length(int, optional): Number of bits to represent an quantized integer in binary.
dtype(str, optional): The data type of input tensor.
name (str, optional): This parameter is used by developers to print debugging information. \
For details, please refer to :ref:`api_guide_Name`. Default is None.
Examples:
.. code-block:: python
from paddle.quantization import QuantConfig
from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.99)
q_config = QuantConfig(activation=quanter, weight=quanter)
"""
def __init__(self, quant_bits=8):
super(EMDObserver, self).__init__(quant_bits=quant_bits)
def _get_class(self):
return EMDObserverLayer
class EMDObserverLayer(UniformObserver):
def __init__(self, layer, quant_bits=8):
super(EMDObserverLayer, self).__init__(quant_bits=quant_bits)
self._quant_bits = quant_bits
self._calibration_loss = float('inf')
self.qmin, self.qmax = self.qmin_qmax
def forward(self, inputs):
""" Calculate forward pass.
"""
self._scale = None
self._zero_point = None
self._min = None
self._max = None
self._emd_min, self._emd_max = self.cal_min_max(inputs)
return inputs
def cal_min_max(self, inputs):
abs_max_value = float(paddle.max(paddle.flatten(inputs)))
abs_max_value = 1e-8 if abs_max_value == 0.0 else abs_max_value
s = 0.3
while s <= 1.0:
scale = s * abs_max_value
s += 0.02
bins = 2**(self._quant_bits - 1) - 1
quant_var = paddle.clip(
paddle.round(inputs / scale * self.qmax), -self.qmax - 1,
self.qmax)
quant_dequant_var = quant_var / self.qmax * scale
emd_loss = paddle.abs(
paddle.mean(inputs) - paddle.mean(quant_dequant_var)
) + paddle.abs(paddle.std(inputs) - paddle.std(quant_dequant_var))
emd_loss = float(emd_loss)
if emd_loss <= self._calibration_loss:
self._calibration_loss = emd_loss
return 0, scale
def cal_thresholds(self):
""" Compute thresholds for MAX function.
"""
self._min, self._max = self._emd_min, self._emd_max
self._scale, self._zero_point = self.cal_scales_zero_points()
def bit_length(self):
""" Return the bit length of quantized data.
"""
return self._quant_bits
def quant_axis(self):
""" Return quantization axis.
"""
return -1
def scales(self):
""" Return output scales.
"""
if self._scale is None:
self.cal_thresholds()
return self._scale
def zero_points(self):
""" Return output zero points.
"""
if self._zero_point is None:
self.cal_thresholds()
return self._zero_point
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numpy as np
from .base_hist import BaseHistObserver
from paddle.quantization.factory import ObserverFactory
class HistObserver(ObserverFactory):
r"""
It collects maximum absolute values of target tensor.
Args:
bit_length(int, optional): Number of bits to represent an quantized integer in binary.
dtype(str, optional): The data type of input tensor.
name (str, optional): This parameter is used by developers to print debugging information. \
For details, please refer to :ref:`api_guide_Name`. Default is None.
Examples:
.. code-block:: python
from paddle.quantization import QuantConfig
from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.99)
q_config = QuantConfig(activation=quanter, weight=quanter)
"""
def __init__(self,
quant_bits=8,
bins_count=2048,
upsample_bins_count=64,
percent=0.999,
sign=True,
symmetric=True):
super(HistObserver, self).__init__(
quant_bits=quant_bits,
bins_count=bins_count,
upsample_bins_count=upsample_bins_count,
percent=percent,
sign=sign,
symmetric=symmetric)
def _get_class(self):
return PercentHistObserverLayer
class PercentHistObserverLayer(BaseHistObserver):
"""
Per-tensor abs max quantizer.
"""
def __init__(self,
layer,
quant_bits=8,
bins_count=2048,
upsample_bins_count=64,
percent=0.999,
sign=True,
symmetric=True):
super(PercentHistObserverLayer, self).__init__(
quant_bits=quant_bits,
bins_count=bins_count,
upsample_bins_count=upsample_bins_count,
sign=sign,
symmetric=symmetric)
self._percent = percent
def _cal_min_max_by_percent(self):
hist = self._hist / np.sum(self._hist, dtype=np.float64)
cumsumed_hist = np.cumsum(hist)
max_idx = np.argwhere(cumsumed_hist >= self._percent)[0]
min_idx = np.argwhere(cumsumed_hist >= (1 - self._percent))[0]
bin_width = (self._hist_max - self._hist_min) / hist.shape[0]
_max = self._hist_min + float((max_idx - 0.5) * bin_width)
_min = self._hist_min + float((min_idx - 0.5) * bin_width)
return _min, _max
def cal_min_max(self):
return self._cal_min_max_by_percent()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import math
from .base_hist import BaseHistObserver
from paddle.quantization.factory import ObserverFactory
class KLObserver(ObserverFactory):
r"""
It collects maximum absolute values of target tensor.
Args:
bit_length(int, optional): Number of bits to represent an quantized integer in binary.
dtype(str, optional): The data type of input tensor.
name (str, optional): This parameter is used by developers to print debugging information. \
For details, please refer to :ref:`api_guide_Name`. Default is None.
Examples:
.. code-block:: python
from paddle.quantization import QuantConfig
from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.99)
q_config = QuantConfig(activation=quanter, weight=quanter)
"""
def __init__(
self,
quant_bits=8,
bins_count=2048,
upsample_bins_count=64, ):
super(KLObserver, self).__init__(
quant_bits=quant_bits,
bins_count=bins_count,
upsample_bins_count=upsample_bins_count, )
def _get_class(self):
return KLObserverLayer
class KLObserverLayer(BaseHistObserver):
"""
Per-tensor KL observer.
"""
def __init__(
self,
layer,
quant_bits=8,
bins_count=2048,
upsample_bins_count=64, ):
super(KLObserverLayer, self).__init__(
quant_bits=quant_bits,
bins_count=bins_count,
upsample_bins_count=upsample_bins_count,
sign=True,
symmetric=True)
def _search_min_max_by_kl(self):
bin_width = (self._hist_max - self._hist_min) / self._bin_count
_max = cal_kl_threshold(self._hist, bin_width, self.bit_length())
return 0., _max
def cal_min_max(self):
return self._search_min_max_by_kl()
def expand_quantized_bins(quantized_bins, reference_bins):
'''
Expand hist bins.
'''
expanded_quantized_bins = [0] * len(reference_bins)
num_merged_bins = int(len(reference_bins) / len(quantized_bins))
j_start = 0
j_end = num_merged_bins
for idx in range(len(quantized_bins)):
zero_count = reference_bins[j_start:j_end].count(0)
num_merged_bins = j_end - j_start
if zero_count == num_merged_bins:
avg_bin_ele = 0
else:
avg_bin_ele = quantized_bins[idx] / (
num_merged_bins - zero_count + 0.0)
for idx1 in range(j_start, j_end):
expanded_quantized_bins[idx1] = (0 if reference_bins[idx1] == 0 else
avg_bin_ele)
j_start += num_merged_bins
j_end += num_merged_bins
if (idx + 1) == len(quantized_bins) - 1:
j_end = len(reference_bins)
return expanded_quantized_bins
def safe_entropy(reference_distr_P, P_sum, candidate_distr_Q, Q_sum):
'''
Calculate the entropy.
'''
assert len(reference_distr_P) == len(candidate_distr_Q)
tmp_sum1 = 0
tmp_sum2 = 0
for idx in range(len(reference_distr_P)):
p_idx = reference_distr_P[idx]
q_idx = candidate_distr_Q[idx]
if p_idx == 0:
tmp_sum1 += 0
tmp_sum2 += 0
else:
if q_idx == 0:
_logger.error("Fatal error!, idx = " + str(idx) +
" qindex = 0! p_idx = " + str(p_idx))
tmp_sum1 += p_idx * (math.log(Q_sum * p_idx))
tmp_sum2 += p_idx * (math.log(P_sum * q_idx))
return (tmp_sum1 - tmp_sum2) / P_sum
def cal_kl_threshold(hist, bin_width, bits):
'''
Using the KL-divergenc method to get the more precise threshold.
Args:
hist(List): The hist of the tensor.
bin_width(float): The bin width for the hist.
bits(int): The quantization bits.
'''
assert hist.ndim == 1
hist_bins = hist.shape[0]
starting_iter = int((hist_bins - 1) * 0.5)
quant_range = 2**(bits - 1) - 1
P_sum = np.sum(np.array(hist).ravel())
min_kl_divergence = 0
min_kl_index = 0
kl_inited = False
for i in range(starting_iter, hist_bins):
reference_distr_P = hist[0:i].tolist()
outliers_count = sum(hist[i:])
if reference_distr_P[i - 1] == 0:
continue
reference_distr_P[i - 1] += outliers_count
reference_distr_bins = reference_distr_P[:]
candidate_distr_Q = hist[0:i].tolist()
num_merged_bins = int(i / quant_range)
candidate_distr_Q_quantized = [0] * quant_range
j_start = 0
j_end = num_merged_bins
for idx in range(quant_range):
candidate_distr_Q_quantized[idx] = sum(
candidate_distr_Q[j_start:j_end])
j_start += num_merged_bins
j_end += num_merged_bins
if (idx + 1) == quant_range - 1:
j_end = i
candidate_distr_Q = expand_quantized_bins(candidate_distr_Q_quantized,
reference_distr_bins)
Q_sum = sum(candidate_distr_Q)
kl_divergence = safe_entropy(reference_distr_P, P_sum,
candidate_distr_Q, Q_sum)
if not kl_inited:
min_kl_divergence = kl_divergence
min_kl_index = i
kl_inited = True
elif kl_divergence < min_kl_divergence:
min_kl_divergence = kl_divergence
min_kl_index = i
else:
pass
if min_kl_index == 0:
while starting_iter > 0:
if hist[starting_iter] == 0:
starting_iter -= 1
continue
else:
break
min_kl_index = starting_iter
return (min_kl_index + 0.5) * bin_width
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from .uniform import UniformObserver
from paddle.quantization.factory import ObserverFactory
class MSEObserver(ObserverFactory):
r"""
It collects maximum absolute values of target tensor.
Args:
bit_length(int, optional): Number of bits to represent an quantized integer in binary.
dtype(str, optional): The data type of input tensor.
name (str, optional): This parameter is used by developers to print debugging information. \
For details, please refer to :ref:`api_guide_Name`. Default is None.
Examples:
.. code-block:: python
from paddle.quantization import QuantConfig
from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.99)
q_config = QuantConfig(activation=quanter, weight=quanter)
"""
def __init__(self, quant_bits=8):
super(MSEObserver, self).__init__(quant_bits=quant_bits)
def _get_class(self):
return MSEObserverLayer
class MSEObserverLayer(UniformObserver):
def __init__(self, layer, quant_bits=8):
super(MSEObserverLayer, self).__init__(quant_bits=quant_bits)
self.quant_bits = quant_bits
self.calibration_loss = float('inf')
self.qmin, self.qmax = self.qmin_qmax
def forward(self, inputs):
""" Calculate forward pass.
"""
self._scale = None
self._zero_point = None
self._min = None
self._max = None
self._mse_min, self._mse_max = self.cal_min_max(inputs)
return inputs
def cal_min_max(self, inputs):
abs_max_value = float(paddle.max(paddle.abs(inputs.flatten())))
abs_max_value = 1e-8 if abs_max_value == 0.0 else abs_max_value
s = 0.3
while s <= 1.0:
scale = s * abs_max_value
s += 0.02
quant_var = paddle.clip(
paddle.round(inputs / scale * self.qmax), -self.qmax - 1,
self.qmax)
quant_dequant_var = quant_var / self.qmax * scale
mse_loss = float(((inputs - quant_dequant_var)**2).mean())
if mse_loss <= self.calibration_loss:
self.calibration_loss = mse_loss
return 0, scale
def cal_thresholds(self):
""" Compute thresholds for MAX function.
"""
self._min, self._max = self._mse_min, self._mse_max
self._scale, self._zero_point = self.cal_scales_zero_points()
def bit_length(self):
""" Return the bit length of quantized data.
"""
return self._quant_bits
def quant_axis(self):
""" Return quantization axis.
"""
return -1
def scales(self):
""" Return output scales.
"""
if self._scale is None:
self.cal_thresholds()
return self._scale
def zero_points(self):
""" Return output zero points.
"""
if self._zero_point is None:
self.cal_thresholds()
return self._zero_point
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from paddle.quantization.base_observer import BaseObserver
class UniformObserver(BaseObserver):
""" An abstract class used for uniform quantization.
"""
def __init__(
self,
quant_bits=8,
sign=True,
symmetric=True, ):
super(UniformObserver, self).__init__()
self._quant_bits = quant_bits
self._sign = sign
self._symmetric = symmetric
self._min = None
self._max = None
self._qmin = None
self._qmax = None
self._scale = None
self._zero_point = None
@property
def qmin_qmax(self):
""" Get the range of the integer."""
if self._qmin is not None and self._qmax is not None:
return self.qmin, self.qmax
if self._sign:
self.qmin = -2**(self.bit_length() - 1)
self.qmax = 2**(self.bit_length() - 1) - 1
else:
self.qmin = 0
self.qmax = 2**self.bit_length()
return self.qmin, self.qmax
def cal_scales_zero_points(self):
""" Compute the scales and zero_points.
"""
assert self._min is not None and self._max is not None
_qmin, _qmax = self.qmin_qmax
# For one-sided distributions, the range (_min , _max ) is relaxed to include zero.
# It is important to ensure that common operations like zero padding do not cause quantization errors.
_min = min(self._min, 0.)
_max = max(self._max, 0.)
if self._symmetric:
self._scale = max(-_min, _max) / (float(_qmax - _qmin) / 2)
if self._sign:
self._zero_point = 0
else:
self._zero_point = (_qmax + _qmin) / 2
else:
self._scale = (_max - _min) / float(_qmax - _qmin)
self._zero_point = _qmin - round(_min / self._scale)
self._zero_point = np.clip(self._zero_point, _qmin, _qmax)
return self._scale, self._zero_point
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
import unittest
import paddle
import tempfile
from paddle.vision.models import resnet18
from paddle.quantization import QuantConfig
from paddle.quantization import PTQ
from paddleslim.quant.observers import HistObserver, KLObserver, EMDObserver, MSEObserver, AVGObserver
from paddleslim.quant.observers.hist import PercentHistObserverLayer
from paddleslim.quant.observers.kl import KLObserverLayer
from paddleslim.quant.observers.mse import MSEObserverLayer
from paddleslim.quant.observers.avg import AVGObserverLayer
from paddleslim.quant.observers.emd import EMDObserverLayer
from paddle.nn.quant.format import LinearDequanter, LinearQuanter
class TestPTQWithObservers(unittest.TestCase):
def __init__(self, observer, observer_type, *args, **kvargs):
super(TestPTQWithObservers, self).__init__(*args, **kvargs)
self.observer = observer
self.observer_type = observer_type
def setUp(self):
paddle.set_device("cpu")
self.init_case()
self.dummy_input = paddle.rand([1, 3, 224, 224])
self.temp_dir = tempfile.TemporaryDirectory(dir="./")
self.path = os.path.join(self.temp_dir.name, 'qat')
def tearDown(self):
self.temp_dir.cleanup()
def runTest(self):
self.test_quantize()
self.test_convert()
def init_case(self):
self.q_config = QuantConfig(activation=None, weight=None)
self.q_config.add_type_config(
paddle.nn.Conv2D, activation=self.observer, weight=self.observer)
def _count_layers(self, model, layer_type):
count = 0
for _layer in model.sublayers(True):
if isinstance(_layer, layer_type):
count += 1
return count
def test_quantize(self):
model = resnet18()
conv_count = self._count_layers(model, paddle.nn.Conv2D)
ptq = PTQ(self.q_config)
model.eval()
quant_model = ptq.quantize(model, inplace=False)
out = quant_model(self.dummy_input)
quantizer_cnt = self._count_layers(quant_model, self.observer_type)
self.assertEqual(quantizer_cnt, 2 * conv_count)
def test_convert(self):
model = resnet18()
conv_count = self._count_layers(model, paddle.nn.Conv2D)
ptq = PTQ(self.q_config)
model.eval()
quant_model = ptq.quantize(model, inplace=False)
out = quant_model(self.dummy_input)
converted_model = ptq.convert(quant_model, inplace=False)
# check count of LinearQuanter and LinearDequanter in dygraph
quantizer_count_in_dygraph = self._count_layers(converted_model,
LinearQuanter)
dequantizer_count_in_dygraph = self._count_layers(
converted_model, LinearDequanter)
self.assertEqual(quantizer_count_in_dygraph, conv_count)
self.assertEqual(dequantizer_count_in_dygraph, conv_count * 2)
observer_suite = unittest.TestSuite()
observer_suite.addTest(
TestPTQWithObservers(
observer=HistObserver(), observer_type=PercentHistObserverLayer))
observer_suite.addTest(
TestPTQWithObservers(
observer=KLObserver(bins_count=256, upsample_bins_count=32),
observer_type=KLObserverLayer))
observer_suite.addTest(
TestPTQWithObservers(
observer=EMDObserver(), observer_type=EMDObserverLayer))
observer_suite.addTest(
TestPTQWithObservers(
observer=MSEObserver(), observer_type=MSEObserverLayer))
observer_suite.addTest(
TestPTQWithObservers(
observer=AVGObserver(), observer_type=AVGObserverLayer))
if __name__ == '__main__':
runner = unittest.TextTestRunner(verbosity=2)
runner.run(observer_suite)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册