From 5000cf3dae11dfa84369732c8f043eeabe5dbbb4 Mon Sep 17 00:00:00 2001 From: whs Date: Tue, 28 Mar 2023 16:09:03 +0800 Subject: [PATCH] Add channel-wise weights observers (#1697) --- paddleslim/quant/observers/__init__.py | 11 +- paddleslim/quant/observers/abs_max_weight.py | 112 +++++++++++++++++++ paddleslim/quant/observers/channel_wise.py | 47 ++++++++ paddleslim/quant/observers/mse_weight.py | 57 ++++++++++ tests/quantization/test_observers_acc.py | 52 +++++---- 5 files changed, 259 insertions(+), 20 deletions(-) create mode 100644 paddleslim/quant/observers/abs_max_weight.py create mode 100644 paddleslim/quant/observers/channel_wise.py create mode 100644 paddleslim/quant/observers/mse_weight.py diff --git a/paddleslim/quant/observers/__init__.py b/paddleslim/quant/observers/__init__.py index 614712f3..c2f4e4c8 100644 --- a/paddleslim/quant/observers/__init__.py +++ b/paddleslim/quant/observers/__init__.py @@ -17,7 +17,16 @@ from .kl import KLObserver from .mse import MSEObserver from .emd import EMDObserver from .avg import AVGObserver +from .mse_weight import MSEChannelWiseWeightObserver +from .abs_max_weight import AbsMaxChannelWiseWeightObserver __all__ = [ - "HistObserver", "KLObserver", "MSEObserver", "EMDObserver", "AVGObserver" + "HistObserver", + "KLObserver", + "MSEObserver", + "EMDObserver", + "AVGObserver", + "MSEWeightObserver", + "MSEChannelWiseWeightObserver", + "AbsMaxChannelWiseWeightObserver", ] diff --git a/paddleslim/quant/observers/abs_max_weight.py b/paddleslim/quant/observers/abs_max_weight.py new file mode 100644 index 00000000..5594ab1c --- /dev/null +++ b/paddleslim/quant/observers/abs_max_weight.py @@ -0,0 +1,112 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle +from .channel_wise import ChannelWiseObserver +from paddle.quantization.factory import ObserverFactory + + +class AbsMaxChannelWiseWeightObserver(ObserverFactory): + r""" + It collects channel-wise maximum absolute values of target weights. + Args: + bit_length(int, optional): Number of bits to represent an quantized integer in binary. + dtype(str, optional): The data type of input tensor. + name (str, optional): This parameter is used by developers to print debugging information. \ + For details, please refer to :ref:`api_guide_Name`. Default is None. + Examples: + .. code-block:: python + from paddle.quantization import QuantConfig + from paddle.quantization.quanters import AbsMaxChannelWiseWeightObserver + quanter = AbsMaxChannelWiseWeightObserver() + q_config = QuantConfig(activation=None, weight=quanter) + """ + + def __init__(self, quant_bits=8): + super(AbsMaxChannelWiseWeightObserver, self).__init__( + quant_bits=quant_bits) + + def _get_class(self): + return AbsMaxChannelWiseWeightObserverLayer + + +class AbsMaxChannelWiseWeightObserverLayer(ChannelWiseObserver): + def __init__(self, layer, quant_bits=8): + super(AbsMaxChannelWiseWeightObserverLayer, self).__init__( + layer, + quant_bits=quant_bits, + sign=True, + symmetric=True, ) + self.quant_bits = quant_bits + self.calibration_loss = float('inf') + self.qmin, self.qmax = self.qmin_qmax + self._max = None + self._scale = None + self._zero_point = None + + def forward(self, inputs): + if self._max is None: + self._max = self._cal_abs_max(inputs) + return inputs + + def _cal_abs_max(self, inputs): + reduce_axis = tuple( + [i for i in range(len(inputs.shape)) if i != self.quant_axis()]) + abs_max_values = paddle.max(paddle.abs(inputs), axis=reduce_axis) + abs_max_values = paddle.where(abs_max_values == np.float32(0.0), + np.float32(1e-8), abs_max_values) + minimum_loss = paddle.full(abs_max_values.shape, float('inf')) + result = abs_max_values + factor = 0.3 + while factor <= 1.0: + scales = factor * abs_max_values + factor += 0.02 + expand_scales = paddle.unsqueeze(scales, axis=reduce_axis) + quant_var = paddle.clip( + paddle.round(inputs / expand_scales * self.qmax), self.qmin, + self.qmax) + quant_dequant_var = quant_var / self.qmax * expand_scales + + mse_loss = ((inputs - quant_dequant_var)**2).mean(axis=reduce_axis) + result = paddle.where(mse_loss < minimum_loss, scales, result) + minimum_loss = paddle.minimum(mse_loss, minimum_loss) + + return result + + def min_value(self) -> float: + return 0. + + def max_value(self) -> float: + return self._max + + def cal_thresholds(self): + """ Compute thresholds for MAX function. + """ + self._scale = self._max + self._zero_point = paddle.zeros_like(self._scale) + + def scales(self): + """ Return output scales. + """ + if self._scale is None: + self.cal_thresholds() + return self._scale + + def zero_points(self): + """ Return output zero points. + """ + if self._zero_point is None: + self.cal_thresholds() + return self._zero_point diff --git a/paddleslim/quant/observers/channel_wise.py b/paddleslim/quant/observers/channel_wise.py new file mode 100644 index 00000000..7d270cdd --- /dev/null +++ b/paddleslim/quant/observers/channel_wise.py @@ -0,0 +1,47 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict +import numpy as np +import paddle +from .mse import MSEObserverLayer +from .uniform import UniformObserver +from paddle.quantization.factory import ObserverFactory + +CHANNEL_AXIS: Dict[type, int] = {paddle.nn.Conv2D: 0, paddle.nn.Linear: 1} + + +class ChannelWiseObserver(UniformObserver): + def __init__( + self, + layer, + quant_bits=8, + sign=True, + symmetric=True, ): + super(ChannelWiseObserver, self).__init__( + quant_bits=quant_bits, + sign=sign, + symmetric=symmetric, ) + self._channel_axis = CHANNEL_AXIS[type(layer)] + self._quant_bits = quant_bits + + def quant_axis(self): + """ Return quantization axis. + """ + return self._channel_axis + + def bit_length(self): + """ Return the bit length of quantized data. + """ + return self._quant_bits diff --git a/paddleslim/quant/observers/mse_weight.py b/paddleslim/quant/observers/mse_weight.py new file mode 100644 index 00000000..16aa2726 --- /dev/null +++ b/paddleslim/quant/observers/mse_weight.py @@ -0,0 +1,57 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle +from paddle.quantization.factory import ObserverFactory +from .abs_max_weight import AbsMaxChannelWiseWeightObserverLayer + + +class MSEChannelWiseWeightObserver(ObserverFactory): + r""" + It collects channel-wise maximum absolute values and calculates the quantization scales by minimizing + the quantization MSE error. + Args: + bit_length(int, optional): Number of bits to represent an quantized integer in binary. + dtype(str, optional): The data type of input tensor. + name (str, optional): This parameter is used by developers to print debugging information. \ + For details, please refer to :ref:`api_guide_Name`. Default is None. + Examples: + .. code-block:: python + from paddle.quantization import QuantConfig + from paddle.quantization.quanters import MSEChannelWiseWeightObserver + quanter = MSEChannelWiseWeightObserver() + q_config = QuantConfig(activation=None, weight=quanter) + """ + + def __init__(self, quant_bits=8): + super(MSEChannelWiseWeightObserver, self).__init__( + quant_bits=quant_bits) + + def _get_class(self): + return MSEChannelWiseWeightObserverLayer + + +class MSEChannelWiseWeightObserverLayer(AbsMaxChannelWiseWeightObserverLayer): + def __init__(self, layer, quant_bits=8): + super(MSEChannelWiseWeightObserverLayer, self).__init__( + layer, quant_bits=quant_bits) + + def _cal_abs_max(self, inputs): + reduce_axis = tuple( + [i for i in range(len(inputs.shape)) if i != self.quant_axis()]) + abs_max_values = paddle.max(paddle.abs(inputs), axis=reduce_axis) + abs_max_values = paddle.where(abs_max_values == np.float32(0.0), + np.float32(1e-8), abs_max_values) + return abs_max_values diff --git a/tests/quantization/test_observers_acc.py b/tests/quantization/test_observers_acc.py index a2b03685..d026dd53 100644 --- a/tests/quantization/test_observers_acc.py +++ b/tests/quantization/test_observers_acc.py @@ -31,6 +31,8 @@ from paddleslim.quant.observers.mse import MSEObserverLayer from paddleslim.quant.observers.avg import AVGObserverLayer from paddleslim.quant.observers.emd import EMDObserverLayer from paddleslim.quant.observers.kl import KLObserverLayer +from paddleslim.quant.observers.mse_weight import MSEChannelWiseWeightObserver +from paddleslim.quant.observers.abs_max_weight import AbsMaxChannelWiseWeightObserver from paddle.nn.quant.format import LinearDequanter, LinearQuanter import logging @@ -70,10 +72,14 @@ class ImperativeLenet(paddle.nn.Layer): class TestPTQObserverAcc(unittest.TestCase): - def __init__(self, observer, observer_type, *args, **kvargs): + def __init__(self, + activation_observer, + weight_observer=None, + *args, + **kvargs): super(TestPTQObserverAcc, self).__init__(*args, **kvargs) - self.observer = observer - self.observer_type = observer_type + self.act_observer = activation_observer + self.weight_observer = weight_observer def setUp(self): paddle.set_device("cpu") @@ -100,7 +106,9 @@ class TestPTQObserverAcc(unittest.TestCase): def init_case(self): self.q_config = QuantConfig(activation=None, weight=None) self.q_config.add_type_config( - paddle.nn.Conv2D, activation=self.observer, weight=self.observer) + paddle.nn.Conv2D, + activation=self.act_observer, + weight=self.weight_observer) def _count_layers(self, model, layer_type): count = 0 @@ -202,7 +210,7 @@ class TestPTQObserverAcc(unittest.TestCase): quant_model = ptq.quantize(model, inplace=False) ptq_sample(quant_model) - converted_model = ptq.convert(quant_model, inplace=False) + converted_model = ptq.convert(quant_model, inplace=True) top1_2, top5_2 = test(converted_model) _logger.info( @@ -220,20 +228,26 @@ class TestPTQObserverAcc(unittest.TestCase): observer_suite = unittest.TestSuite() -observer_suite.addTest( - TestPTQObserverAcc( - observer=HistObserver(sign=True, symmetric=True), - observer_type=PercentHistObserverLayer)) -observer_suite.addTest( - TestPTQObserverAcc( - observer=KLObserver(bins_count=256), observer_type=KLObserverLayer)) - -observer_suite.addTest( - TestPTQObserverAcc(observer=AVGObserver(), observer_type=AVGObserverLayer)) -observer_suite.addTest( - TestPTQObserverAcc(observer=EMDObserver(), observer_type=EMDObserverLayer)) -observer_suite.addTest( - TestPTQObserverAcc(observer=MSEObserver(), observer_type=MSEObserverLayer)) + +for _observer in [ + AVGObserver(), + EMDObserver(), + MSEObserver(), + KLObserver(bins_count=256), + HistObserver(sign=True, symmetric=True), +]: + observer_suite.addTest( + TestPTQObserverAcc( + activation_observer=_observer, weight_observer=_observer)) + +for _weight_observer in [ + MSEChannelWiseWeightObserver(), + AbsMaxChannelWiseWeightObserver(), +]: + observer_suite.addTest( + TestPTQObserverAcc( + activation_observer=MSEObserver(), + weight_observer=_weight_observer)) if __name__ == '__main__': runner = unittest.TextTestRunner(verbosity=2) -- GitLab