未验证 提交 f54331a6 编写于 作者: C Chang Xu 提交者: GitHub

Add More Observers (#1690)

上级 72633164
...@@ -14,5 +14,10 @@ ...@@ -14,5 +14,10 @@
from .hist import HistObserver from .hist import HistObserver
from .kl import KLObserver from .kl import KLObserver
from .mse import MSEObserver
from .emd import EMDObserver
from .avg import AVGObserver
__all__ = ["HistObserver", "KLObserver"] __all__ = [
\ No newline at end of file "HistObserver", "KLObserver", "MSEObserver", "EMDObserver", "AVGObserver"
]
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from .uniform import UniformObserver
from paddle.quantization.factory import ObserverFactory
class AVGObserver(ObserverFactory):
r"""
It collects maximum absolute values of target tensor.
Args:
bit_length(int, optional): Number of bits to represent an quantized integer in binary.
dtype(str, optional): The data type of input tensor.
name (str, optional): This parameter is used by developers to print debugging information. \
For details, please refer to :ref:`api_guide_Name`. Default is None.
Examples:
.. code-block:: python
from paddle.quantization import QuantConfig
from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.99)
q_config = QuantConfig(activation=quanter, weight=quanter)
"""
def __init__(self, quant_bits=8):
super(AVGObserver, self).__init__(quant_bits=quant_bits)
def _get_class(self):
return AVGObserverLayer
class AVGObserverLayer(UniformObserver):
def __init__(
self,
layer,
quant_bits=8, ):
super(AVGObserverLayer, self).__init__(quant_bits=quant_bits)
self._quant_bits = quant_bits
self._avg_list = []
def forward(self, inputs):
""" Calculate forward pass.
"""
self._scale = None
self._zero_point = None
self._min = None
self._max = None
self._avg_min, self._avg_max = self.cal_min_max(inputs)
self._avg_list.append(self._avg_max)
return inputs
def cal_min_max(self, inputs):
abs_avg_value = paddle.abs(inputs.reshape((inputs.shape[0], -1)))
abs_avg_value = float(paddle.mean(paddle.max(abs_avg_value, axis=(1))))
return 0, abs_avg_value
def cal_thresholds(self):
""" Compute thresholds for MAX function.
"""
self._min, self._max = self._avg_min, paddle.mean(
paddle.to_tensor(self._avg_list))
self._scale, self._zero_point = self.cal_scales_zero_points()
def min_value(self) -> float:
return self._min
def max_value(self) -> float:
return self._max
def bit_length(self):
""" Return the bit length of quantized data.
"""
return self._quant_bits
def quant_axis(self):
""" Return quantization axis.
"""
return -1
def scales(self):
""" Return output scales.
"""
if self._scale is None:
self.cal_thresholds()
return self._scale
def zero_points(self):
""" Return output zero points.
"""
if self._zero_point is None:
self.cal_thresholds()
return self._zero_point
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from .uniform import UniformObserver
from paddle.quantization.factory import ObserverFactory
class EMDObserver(ObserverFactory):
r"""
It collects maximum absolute values of target tensor.
Args:
bit_length(int, optional): Number of bits to represent an quantized integer in binary.
dtype(str, optional): The data type of input tensor.
name (str, optional): This parameter is used by developers to print debugging information. \
For details, please refer to :ref:`api_guide_Name`. Default is None.
Examples:
.. code-block:: python
from paddle.quantization import QuantConfig
from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.99)
q_config = QuantConfig(activation=quanter, weight=quanter)
"""
def __init__(self, quant_bits=8):
super(EMDObserver, self).__init__(quant_bits=quant_bits)
def _get_class(self):
return EMDObserverLayer
class EMDObserverLayer(UniformObserver):
def __init__(self, layer, quant_bits=8):
super(EMDObserverLayer, self).__init__(quant_bits=quant_bits)
self._quant_bits = quant_bits
self._calibration_loss = float('inf')
self.qmin, self.qmax = self.qmin_qmax
def forward(self, inputs):
""" Calculate forward pass.
"""
self._scale = None
self._zero_point = None
self._min = None
self._max = None
self._emd_min, self._emd_max = self.cal_min_max(inputs)
return inputs
def cal_min_max(self, inputs):
abs_max_value = float(paddle.max(paddle.flatten(inputs)))
abs_max_value = 1e-8 if abs_max_value == 0.0 else abs_max_value
s = 0.3
while s <= 1.0:
scale = s * abs_max_value
s += 0.02
bins = 2**(self._quant_bits - 1) - 1
quant_var = paddle.clip(
paddle.round(inputs / scale * self.qmax), -self.qmax - 1,
self.qmax)
quant_dequant_var = quant_var / self.qmax * scale
emd_loss = paddle.abs(
paddle.mean(inputs) - paddle.mean(quant_dequant_var)
) + paddle.abs(paddle.std(inputs) - paddle.std(quant_dequant_var))
emd_loss = float(emd_loss)
if emd_loss <= self._calibration_loss:
self._calibration_loss = emd_loss
return 0, scale
def cal_thresholds(self):
""" Compute thresholds for MAX function.
"""
self._min, self._max = self._emd_min, self._emd_max
self._scale, self._zero_point = self.cal_scales_zero_points()
def min_value(self) -> float:
return self._min
def max_value(self) -> float:
return self._max
def bit_length(self):
""" Return the bit length of quantized data.
"""
return self._quant_bits
def quant_axis(self):
""" Return quantization axis.
"""
return -1
def scales(self):
""" Return output scales.
"""
if self._scale is None:
self.cal_thresholds()
return self._scale
def zero_points(self):
""" Return output zero points.
"""
if self._zero_point is None:
self.cal_thresholds()
return self._zero_point
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from .uniform import UniformObserver
from paddle.quantization.factory import ObserverFactory
class MSEObserver(ObserverFactory):
r"""
It collects maximum absolute values of target tensor.
Args:
bit_length(int, optional): Number of bits to represent an quantized integer in binary.
dtype(str, optional): The data type of input tensor.
name (str, optional): This parameter is used by developers to print debugging information. \
For details, please refer to :ref:`api_guide_Name`. Default is None.
Examples:
.. code-block:: python
from paddle.quantization import QuantConfig
from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.99)
q_config = QuantConfig(activation=quanter, weight=quanter)
"""
def __init__(self, quant_bits=8):
super(MSEObserver, self).__init__(quant_bits=quant_bits)
def _get_class(self):
return MSEObserverLayer
class MSEObserverLayer(UniformObserver):
def __init__(self, layer, quant_bits=8):
super(MSEObserverLayer, self).__init__(quant_bits=quant_bits)
self.quant_bits = quant_bits
self.calibration_loss = float('inf')
self.qmin, self.qmax = self.qmin_qmax
def forward(self, inputs):
""" Calculate forward pass.
"""
self._scale = None
self._zero_point = None
self._min = None
self._max = None
self._mse_min, self._mse_max = self.cal_min_max(inputs)
return inputs
def cal_min_max(self, inputs):
abs_max_value = float(paddle.max(paddle.abs(inputs.flatten())))
abs_max_value = 1e-8 if abs_max_value == 0.0 else abs_max_value
s = 0.3
while s <= 1.0:
scale = s * abs_max_value
s += 0.02
quant_var = paddle.clip(
paddle.round(inputs / scale * self.qmax), -self.qmax - 1,
self.qmax)
quant_dequant_var = quant_var / self.qmax * scale
mse_loss = float(((inputs - quant_dequant_var)**2).mean())
if mse_loss <= self.calibration_loss:
self.calibration_loss = mse_loss
return 0, scale
def cal_thresholds(self):
""" Compute thresholds for MAX function.
"""
self._min, self._max = self._mse_min, self._mse_max
self._scale, self._zero_point = self.cal_scales_zero_points()
def min_value(self) -> float:
return self._min
def max_value(self) -> float:
return self._max
def bit_length(self):
""" Return the bit length of quantized data.
"""
return self._quant_bits
def quant_axis(self):
""" Return quantization axis.
"""
return -1
def scales(self):
""" Return output scales.
"""
if self._scale is None:
self.cal_thresholds()
return self._scale
def zero_points(self):
""" Return output zero points.
"""
if self._zero_point is None:
self.cal_thresholds()
return self._zero_point
...@@ -22,15 +22,19 @@ from paddle.vision.models import resnet18 ...@@ -22,15 +22,19 @@ from paddle.vision.models import resnet18
from paddle.quantization import QuantConfig from paddle.quantization import QuantConfig
from paddle.quantization import PTQ from paddle.quantization import PTQ
from paddleslim.quant.observers import HistObserver, KLObserver from paddleslim.quant.observers import HistObserver, KLObserver, EMDObserver, MSEObserver, AVGObserver
from paddleslim.quant.observers.hist import PercentHistObserverLayer from paddleslim.quant.observers.hist import PercentHistObserverLayer
from paddleslim.quant.observers.kl import KLObserverLayer from paddleslim.quant.observers.kl import KLObserverLayer
from paddleslim.quant.observers.mse import MSEObserverLayer
from paddleslim.quant.observers.avg import AVGObserverLayer
from paddleslim.quant.observers.emd import EMDObserverLayer
from paddleslim.quant.observers.kl import KLObserverLayer
from paddle.nn.quant.format import LinearDequanter, LinearQuanter from paddle.nn.quant.format import LinearDequanter, LinearQuanter
class TestPTQWithHistObserver(unittest.TestCase): class TestPTQObserver(unittest.TestCase):
def __init__(self, observer, observer_type, *args, **kvargs): def __init__(self, observer, observer_type, *args, **kvargs):
super(TestPTQWithHistObserver, self).__init__(*args, **kvargs) super(TestPTQObserver, self).__init__(*args, **kvargs)
self.observer = observer self.observer = observer
self.observer_type = observer_type self.observer_type = observer_type
...@@ -49,8 +53,6 @@ class TestPTQWithHistObserver(unittest.TestCase): ...@@ -49,8 +53,6 @@ class TestPTQWithHistObserver(unittest.TestCase):
self.test_convert() self.test_convert()
def init_case(self): def init_case(self):
# observer = HistObserver()
# self.observer_type = PercentHistObserverLayer
self.q_config = QuantConfig(activation=None, weight=None) self.q_config = QuantConfig(activation=None, weight=None)
self.q_config.add_type_config( self.q_config.add_type_config(
paddle.nn.Conv2D, activation=self.observer, weight=self.observer) paddle.nn.Conv2D, activation=self.observer, weight=self.observer)
...@@ -96,25 +98,31 @@ class TestPTQWithHistObserver(unittest.TestCase): ...@@ -96,25 +98,31 @@ class TestPTQWithHistObserver(unittest.TestCase):
observer_suite = unittest.TestSuite() observer_suite = unittest.TestSuite()
observer_suite.addTest( observer_suite.addTest(
TestPTQWithHistObserver( TestPTQObserver(
observer=HistObserver(sign=True, symmetric=True), observer=HistObserver(sign=True, symmetric=True),
observer_type=PercentHistObserverLayer)) observer_type=PercentHistObserverLayer))
observer_suite.addTest( observer_suite.addTest(
TestPTQWithHistObserver( TestPTQObserver(
observer=HistObserver(sign=False, symmetric=True), observer=HistObserver(sign=False, symmetric=True),
observer_type=PercentHistObserverLayer)) observer_type=PercentHistObserverLayer))
observer_suite.addTest( observer_suite.addTest(
TestPTQWithHistObserver( TestPTQObserver(
observer=HistObserver(sign=True, symmetric=False), observer=HistObserver(sign=True, symmetric=False),
observer_type=PercentHistObserverLayer)) observer_type=PercentHistObserverLayer))
observer_suite.addTest( observer_suite.addTest(
TestPTQWithHistObserver( TestPTQObserver(
observer=HistObserver(sign=False, symmetric=False), observer=HistObserver(sign=False, symmetric=False),
observer_type=PercentHistObserverLayer)) observer_type=PercentHistObserverLayer))
observer_suite.addTest( observer_suite.addTest(
TestPTQWithHistObserver( TestPTQObserver(
observer=KLObserver(bins_count=256), observer_type=KLObserverLayer)) observer=KLObserver(bins_count=256), observer_type=KLObserverLayer))
observer_suite.addTest(
TestPTQObserver(observer=AVGObserver(), observer_type=AVGObserverLayer))
observer_suite.addTest(
TestPTQObserver(observer=EMDObserver(), observer_type=EMDObserverLayer))
observer_suite.addTest(
TestPTQObserver(observer=MSEObserver(), observer_type=MSEObserverLayer))
if __name__ == '__main__': if __name__ == '__main__':
runner = unittest.TextTestRunner(verbosity=2) runner = unittest.TextTestRunner(verbosity=2)
runner.run(observer_suite) runner.run(observer_suite)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册