提交 58b3480c 编写于 作者: R RachelXu7

Add Observers

上级 a8e2f02b
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .mse import MSEObserver
from .emd import EMDObserver
from .avg import AVGObserver
from .hist import HistObserver
from .kl import KLObserver
__all__ = [
"MSEObserver", "EMDObserver", "AVGObserver", "HistObserver", "KLObserver"
]
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from .uniform import UniformObserver
from paddle.quantization.factory import ObserverFactory
class AVGObserver(ObserverFactory):
r"""
It collects maximum absolute values of target tensor.
Args:
bit_length(int, optional): Number of bits to represent an quantized integer in binary.
dtype(str, optional): The data type of input tensor.
name (str, optional): This parameter is used by developers to print debugging information. \
For details, please refer to :ref:`api_guide_Name`. Default is None.
Examples:
.. code-block:: python
from paddle.quantization import QuantConfig
from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.99)
q_config = QuantConfig(activation=quanter, weight=quanter)
"""
def __init__(self, quant_bits=8):
super(AVGObserver, self).__init__(quant_bits=quant_bits)
def _get_class(self):
return AVGObserverLayer
class AVGObserverLayer(UniformObserver):
def __init__(
self,
layer,
quant_bits=8, ):
super(AVGObserverLayer, self).__init__(quant_bits=quant_bits)
self._quant_bits = quant_bits
self._avg_list = []
def forward(self, inputs):
""" Calculate forward pass.
"""
self._scale = None
self._zero_point = None
self._min = None
self._max = None
self._avg_min, self._avg_max = self.cal_min_max(inputs)
self._avg_list.append(self._avg_max)
return inputs
def cal_min_max(self, inputs):
abs_avg_value = paddle.abs(inputs.reshape((inputs.shape[0], -1)))
abs_avg_value = float(paddle.mean(paddle.max(abs_avg_value, axis=(1))))
return 0, abs_avg_value
def cal_thresholds(self):
""" Compute thresholds for MAX function.
"""
self._min, self._max = self._avg_min, paddle.mean(
paddle.to_tensor(self._avg_list))
self._scale, self._zero_point = self.cal_scales_zero_points()
def bit_length(self):
""" Return the bit length of quantized data.
"""
return self._quant_bits
def quant_axis(self):
""" Return quantization axis.
"""
return -1
def scales(self):
""" Return output scales.
"""
if self._scale is None:
self.cal_thresholds()
return self._scale
def zero_points(self):
""" Return output zero points.
"""
if self._zero_point is None:
self.cal_thresholds()
return self._zero_point
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import paddle
import numpy as np
from .uniform import UniformObserver
class BaseHistObserver(UniformObserver):
"""
Per-tensor abs max quantizer.
"""
def __init__(self,
quant_bits=8,
bins_count=2048,
upsample_bins_count=64,
sign=True,
symmetric=True):
super(BaseHistObserver, self).__init__(
quant_bits=quant_bits,
sign=sign,
symmetric=symmetric, )
self._bin_count = bins_count
self._upsample_bin_count = upsample_bins_count
self._hist_min = None
self._hist_max = None
self._hist = None
def _min_max(self, tensor):
"""" """
return float(paddle.min(tensor).numpy()), float(
paddle.max(tensor).numpy())
def _init_hists(self, inputs):
_min, _max = self._min_max(inputs)
hist = None
if _max > _min:
hist, _ = np.histogram(
inputs.numpy(), range=(_min, _max), bins=self._bin_count)
hist.astype(np.float32)
return hist
def forward(self, inputs):
""" Calculate forward pass.
"""
self._scale = None
self._zero_point = None
self._min = None
self._max = None
if self._hist_min is None or self._hist_max is None:
self._hist_min, self._hist_max = self._min_max(inputs)
self._hist = self._init_hists(inputs)
else:
new_min, new_max, new_hist = self.update_min_max_and_hist(
inputs,
self._hist_min,
self._hist_max,
self._hist,
self._bin_count,
self._upsample_bin_count, )
self._hist_min, self._hist_max = new_min, new_max
self._hist = new_hist
return inputs
def update_min_max_and_hist(self, tensor, origin_min, origin_max,
origin_hist, bins_count, upsample_bins_count):
_origin_min, _origin_max = origin_min, origin_max
_new_min, _new_max = self._min_max(tensor)
if (_new_max - _new_min) == 0.0:
return _origin_min, _origin_max, origin_hist
elif _origin_max - _origin_min == 0.0:
new_hist, _ = np.histogram(
tensor.numpy(), range=(_new_min, _new_max), bins=bins_count)
new_hist = new_hist.astype(np.float32)
return _new_min, _new_max, new_hist
elif _new_max <= _origin_max and _new_min >= _origin_min:
new_hist, _ = np.histogram(
tensor.numpy(),
range=(_origin_min, _origin_max),
bins=bins_count)
new_hist = new_hist.astype(np.float32)
new_hist += origin_hist
return _origin_min, _origin_max, new_hist
else:
_new_min = min(_new_min, _origin_min)
_new_max = max(_new_max, _origin_max)
_new_min, _new_max, downsample_bins_count, start_bin_idx = self._relax_min_max(
_new_min, _new_max, _origin_min, _origin_max, bins_count,
upsample_bins_count)
new_hist, _ = np.histogram(
tensor.numpy(), range=(_new_min, _new_max), bins=bins_count)
merged_histogram = self._merge_histograms(
new_hist, origin_hist, upsample_bins_count,
downsample_bins_count, start_bin_idx, bins_count)
return _new_min, _new_max, merged_histogram
def _merge_histograms(
self,
new_hist: np.ndarray,
origin_hist: np.ndarray,
upsample_bins_count: int,
downsample_bins_count: int,
start_bin_idx: int,
bins_count: int, ):
upsampled_histogram = np.repeat(origin_hist, upsample_bins_count)
expanded_hist = np.zeros(
(bins_count * downsample_bins_count), dtype=np.float32)
expanded_hist[start_bin_idx:bins_count * upsample_bins_count +
start_bin_idx] = upsampled_histogram
cumsumed_hist = np.cumsum(
expanded_hist,
dtype=np.float64)[downsample_bins_count - 1::downsample_bins_count]
shift_cumsumed_hist = np.zeros((bins_count), dtype=np.float64)
shift_cumsumed_hist[1:] = cumsumed_hist[0:-1]
sampled_hist = (
cumsumed_hist - shift_cumsumed_hist) / upsample_bins_count
new_hist += sampled_hist.astype(np.float32)
return new_hist
def _relax_min_max(self, new_min, new_max, origin_min, origin_max,
bins_count, upsample_bins_count):
_bin_width = (origin_max - origin_min) / (
bins_count * upsample_bins_count)
downsample_bins_count = np.ceil(
(new_max - new_min) / (bins_count * _bin_width))
error = downsample_bins_count * bins_count * _bin_width - (
new_max - new_min)
new_max += error
start_bin_idx = round((origin_min - new_min) / _bin_width)
return new_min, new_max, downsample_bins_count, start_bin_idx
@abc.abstractmethod
def cal_min_max(self):
pass
def cal_thresholds(self):
""" Compute thresholds for MAX function.
"""
assert self._hist is not None
self._min, self._max = self.cal_min_max()
self._scale, self._zero_point = self.cal_scales_zero_points()
def bit_length(self):
""" Return the bit length of quantized data.
"""
return self._quant_bits
def quant_axis(self):
""" Return quantization axis.
"""
return -1
def scales(self):
""" Return output scales.
"""
if self._scale is None:
self.cal_thresholds()
return self._scale
def zero_points(self):
""" Return output zero points.
"""
if self._zero_point is None:
self.cal_thresholds()
return self._zero_point
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from .uniform import UniformObserver
from paddle.quantization.factory import ObserverFactory
class EMDObserver(ObserverFactory):
r"""
It collects maximum absolute values of target tensor.
Args:
bit_length(int, optional): Number of bits to represent an quantized integer in binary.
dtype(str, optional): The data type of input tensor.
name (str, optional): This parameter is used by developers to print debugging information. \
For details, please refer to :ref:`api_guide_Name`. Default is None.
Examples:
.. code-block:: python
from paddle.quantization import QuantConfig
from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.99)
q_config = QuantConfig(activation=quanter, weight=quanter)
"""
def __init__(self, quant_bits=8):
super(EMDObserver, self).__init__(quant_bits=quant_bits)
def _get_class(self):
return EMDObserverLayer
class EMDObserverLayer(UniformObserver):
def __init__(self, layer, quant_bits=8):
super(EMDObserverLayer, self).__init__(quant_bits=quant_bits)
self._quant_bits = quant_bits
self._calibration_loss = float('inf')
self.qmin, self.qmax = self.qmin_qmax
def forward(self, inputs):
""" Calculate forward pass.
"""
self._scale = None
self._zero_point = None
self._min = None
self._max = None
self._emd_min, self._emd_max = self.cal_min_max(inputs)
return inputs
def cal_min_max(self, inputs):
abs_max_value = float(paddle.max(paddle.flatten(inputs)))
abs_max_value = 1e-8 if abs_max_value == 0.0 else abs_max_value
s = 0.3
while s <= 1.0:
scale = s * abs_max_value
s += 0.02
bins = 2**(self._quant_bits - 1) - 1
quant_var = paddle.clip(
paddle.round(inputs / scale * self.qmax), -self.qmax - 1,
self.qmax)
quant_dequant_var = quant_var / self.qmax * scale
emd_loss = paddle.abs(
paddle.mean(inputs) - paddle.mean(quant_dequant_var)
) + paddle.abs(paddle.std(inputs) - paddle.std(quant_dequant_var))
emd_loss = float(emd_loss)
if emd_loss <= self._calibration_loss:
self._calibration_loss = emd_loss
return 0, scale
def cal_thresholds(self):
""" Compute thresholds for MAX function.
"""
self._min, self._max = self._emd_min, self._emd_max
self._scale, self._zero_point = self.cal_scales_zero_points()
def bit_length(self):
""" Return the bit length of quantized data.
"""
return self._quant_bits
def quant_axis(self):
""" Return quantization axis.
"""
return -1
def scales(self):
""" Return output scales.
"""
if self._scale is None:
self.cal_thresholds()
return self._scale
def zero_points(self):
""" Return output zero points.
"""
if self._zero_point is None:
self.cal_thresholds()
return self._zero_point
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numpy as np
from .base_hist import BaseHistObserver
from paddle.quantization.factory import ObserverFactory
class HistObserver(ObserverFactory):
r"""
It collects maximum absolute values of target tensor.
Args:
bit_length(int, optional): Number of bits to represent an quantized integer in binary.
dtype(str, optional): The data type of input tensor.
name (str, optional): This parameter is used by developers to print debugging information. \
For details, please refer to :ref:`api_guide_Name`. Default is None.
Examples:
.. code-block:: python
from paddle.quantization import QuantConfig
from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.99)
q_config = QuantConfig(activation=quanter, weight=quanter)
"""
def __init__(self,
quant_bits=8,
bins_count=2048,
upsample_bins_count=64,
percent=0.999,
sign=True,
symmetric=True):
super(HistObserver, self).__init__(
quant_bits=quant_bits,
bins_count=bins_count,
upsample_bins_count=upsample_bins_count,
percent=percent,
sign=sign,
symmetric=symmetric)
def _get_class(self):
return PercentHistObserverLayer
class PercentHistObserverLayer(BaseHistObserver):
"""
Per-tensor abs max quantizer.
"""
def __init__(self,
layer,
quant_bits=8,
bins_count=2048,
upsample_bins_count=64,
percent=0.999,
sign=True,
symmetric=True):
super(PercentHistObserverLayer, self).__init__(
quant_bits=quant_bits,
bins_count=bins_count,
upsample_bins_count=upsample_bins_count,
sign=sign,
symmetric=symmetric)
self._percent = percent
def _cal_min_max_by_percent(self):
hist = self._hist / np.sum(self._hist, dtype=np.float64)
cumsumed_hist = np.cumsum(hist)
max_idx = np.argwhere(cumsumed_hist >= self._percent)[0]
min_idx = np.argwhere(cumsumed_hist >= (1 - self._percent))[0]
bin_width = (self._hist_max - self._hist_min) / hist.shape[0]
_max = self._hist_min + float((max_idx - 0.5) * bin_width)
_min = self._hist_min + float((min_idx - 0.5) * bin_width)
return _min, _max
def cal_min_max(self):
return self._cal_min_max_by_percent()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import math
from .base_hist import BaseHistObserver
from paddle.quantization.factory import ObserverFactory
class KLObserver(ObserverFactory):
r"""
It collects maximum absolute values of target tensor.
Args:
bit_length(int, optional): Number of bits to represent an quantized integer in binary.
dtype(str, optional): The data type of input tensor.
name (str, optional): This parameter is used by developers to print debugging information. \
For details, please refer to :ref:`api_guide_Name`. Default is None.
Examples:
.. code-block:: python
from paddle.quantization import QuantConfig
from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.99)
q_config = QuantConfig(activation=quanter, weight=quanter)
"""
def __init__(
self,
quant_bits=8,
bins_count=2048,
upsample_bins_count=64, ):
super(KLObserver, self).__init__(
quant_bits=quant_bits,
bins_count=bins_count,
upsample_bins_count=upsample_bins_count, )
def _get_class(self):
return KLObserverLayer
class KLObserverLayer(BaseHistObserver):
"""
Per-tensor KL observer.
"""
def __init__(
self,
layer,
quant_bits=8,
bins_count=2048,
upsample_bins_count=64, ):
super(KLObserverLayer, self).__init__(
quant_bits=quant_bits,
bins_count=bins_count,
upsample_bins_count=upsample_bins_count,
sign=True,
symmetric=True)
def _search_min_max_by_kl(self):
bin_width = (self._hist_max - self._hist_min) / self._bin_count
_max = cal_kl_threshold(self._hist, bin_width, self.bit_length())
return 0., _max
def cal_min_max(self):
return self._search_min_max_by_kl()
def expand_quantized_bins(quantized_bins, reference_bins):
'''
Expand hist bins.
'''
expanded_quantized_bins = [0] * len(reference_bins)
num_merged_bins = int(len(reference_bins) / len(quantized_bins))
j_start = 0
j_end = num_merged_bins
for idx in range(len(quantized_bins)):
zero_count = reference_bins[j_start:j_end].count(0)
num_merged_bins = j_end - j_start
if zero_count == num_merged_bins:
avg_bin_ele = 0
else:
avg_bin_ele = quantized_bins[idx] / (
num_merged_bins - zero_count + 0.0)
for idx1 in range(j_start, j_end):
expanded_quantized_bins[idx1] = (0 if reference_bins[idx1] == 0 else
avg_bin_ele)
j_start += num_merged_bins
j_end += num_merged_bins
if (idx + 1) == len(quantized_bins) - 1:
j_end = len(reference_bins)
return expanded_quantized_bins
def safe_entropy(reference_distr_P, P_sum, candidate_distr_Q, Q_sum):
'''
Calculate the entropy.
'''
assert len(reference_distr_P) == len(candidate_distr_Q)
tmp_sum1 = 0
tmp_sum2 = 0
for idx in range(len(reference_distr_P)):
p_idx = reference_distr_P[idx]
q_idx = candidate_distr_Q[idx]
if p_idx == 0:
tmp_sum1 += 0
tmp_sum2 += 0
else:
if q_idx == 0:
_logger.error("Fatal error!, idx = " + str(idx) +
" qindex = 0! p_idx = " + str(p_idx))
tmp_sum1 += p_idx * (math.log(Q_sum * p_idx))
tmp_sum2 += p_idx * (math.log(P_sum * q_idx))
return (tmp_sum1 - tmp_sum2) / P_sum
def cal_kl_threshold(hist, bin_width, bits):
'''
Using the KL-divergenc method to get the more precise threshold.
Args:
hist(List): The hist of the tensor.
bin_width(float): The bin width for the hist.
bits(int): The quantization bits.
'''
assert hist.ndim == 1
hist_bins = hist.shape[0]
starting_iter = int((hist_bins - 1) * 0.5)
quant_range = 2**(bits - 1) - 1
P_sum = np.sum(np.array(hist).ravel())
min_kl_divergence = 0
min_kl_index = 0
kl_inited = False
for i in range(starting_iter, hist_bins):
reference_distr_P = hist[0:i].tolist()
outliers_count = sum(hist[i:])
if reference_distr_P[i - 1] == 0:
continue
reference_distr_P[i - 1] += outliers_count
reference_distr_bins = reference_distr_P[:]
candidate_distr_Q = hist[0:i].tolist()
num_merged_bins = int(i / quant_range)
candidate_distr_Q_quantized = [0] * quant_range
j_start = 0
j_end = num_merged_bins
for idx in range(quant_range):
candidate_distr_Q_quantized[idx] = sum(
candidate_distr_Q[j_start:j_end])
j_start += num_merged_bins
j_end += num_merged_bins
if (idx + 1) == quant_range - 1:
j_end = i
candidate_distr_Q = expand_quantized_bins(candidate_distr_Q_quantized,
reference_distr_bins)
Q_sum = sum(candidate_distr_Q)
kl_divergence = safe_entropy(reference_distr_P, P_sum,
candidate_distr_Q, Q_sum)
if not kl_inited:
min_kl_divergence = kl_divergence
min_kl_index = i
kl_inited = True
elif kl_divergence < min_kl_divergence:
min_kl_divergence = kl_divergence
min_kl_index = i
else:
pass
if min_kl_index == 0:
while starting_iter > 0:
if hist[starting_iter] == 0:
starting_iter -= 1
continue
else:
break
min_kl_index = starting_iter
return (min_kl_index + 0.5) * bin_width
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from .uniform import UniformObserver
from paddle.quantization.factory import ObserverFactory
class MSEObserver(ObserverFactory):
r"""
It collects maximum absolute values of target tensor.
Args:
bit_length(int, optional): Number of bits to represent an quantized integer in binary.
dtype(str, optional): The data type of input tensor.
name (str, optional): This parameter is used by developers to print debugging information. \
For details, please refer to :ref:`api_guide_Name`. Default is None.
Examples:
.. code-block:: python
from paddle.quantization import QuantConfig
from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.99)
q_config = QuantConfig(activation=quanter, weight=quanter)
"""
def __init__(self, quant_bits=8):
super(MSEObserver, self).__init__(quant_bits=quant_bits)
def _get_class(self):
return MSEObserverLayer
class MSEObserverLayer(UniformObserver):
def __init__(self, layer, quant_bits=8):
super(MSEObserverLayer, self).__init__(quant_bits=quant_bits)
self.quant_bits = quant_bits
self.calibration_loss = float('inf')
self.qmin, self.qmax = self.qmin_qmax
def forward(self, inputs):
""" Calculate forward pass.
"""
self._scale = None
self._zero_point = None
self._min = None
self._max = None
self._mse_min, self._mse_max = self.cal_min_max(inputs)
return inputs
def cal_min_max(self, inputs):
abs_max_value = float(paddle.max(paddle.abs(inputs.flatten())))
abs_max_value = 1e-8 if abs_max_value == 0.0 else abs_max_value
s = 0.3
while s <= 1.0:
scale = s * abs_max_value
s += 0.02
quant_var = paddle.clip(
paddle.round(inputs / scale * self.qmax), -self.qmax - 1,
self.qmax)
quant_dequant_var = quant_var / self.qmax * scale
mse_loss = float(((inputs - quant_dequant_var)**2).mean())
if mse_loss <= self.calibration_loss:
self.calibration_loss = mse_loss
return 0, scale
def cal_thresholds(self):
""" Compute thresholds for MAX function.
"""
self._min, self._max = self._mse_min, self._mse_max
self._scale, self._zero_point = self.cal_scales_zero_points()
def bit_length(self):
""" Return the bit length of quantized data.
"""
return self._quant_bits
def quant_axis(self):
""" Return quantization axis.
"""
return -1
def scales(self):
""" Return output scales.
"""
if self._scale is None:
self.cal_thresholds()
return self._scale
def zero_points(self):
""" Return output zero points.
"""
if self._zero_point is None:
self.cal_thresholds()
return self._zero_point
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from paddle.quantization.base_observer import BaseObserver
class UniformObserver(BaseObserver):
""" An abstract class used for uniform quantization.
"""
def __init__(
self,
quant_bits=8,
sign=True,
symmetric=True, ):
super(UniformObserver, self).__init__()
self._quant_bits = quant_bits
self._sign = sign
self._symmetric = symmetric
self._min = None
self._max = None
self._qmin = None
self._qmax = None
self._scale = None
self._zero_point = None
@property
def qmin_qmax(self):
""" Get the range of the integer."""
if self._qmin is not None and self._qmax is not None:
return self.qmin, self.qmax
if self._sign:
self.qmin = -2**(self.bit_length() - 1)
self.qmax = 2**(self.bit_length() - 1) - 1
else:
self.qmin = 0
self.qmax = 2**self.bit_length()
return self.qmin, self.qmax
def cal_scales_zero_points(self):
""" Compute the scales and zero_points.
"""
assert self._min is not None and self._max is not None
_qmin, _qmax = self.qmin_qmax
# For one-sided distributions, the range (_min , _max ) is relaxed to include zero.
# It is important to ensure that common operations like zero padding do not cause quantization errors.
_min = min(self._min, 0.)
_max = max(self._max, 0.)
if self._symmetric:
self._scale = max(-_min, _max) / (float(_qmax - _qmin) / 2)
if self._sign:
self._zero_point = 0
else:
self._zero_point = (_qmax + _qmin) / 2
else:
self._scale = (_max - _min) / float(_qmax - _qmin)
self._zero_point = _qmin - round(_min / self._scale)
self._zero_point = np.clip(self._zero_point, _qmin, _qmax)
return self._scale, self._zero_point
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册