From 069cfb7d24e9ae003478d9b5a777811e470eb278 Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Thu, 17 Aug 2023 15:19:08 +0800 Subject: [PATCH] add abs_max observer (#1786) --- paddleslim/quant/observers/__init__.py | 2 + paddleslim/quant/observers/abs_max.py | 102 +++++++++++++++++++++++++ 2 files changed, 104 insertions(+) create mode 100644 paddleslim/quant/observers/abs_max.py diff --git a/paddleslim/quant/observers/__init__.py b/paddleslim/quant/observers/__init__.py index c2f4e4c8..7ab3b723 100644 --- a/paddleslim/quant/observers/__init__.py +++ b/paddleslim/quant/observers/__init__.py @@ -17,6 +17,7 @@ from .kl import KLObserver from .mse import MSEObserver from .emd import EMDObserver from .avg import AVGObserver +from .abs_max import AbsmaxObserver from .mse_weight import MSEChannelWiseWeightObserver from .abs_max_weight import AbsMaxChannelWiseWeightObserver @@ -27,6 +28,7 @@ __all__ = [ "EMDObserver", "AVGObserver", "MSEWeightObserver", + "AbsmaxObserver", "MSEChannelWiseWeightObserver", "AbsMaxChannelWiseWeightObserver", ] diff --git a/paddleslim/quant/observers/abs_max.py b/paddleslim/quant/observers/abs_max.py new file mode 100644 index 00000000..10cdf298 --- /dev/null +++ b/paddleslim/quant/observers/abs_max.py @@ -0,0 +1,102 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle +from .uniform import UniformObserver +from paddle.quantization.factory import ObserverFactory + + +class AbsmaxObserver(ObserverFactory): + r""" + It collects maximum absolute values of target tensor. + Args: + bit_length(int, optional): Number of bits to represent an quantized integer in binary. + dtype(str, optional): The data type of input tensor. + name (str, optional): This parameter is used by developers to print debugging information. \ + For details, please refer to :ref:`api_guide_Name`. Default is None. + Examples: + .. code-block:: python + from paddle.quantization import QuantConfig + from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver + quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.99) + q_config = QuantConfig(activation=quanter, weight=quanter) + """ + + def __init__(self, quant_bits=8): + super(AbsmaxObserver, self).__init__(quant_bits=quant_bits) + + def _get_class(self): + return AbsmaxObserverLayer + + +class AbsmaxObserverLayer(UniformObserver): + def __init__( + self, + layer, + quant_bits=8, ): + super(AbsmaxObserverLayer, self).__init__(quant_bits=quant_bits) + self._quant_bits = quant_bits + self._layer = layer + self._scale = None + self._zero_point = None + self._min = None + self._max = paddle.to_tensor(1e-7, dtype="float32") + self.step = 0 + + def forward(self, inputs): + """ Calculate forward pass. + """ + self._min, self._max = self.cal_min_max(inputs) + return inputs + + def cal_min_max(self, inputs): + abs_max_val = paddle.max(paddle.abs(inputs.cast("float32"))) + abs_max_val = paddle.maximum(abs_max_val, self._max) + return 0, abs_max_val + + def cal_thresholds(self): + """ Compute thresholds for MAX function. + """ + self._scale, self._zero_point = self.cal_scales_zero_points() + + def min_value(self) -> float: + return self._min + + def max_value(self) -> float: + return self._max + + def bit_length(self): + """ Return the bit length of quantized data. + """ + return self._quant_bits + + def quant_axis(self): + """ Return quantization axis. + """ + return -1 + + def scales(self): + """ Return output scales. + """ + if self._scale is None: + self.cal_thresholds() + return self._scale + + def zero_points(self): + """ Return output zero points. + """ + if self._zero_point is None: + self.cal_thresholds() + return self._zero_point -- GitLab