未验证 提交 5000cf3d 编写于 作者: W whs 提交者: GitHub

Add channel-wise weights observers (#1697)

上级 18dd3aad
...@@ -17,7 +17,16 @@ from .kl import KLObserver ...@@ -17,7 +17,16 @@ from .kl import KLObserver
from .mse import MSEObserver from .mse import MSEObserver
from .emd import EMDObserver from .emd import EMDObserver
from .avg import AVGObserver from .avg import AVGObserver
from .mse_weight import MSEChannelWiseWeightObserver
from .abs_max_weight import AbsMaxChannelWiseWeightObserver
__all__ = [ __all__ = [
"HistObserver", "KLObserver", "MSEObserver", "EMDObserver", "AVGObserver" "HistObserver",
"KLObserver",
"MSEObserver",
"EMDObserver",
"AVGObserver",
"MSEWeightObserver",
"MSEChannelWiseWeightObserver",
"AbsMaxChannelWiseWeightObserver",
] ]
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from .channel_wise import ChannelWiseObserver
from paddle.quantization.factory import ObserverFactory
class AbsMaxChannelWiseWeightObserver(ObserverFactory):
r"""
It collects channel-wise maximum absolute values of target weights.
Args:
bit_length(int, optional): Number of bits to represent an quantized integer in binary.
dtype(str, optional): The data type of input tensor.
name (str, optional): This parameter is used by developers to print debugging information. \
For details, please refer to :ref:`api_guide_Name`. Default is None.
Examples:
.. code-block:: python
from paddle.quantization import QuantConfig
from paddle.quantization.quanters import AbsMaxChannelWiseWeightObserver
quanter = AbsMaxChannelWiseWeightObserver()
q_config = QuantConfig(activation=None, weight=quanter)
"""
def __init__(self, quant_bits=8):
super(AbsMaxChannelWiseWeightObserver, self).__init__(
quant_bits=quant_bits)
def _get_class(self):
return AbsMaxChannelWiseWeightObserverLayer
class AbsMaxChannelWiseWeightObserverLayer(ChannelWiseObserver):
def __init__(self, layer, quant_bits=8):
super(AbsMaxChannelWiseWeightObserverLayer, self).__init__(
layer,
quant_bits=quant_bits,
sign=True,
symmetric=True, )
self.quant_bits = quant_bits
self.calibration_loss = float('inf')
self.qmin, self.qmax = self.qmin_qmax
self._max = None
self._scale = None
self._zero_point = None
def forward(self, inputs):
if self._max is None:
self._max = self._cal_abs_max(inputs)
return inputs
def _cal_abs_max(self, inputs):
reduce_axis = tuple(
[i for i in range(len(inputs.shape)) if i != self.quant_axis()])
abs_max_values = paddle.max(paddle.abs(inputs), axis=reduce_axis)
abs_max_values = paddle.where(abs_max_values == np.float32(0.0),
np.float32(1e-8), abs_max_values)
minimum_loss = paddle.full(abs_max_values.shape, float('inf'))
result = abs_max_values
factor = 0.3
while factor <= 1.0:
scales = factor * abs_max_values
factor += 0.02
expand_scales = paddle.unsqueeze(scales, axis=reduce_axis)
quant_var = paddle.clip(
paddle.round(inputs / expand_scales * self.qmax), self.qmin,
self.qmax)
quant_dequant_var = quant_var / self.qmax * expand_scales
mse_loss = ((inputs - quant_dequant_var)**2).mean(axis=reduce_axis)
result = paddle.where(mse_loss < minimum_loss, scales, result)
minimum_loss = paddle.minimum(mse_loss, minimum_loss)
return result
def min_value(self) -> float:
return 0.
def max_value(self) -> float:
return self._max
def cal_thresholds(self):
""" Compute thresholds for MAX function.
"""
self._scale = self._max
self._zero_point = paddle.zeros_like(self._scale)
def scales(self):
""" Return output scales.
"""
if self._scale is None:
self.cal_thresholds()
return self._scale
def zero_points(self):
""" Return output zero points.
"""
if self._zero_point is None:
self.cal_thresholds()
return self._zero_point
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
import numpy as np
import paddle
from .mse import MSEObserverLayer
from .uniform import UniformObserver
from paddle.quantization.factory import ObserverFactory
CHANNEL_AXIS: Dict[type, int] = {paddle.nn.Conv2D: 0, paddle.nn.Linear: 1}
class ChannelWiseObserver(UniformObserver):
def __init__(
self,
layer,
quant_bits=8,
sign=True,
symmetric=True, ):
super(ChannelWiseObserver, self).__init__(
quant_bits=quant_bits,
sign=sign,
symmetric=symmetric, )
self._channel_axis = CHANNEL_AXIS[type(layer)]
self._quant_bits = quant_bits
def quant_axis(self):
""" Return quantization axis.
"""
return self._channel_axis
def bit_length(self):
""" Return the bit length of quantized data.
"""
return self._quant_bits
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from paddle.quantization.factory import ObserverFactory
from .abs_max_weight import AbsMaxChannelWiseWeightObserverLayer
class MSEChannelWiseWeightObserver(ObserverFactory):
r"""
It collects channel-wise maximum absolute values and calculates the quantization scales by minimizing
the quantization MSE error.
Args:
bit_length(int, optional): Number of bits to represent an quantized integer in binary.
dtype(str, optional): The data type of input tensor.
name (str, optional): This parameter is used by developers to print debugging information. \
For details, please refer to :ref:`api_guide_Name`. Default is None.
Examples:
.. code-block:: python
from paddle.quantization import QuantConfig
from paddle.quantization.quanters import MSEChannelWiseWeightObserver
quanter = MSEChannelWiseWeightObserver()
q_config = QuantConfig(activation=None, weight=quanter)
"""
def __init__(self, quant_bits=8):
super(MSEChannelWiseWeightObserver, self).__init__(
quant_bits=quant_bits)
def _get_class(self):
return MSEChannelWiseWeightObserverLayer
class MSEChannelWiseWeightObserverLayer(AbsMaxChannelWiseWeightObserverLayer):
def __init__(self, layer, quant_bits=8):
super(MSEChannelWiseWeightObserverLayer, self).__init__(
layer, quant_bits=quant_bits)
def _cal_abs_max(self, inputs):
reduce_axis = tuple(
[i for i in range(len(inputs.shape)) if i != self.quant_axis()])
abs_max_values = paddle.max(paddle.abs(inputs), axis=reduce_axis)
abs_max_values = paddle.where(abs_max_values == np.float32(0.0),
np.float32(1e-8), abs_max_values)
return abs_max_values
...@@ -31,6 +31,8 @@ from paddleslim.quant.observers.mse import MSEObserverLayer ...@@ -31,6 +31,8 @@ from paddleslim.quant.observers.mse import MSEObserverLayer
from paddleslim.quant.observers.avg import AVGObserverLayer from paddleslim.quant.observers.avg import AVGObserverLayer
from paddleslim.quant.observers.emd import EMDObserverLayer from paddleslim.quant.observers.emd import EMDObserverLayer
from paddleslim.quant.observers.kl import KLObserverLayer from paddleslim.quant.observers.kl import KLObserverLayer
from paddleslim.quant.observers.mse_weight import MSEChannelWiseWeightObserver
from paddleslim.quant.observers.abs_max_weight import AbsMaxChannelWiseWeightObserver
from paddle.nn.quant.format import LinearDequanter, LinearQuanter from paddle.nn.quant.format import LinearDequanter, LinearQuanter
import logging import logging
...@@ -70,10 +72,14 @@ class ImperativeLenet(paddle.nn.Layer): ...@@ -70,10 +72,14 @@ class ImperativeLenet(paddle.nn.Layer):
class TestPTQObserverAcc(unittest.TestCase): class TestPTQObserverAcc(unittest.TestCase):
def __init__(self, observer, observer_type, *args, **kvargs): def __init__(self,
activation_observer,
weight_observer=None,
*args,
**kvargs):
super(TestPTQObserverAcc, self).__init__(*args, **kvargs) super(TestPTQObserverAcc, self).__init__(*args, **kvargs)
self.observer = observer self.act_observer = activation_observer
self.observer_type = observer_type self.weight_observer = weight_observer
def setUp(self): def setUp(self):
paddle.set_device("cpu") paddle.set_device("cpu")
...@@ -100,7 +106,9 @@ class TestPTQObserverAcc(unittest.TestCase): ...@@ -100,7 +106,9 @@ class TestPTQObserverAcc(unittest.TestCase):
def init_case(self): def init_case(self):
self.q_config = QuantConfig(activation=None, weight=None) self.q_config = QuantConfig(activation=None, weight=None)
self.q_config.add_type_config( self.q_config.add_type_config(
paddle.nn.Conv2D, activation=self.observer, weight=self.observer) paddle.nn.Conv2D,
activation=self.act_observer,
weight=self.weight_observer)
def _count_layers(self, model, layer_type): def _count_layers(self, model, layer_type):
count = 0 count = 0
...@@ -202,7 +210,7 @@ class TestPTQObserverAcc(unittest.TestCase): ...@@ -202,7 +210,7 @@ class TestPTQObserverAcc(unittest.TestCase):
quant_model = ptq.quantize(model, inplace=False) quant_model = ptq.quantize(model, inplace=False)
ptq_sample(quant_model) ptq_sample(quant_model)
converted_model = ptq.convert(quant_model, inplace=False) converted_model = ptq.convert(quant_model, inplace=True)
top1_2, top5_2 = test(converted_model) top1_2, top5_2 = test(converted_model)
_logger.info( _logger.info(
...@@ -220,20 +228,26 @@ class TestPTQObserverAcc(unittest.TestCase): ...@@ -220,20 +228,26 @@ class TestPTQObserverAcc(unittest.TestCase):
observer_suite = unittest.TestSuite() observer_suite = unittest.TestSuite()
observer_suite.addTest(
TestPTQObserverAcc( for _observer in [
observer=HistObserver(sign=True, symmetric=True), AVGObserver(),
observer_type=PercentHistObserverLayer)) EMDObserver(),
observer_suite.addTest( MSEObserver(),
TestPTQObserverAcc( KLObserver(bins_count=256),
observer=KLObserver(bins_count=256), observer_type=KLObserverLayer)) HistObserver(sign=True, symmetric=True),
]:
observer_suite.addTest( observer_suite.addTest(
TestPTQObserverAcc(observer=AVGObserver(), observer_type=AVGObserverLayer)) TestPTQObserverAcc(
observer_suite.addTest( activation_observer=_observer, weight_observer=_observer))
TestPTQObserverAcc(observer=EMDObserver(), observer_type=EMDObserverLayer))
observer_suite.addTest( for _weight_observer in [
TestPTQObserverAcc(observer=MSEObserver(), observer_type=MSEObserverLayer)) MSEChannelWiseWeightObserver(),
AbsMaxChannelWiseWeightObserver(),
]:
observer_suite.addTest(
TestPTQObserverAcc(
activation_observer=MSEObserver(),
weight_observer=_weight_observer))
if __name__ == '__main__': if __name__ == '__main__':
runner = unittest.TextTestRunner(verbosity=2) runner = unittest.TextTestRunner(verbosity=2)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册