From b7030257a8b1ec50ac93fce0b0f9358c063d946f Mon Sep 17 00:00:00 2001 From: whs Date: Thu, 16 Feb 2023 14:57:20 +0800 Subject: [PATCH] Add Post-Training Quantization and export function in dygraph mode (#50107) Add PTQ and exporting function 1. Add Post-Training Quantization 1.1 Abstract some functions from QAT to Quantization class 1.2 Add Post-Training Quantization by extending Quantization class 1.3 Add observers for PTQ 1.4 Add unittest for PTQ 2. Add exporting function for QAT and PTQ --- python/paddle/nn/quant/format.py | 234 ++++++++++++++++++ python/paddle/nn/quant/qat/conv.py | 12 +- python/paddle/nn/quant/qat/linear.py | 13 +- python/paddle/quantization/__init__.py | 7 +- python/paddle/quantization/base_observer.py | 32 +++ python/paddle/quantization/factory.py | 3 + .../paddle/quantization/observers/__init__.py | 18 ++ .../paddle/quantization/observers/abs_max.py | 78 ++++++ python/paddle/quantization/ptq.py | 82 ++++++ python/paddle/quantization/qat.py | 38 +-- .../paddle/quantization/quanters/abs_max.py | 4 +- python/paddle/quantization/quantize.py | 112 +++++++++ python/paddle/tests/quantization/test_ptq.py | 134 ++++++++++ python/setup.py.in | 1 + setup.py | 1 + 15 files changed, 730 insertions(+), 39 deletions(-) create mode 100644 python/paddle/nn/quant/format.py create mode 100644 python/paddle/quantization/base_observer.py create mode 100644 python/paddle/quantization/observers/__init__.py create mode 100644 python/paddle/quantization/observers/abs_max.py create mode 100644 python/paddle/quantization/ptq.py create mode 100644 python/paddle/quantization/quantize.py create mode 100644 python/paddle/tests/quantization/test_ptq.py diff --git a/python/paddle/nn/quant/format.py b/python/paddle/nn/quant/format.py new file mode 100644 index 0000000000..d6154942f5 --- /dev/null +++ b/python/paddle/nn/quant/format.py @@ -0,0 +1,234 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Define some layers used to export quantization model with ONNX style.""" +import abc +from typing import List, Tuple + +import paddle +from paddle import _legacy_C_ops as _C_ops +from paddle.framework import in_dygraph_mode +from paddle.nn import Layer + + +class LinearQuanterDequanter(Layer): + def __init__(self, quanter, dequanter): + super(LinearQuanterDequanter, self).__init__() + self._quanter = quanter + self._dequanter = dequanter + + def forward(self, input): + out = input + if self._quanter is not None: + out = self._quanter(out) + if self._dequanter is not None: + out = self._dequanter(out) + return out + + @staticmethod + def from_quanter(quanter): + return LinearQuanterDequanter( + LinearQuanter.from_quanter(quanter), + LinearDequanter.from_quanter(quanter), + ) + + +class LinearQuanter(Layer): + def __init__(self, scales, zero_point=None, quant_axis=None, bit_length=8): + super(LinearQuanter, self).__init__() + self._scales = paddle.to_tensor(scales, dtype="float32") + self._zero_point = ( + paddle.zeros([1], dtype="float32") + if zero_point is None + else paddle.to_tensor(zero_point) + ) + self._quant_axis = -1 if quant_axis is None else quant_axis + self._bit_length = bit_length + + def forward(self, input): + if in_dygraph_mode(): + return _C_ops.quantize_linear( + input, + self._scales, + self._zero_point, + "quant_axis", + self._quant_axis, + "bit_length", + self._bit_length, + ) + else: + out = self._helper.create_variable_for_type_inference(input.dtype) + self._helper.append_op( + type='quantize_linear', + inputs={ + 'X': input, + 'Scale': self._scales, + 'ZeroPoint': self._zero_point, + }, + outputs={'Y': out}, + attrs={ + 'quant_axis': self._quant_axis, + 'bit_length': self._bit_length, + }, + ) + return out + + @staticmethod + def from_quanter(quanter): + + return LinearQuanter( + quanter.scales(), + zero_point=quanter.zero_points(), + quant_axis=quanter.quant_axis(), + bit_length=quanter.bit_length(), + ) + + +class LinearDequanter(Layer): + def __init__(self, scales, zero_point=None, quant_axis=None, bit_length=8): + super(LinearDequanter, self).__init__() + self._scales = paddle.to_tensor(scales, dtype="float32") + self._zero_point = ( + paddle.zeros([1], dtype="float32") + if zero_point is None + else paddle.to_tensor(zero_point) + ) + self._quant_axis = -1 if quant_axis is None else quant_axis + self._bit_length = bit_length + + def forward(self, input): + if in_dygraph_mode(): + return _C_ops.dequantize_linear( + input, + self._scales, + self._zero_point, + "quant_axis", + self._quant_axis, + "bit_length", + self._bit_length, + ) + else: + out = self._helper.create_variable_for_type_inference(input.dtype) + self._helper.append_op( + type='dequantize_linear', + inputs={ + 'X': input, + 'Scale': self._scales, + 'ZeroPoint': self._zero_point, + }, + outputs={'Y': out}, + attrs={ + 'quant_axis': self._quant_axis, + 'bit_length': self._bit_length, + }, + ) + return out + + @staticmethod + def from_quanter(quanter): + return LinearDequanter( + quanter.scales(), + zero_point=quanter.zero_points(), + quant_axis=quanter.quant_axis(), + bit_length=quanter.bit_length(), + ) + + +class ConvertibleQuantedLayer(Layer, metaclass=abc.ABCMeta): + r"""Abstract class to help convert quantized layer to inference model. + It defines some functions to convert quantizers and observers to quantize + or dequantize operators that maintain the quantization parameters used + during inference. + Examples: + .. code-block:: python + + # Given codes in ./customized_quanter.py + class CustomizedQuantedLayer(ConvertibleQuantedLayer): + def __init__(self): + super(CustomizedQuantedLayer, self).__init__() + self.weight_a = paddle.create_parameter(shape=[1], dtype='float32') + self.weight_b = paddle.create_parameter(shape=[1], dtype='float32') + self.quanter_for_weight_a = None + self.activation_weight = None + def forward(self, input): + qweight_a = self.quanter_for_weight_a(self.weight_a) + weight_b = self.weight_b + qinput = self.activation_weight(input) + // compute with qweight_a, weight_b and qinput. + return qweight * qinput + weight_b + + def weights_to_quanters(self): + return [('weight_a', 'quanter_for_weight_a')] + + def activation_quanters(self): + return ['activation_weight'] + """ + + def __init__(self): + super(ConvertibleQuantedLayer, self).__init__() + self.converted = False + + @abc.abstractmethod + def weights_to_quanters(self) -> List[Tuple[str, str]]: + r"""Get the name pairs of weights to be quantized and their corresponding + quantizers. In the convert function of this abstract class, it will call + the ‘weights_to_quanters’ function and do something as follows: + For each pair, the quantizer will be converted to a quantize operator and + a dequantize operator. Then, the weight will be quantized by the quantize + operator. Finally, the quantize operator will be removed and the weights + will be stored in integer data type. + + Returns: A list of name pairs. Each pair contains two names. The first is name of weight + to be quantized and the second is name of corresponding quanter. + """ + pass + + @abc.abstractmethod + def activation_quanters(self) -> List[str]: + r"""Get the names of quanters used to quantize activations. + All the quanters or observers returned by this function will be converted to quantize + and dequantize operators for deployment. + Returns: A list of quanter names. + """ + pass + + def _convert_quanter_to_qdq(self, quanter_name) -> LinearQuanterDequanter: + r"""Convert quanter to an instance of LinearQuanterDequanter.""" + assert hasattr( + self, quanter_name + ), f"{quanter_name} is not attribute of current layer." + quanter = getattr(self, quanter_name) + quanter = LinearQuanterDequanter.from_quanter(quanter) + setattr(self, quanter_name, quanter) + self._sub_layers[quanter_name] = quanter + return quanter + + def _quant_weights(self, weight_name, quanter): + r"""Quantize the weight by given quanter.""" + weight = getattr(self, weight_name) + qweight = quanter(weight) + weight.set_value(qweight) + + def _convert(self): + r"""Convert current layer to onnx style for inference.""" + assert not self.converted, "The model should be converted only once." + for weight_name, quanter_name in self.weights_to_quanters(): + qdq = self._convert_quanter_to_qdq(quanter_name) + self._quant_weights(weight_name, qdq._quanter) + qdq._quanter = None + qdq._sub_layers['_quanter'] = None + + for quanter_name in self.activation_quanters(): + self._convert_quanter_to_qdq(quanter_name) + + self.converted = True diff --git a/python/paddle/nn/quant/qat/conv.py b/python/paddle/nn/quant/qat/conv.py index d6ee061f3d..4c8e6915c1 100644 --- a/python/paddle/nn/quant/qat/conv.py +++ b/python/paddle/nn/quant/qat/conv.py @@ -17,10 +17,12 @@ Layers used for QAT. from paddle.nn import Layer from paddle.nn import functional as F +from ..format import ConvertibleQuantedLayer -class QuantedConv2D(Layer): + +class QuantedConv2D(ConvertibleQuantedLayer): """ - The computational logic of QuantizedConv2D is the same with Conv2D. + The computational logic of QuantizedConv2D is the same as Conv2D. The only difference is that its inputs are all fake quantized. """ @@ -77,3 +79,9 @@ class QuantedConv2D(Layer): groups=self._groups, data_format=self._data_format, ) + + def weights_to_quanters(self): + return [('weight', 'weight_quanter')] + + def activation_quanters(self): + return ['activation_quanter'] diff --git a/python/paddle/nn/quant/qat/linear.py b/python/paddle/nn/quant/qat/linear.py index 004a493ce7..b089486531 100644 --- a/python/paddle/nn/quant/qat/linear.py +++ b/python/paddle/nn/quant/qat/linear.py @@ -12,13 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. + from paddle.nn import Layer from paddle.nn import functional as F +from ..format import ConvertibleQuantedLayer + -class QuantedLinear(Layer): +class QuantedLinear(ConvertibleQuantedLayer): """ - The computational logic of QuantizedLinear is the same with Linear. + The computational logic of QuantizedLinear is the same as Linear. The only difference is that its inputs are all fake quantized. """ @@ -49,3 +52,9 @@ class QuantedLinear(Layer): def _linear_forward(self, input, weight): out = F.linear(x=input, weight=weight, bias=self.bias, name=self.name) return out + + def weights_to_quanters(self): + return [('weight', 'weight_quanter')] + + def activation_quanters(self): + return ['activation_quanter'] diff --git a/python/paddle/quantization/__init__.py b/python/paddle/quantization/__init__.py index beb05125af..61d52e39f3 100644 --- a/python/paddle/quantization/__init__.py +++ b/python/paddle/quantization/__init__.py @@ -1,4 +1,5 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +"""Quantization Module""" +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -49,12 +50,16 @@ from .imperative.qat import ( from .config import QuantConfig from .base_quanter import BaseQuanter +from .base_observer import BaseObserver from .factory import quanter from .qat import QAT +from .ptq import PTQ __all__ = [ "QuantConfig", "BaseQuanter", + "BaseObserver", "quanter", "QAT", + "PTQ", ] diff --git a/python/paddle/quantization/base_observer.py b/python/paddle/quantization/base_observer.py new file mode 100644 index 0000000000..ede6873ef5 --- /dev/null +++ b/python/paddle/quantization/base_observer.py @@ -0,0 +1,32 @@ +"""Abstract observer class.""" +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc + +from .base_quanter import BaseQuanter + + +class BaseObserver(BaseQuanter, metaclass=abc.ABCMeta): + r""" + Built-in observers and customized observers should extend this base observer + and implement abstract methods. + """ + + def __init__(self): + super(BaseObserver, self).__init__() + + @abc.abstractmethod + def cal_thresholds(self): + pass diff --git a/python/paddle/quantization/factory.py b/python/paddle/quantization/factory.py index 3fb579bb78..a57a2e95e3 100644 --- a/python/paddle/quantization/factory.py +++ b/python/paddle/quantization/factory.py @@ -70,6 +70,9 @@ class QuanterFactory(ClassWithArguments): return self.partial_class(layer) +ObserverFactory = QuanterFactory + + def quanter(class_name): r""" Annotation to declare a factory class for quanter. diff --git a/python/paddle/quantization/observers/__init__.py b/python/paddle/quantization/observers/__init__.py new file mode 100644 index 0000000000..733b3e7dbb --- /dev/null +++ b/python/paddle/quantization/observers/__init__.py @@ -0,0 +1,18 @@ +"""Observers""" +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .abs_max import AbsmaxObserver + +__all__ = ["AbsmaxObserver"] diff --git a/python/paddle/quantization/observers/abs_max.py b/python/paddle/quantization/observers/abs_max.py new file mode 100644 index 0000000000..4c29dd907a --- /dev/null +++ b/python/paddle/quantization/observers/abs_max.py @@ -0,0 +1,78 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import paddle + +from ..base_observer import BaseObserver +from ..factory import ObserverFactory + + +class AbsmaxObserver(ObserverFactory): + r""" + It collects maximum absolute values of target tensor. + + Args: + bit_length(int, optional): Number of bits to represent an quantized integer in binary. + dtype(str, optional): The data type of input tensor. + name (str, optional): This parameter is used by developers to print debugging information. \ + For details, please refer to :ref:`api_guide_Name`. Default is None. + + Examples: + .. code-block:: python + + from paddle.quantization import QuantConfig + from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver + quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.99) + q_config = QuantConfig(activation=quanter, weight=quanter) + """ + + def __init__(self, quant_bits=8): + super(AbsmaxObserver, self).__init__(quant_bits=quant_bits) + + def _get_class(self): + return AbsmaxObserverLayer + + +class AbsmaxObserverLayer(BaseObserver): + """ + Per-tensor abs max quantizer. + """ + + INIT_ABS_MAX = 1e-7 + + def __init__(self, layer, quant_bits=8): + super(AbsmaxObserverLayer, self).__init__() + self._quant_bits = quant_bits + self.abs_max_val = paddle.to_tensor(AbsmaxObserverLayer.INIT_ABS_MAX) + + def forward(self, input): + abs_max_val = paddle.max(paddle.abs(input)) + self.abs_max_val = paddle.maximum(abs_max_val, self.abs_max_val) + return input + + def cal_thresholds(self): + self.thresholds = self.abs_max_val + + def bit_length(self): + return self._quant_bits + + def quant_axis(self): + return -1 + + def scales(self): + return self.abs_max_val + + def zero_points(self): + return None diff --git a/python/paddle/quantization/ptq.py b/python/paddle/quantization/ptq.py new file mode 100644 index 0000000000..a9204397b7 --- /dev/null +++ b/python/paddle/quantization/ptq.py @@ -0,0 +1,82 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +import paddle.distributed.fleet as fleet +from paddle.nn import Layer + +from .config import QuantConfig +from .quantize import Quantization + + +class PTQ(Quantization): + """ + Applying post training quantization to the model. + """ + + def __init__(self, config: QuantConfig): + super(PTQ, self).__init__(config) + + def _is_parallel_training(self): + try: + if fleet.worker_num() > 2: + return True + else: + return False + except Exception: # fleet is not initialized + return False + + def quantize(self, model: Layer, inplace=False): + r""" + Create a model for post-training quantization. + + The quantization configuration will be propagated in the model. + And it will insert observers into the model to collect and compute + quantization parameters. + + Args: + model(Layer) - The model to be quantized. + inplace(bool) - Whether to modify the model in-place. + + Return: The prepared model for post-training quantization. + + Examples: + .. code-block:: python + from paddle.quantization import PTQ, QuantConfig + from paddle.quantization.observers import AbsmaxObserver + from paddle.vision.models import LeNet + + observer = AbsmaxObserver() + q_config = QuantConfig(activation=observer, weight=observer) + ptq = PTQ(q_config) + model = LeNet() + model.eval() + quant_model = ptq.quantize(model) + print(quant_model) + """ + _model = model + if not inplace: + assert ( + not self._is_parallel_training() + ), "'inplace' is not compatible with parallel training." + _model = copy.deepcopy(model) + _model.eval() + assert ( + not model.training + ), "Post-Training Quantization shoud not work on training models. Please set evaluation mode by model.eval()." + self._config._specify(_model) + self._convert_to_quant_layers(_model, self._config) + self._insert_activation_observers(_model, self._config) + return _model diff --git a/python/paddle/quantization/qat.py b/python/paddle/quantization/qat.py index e70b56ec18..e7a28a3b3a 100644 --- a/python/paddle/quantization/qat.py +++ b/python/paddle/quantization/qat.py @@ -17,9 +17,10 @@ import copy from paddle.nn import Layer from .config import QuantConfig +from .quantize import Quantization -class QAT(object): +class QAT(Quantization): r""" Tools used to prepare model for quantization-aware training. Args: @@ -35,7 +36,7 @@ class QAT(object): """ def __init__(self, config: QuantConfig): - self._config = copy.deepcopy(config) + super(QAT, self).__init__(config) def quantize(self, model: Layer, inplace=False): r""" @@ -63,38 +64,11 @@ class QAT(object): quant_model = qat.quantize(model) print(quant_model) """ + assert ( + model.training + ), "Quantization-Aware Training shoud work on training models. Please set training mode by model.train()." _model = model if inplace else copy.deepcopy(model) self._config._specify(_model) self._convert_to_quant_layers(_model, self._config) self._insert_activation_observers(_model, self._config) return _model - - def _convert_to_quant_layers(self, model: Layer, config: QuantConfig): - replaced = {} - for name, child in model.named_children(): - if config._is_quantifiable(child): - if type(child) not in config.qat_layer_mappings: - self._convert_to_quant_layers(child, config) - else: - replaced[name] = config._get_qat_layer(child) - for key, value in replaced.items(): - model._sub_layers[key] = value - - def _insert_activation_observers(self, model: Layer, config: QuantConfig): - replaced = {} - for name, child in model.named_children(): - if config._need_observe(child): - replaced[name] = config._get_observe_wrapper(child) - else: - self._insert_activation_observers(child, config) - for key, value in replaced.items(): - model._sub_layers[key] = value - - def _details(self): - return self._config.details() - - def __str__(self): - return self._details() - - def __repr__(self): - return self.__str__() diff --git a/python/paddle/quantization/quanters/abs_max.py b/python/paddle/quantization/quanters/abs_max.py index c80f2bf21e..c88269a9a9 100644 --- a/python/paddle/quantization/quanters/abs_max.py +++ b/python/paddle/quantization/quanters/abs_max.py @@ -182,10 +182,10 @@ class FakeQuanterWithAbsMaxObserverLayer(BaseQuanter): return out def bit_length(self): - return self.bits + return self._bit_length def quant_axis(self): - return None + return -1 def scales(self): return self._scale diff --git a/python/paddle/quantization/quantize.py b/python/paddle/quantization/quantize.py new file mode 100644 index 0000000000..4c1e257b97 --- /dev/null +++ b/python/paddle/quantization/quantize.py @@ -0,0 +1,112 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +import copy + +from paddle.nn import Layer +from paddle.nn.quant.format import ( + ConvertibleQuantedLayer, + LinearQuanterDequanter, +) + +from .base_quanter import BaseQuanter +from .config import QuantConfig + + +class Quantization(object, metaclass=abc.ABCMeta): + r""" + Abstract class used to prepares a copy of the model for quantization calibration or quantization-aware training. + Args: + config(QuantConfig) - Quantization configuration + """ + + def __init__(self, config: QuantConfig): + self._config = copy.deepcopy(config) + + @abc.abstractmethod + def quantize(self, model: Layer, inplace=False): + r"""Create a model for quantization-aware training or post-training quantization.""" + pass + + def convert(self, model: Layer, inplace=False): + r"""Convert the quantization model to onnx style. And the converted + model can be saved as inference model by calling paddle.jit.save. + Args: + model(Layer) - The quantized model to be covnerted. + inplace(bool) - Whether to modify the model in-place. + + Return: The converted model + + Examples: + .. code-block:: python + import paddle + from paddle.quantization import QAT, QuantConfig + from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver + from paddle.vision.models import LeNet + + quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.9) + q_config = QuantConfig(activation=quanter, weight=quanter) + qat = QAT(q_config) + model = LeNet() + quantized_model = qat.quantize(model) + converted_model = qat.convert(quantized_model) + dummy_data = paddle.rand([1, 1, 32, 32], dtype="float32") + paddle.jit.save(converted_model, "./quant_deploy", [dummy_data]) + """ + _model = model if inplace else copy.deepcopy(model) + replaced = {} + for name, child in _model.named_children(): + quant_dequant = None + if isinstance(child, ConvertibleQuantedLayer): + child._convert() + elif isinstance(child, BaseQuanter): + quant_dequant = LinearQuanterDequanter.from_quanter(child) + else: + self.convert(child, inplace=True) + if quant_dequant is not None: + replaced[name] = quant_dequant + for key, value in replaced.items(): + _model._sub_layers[key] = value + return _model + + def _convert_to_quant_layers(self, model: Layer, config: QuantConfig): + replaced = {} + for name, child in model.named_children(): + if config._is_quantifiable(child): + if type(child) not in config.qat_layer_mappings: + self._convert_to_quant_layers(child, config) + else: + replaced[name] = config._get_qat_layer(child) + for key, value in replaced.items(): + model._sub_layers[key] = value + + def _insert_activation_observers(self, model: Layer, config: QuantConfig): + replaced = {} + for name, child in model.named_children(): + if config._need_observe(child): + replaced[name] = config._get_observe_wrapper(child) + else: + self._insert_activation_observers(child, config) + for key, value in replaced.items(): + model._sub_layers[key] = value + + def _details(self): + return self._config.details() + + def __str__(self): + return self._details() + + def __repr__(self): + return self.__str__() diff --git a/python/paddle/tests/quantization/test_ptq.py b/python/paddle/tests/quantization/test_ptq.py new file mode 100644 index 0000000000..f5237fdd87 --- /dev/null +++ b/python/paddle/tests/quantization/test_ptq.py @@ -0,0 +1,134 @@ +# copyright (c) 2023 paddlepaddle authors. all rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +import numpy as np + +import paddle +import paddle.nn.functional as F +from paddle.nn import Conv2D, Linear, ReLU, Sequential +from paddle.nn.quant.format import LinearDequanter, LinearQuanter +from paddle.quantization import PTQ, QuantConfig +from paddle.quantization.observers import AbsmaxObserver +from paddle.quantization.observers.abs_max import AbsmaxObserverLayer + + +class LeNetDygraph(paddle.nn.Layer): + def __init__(self, num_classes=10): + super(LeNetDygraph, self).__init__() + self.num_classes = num_classes + self.features = Sequential( + Conv2D(1, 6, 3, stride=1, padding=1), + ReLU(), + paddle.nn.MaxPool2D(2, 2), + Conv2D(6, 16, 5, stride=1, padding=0), + ReLU(), + paddle.nn.MaxPool2D(2, 2), + ) + + if num_classes > 0: + self.fc = Sequential( + Linear(576, 120), Linear(120, 84), Linear(84, 10) + ) + + def forward(self, inputs): + x = self.features(inputs) + if self.num_classes > 0: + x = paddle.flatten(x, 1) + x = self.fc(x) + out = F.relu(x) + return out + + +class TestPTQ(unittest.TestCase): + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + self.path = os.path.join(self.temp_dir.name, 'ptq') + + def tearDown(self): + self.temp_dir.cleanup() + + def _get_model_for_ptq(self): + observer = AbsmaxObserver(quant_bits=8) + model = LeNetDygraph() + model.eval() + q_config = QuantConfig(activation=observer, weight=observer) + ptq = PTQ(q_config) + quant_model = ptq.quantize(model) + return quant_model, ptq + + def _count_layers(self, model, layer_type): + count = 0 + for _layer in model.sublayers(True): + if isinstance(_layer, layer_type): + count += 1 + return count + + def test_quantize(self): + ptq_model, _ = self._get_model_for_ptq() + image = paddle.rand([1, 1, 32, 32], dtype="float32") + out = ptq_model(image) + self.assertIsNotNone(out) + + observer_count = self._count_layers(ptq_model, AbsmaxObserverLayer) + self.assertEqual(observer_count, 14) + + def test_convert(self): + + quant_model, ptq = self._get_model_for_ptq() + + image = paddle.rand([1, 1, 32, 32], dtype="float32") + converted_model = ptq.convert(quant_model) + out = converted_model(image) + self.assertIsNotNone(out) + + observer_count = self._count_layers( + converted_model, AbsmaxObserverLayer + ) + quanter_count = self._count_layers(converted_model, LinearQuanter) + dequanter_count = self._count_layers(converted_model, LinearDequanter) + self.assertEqual(observer_count, 0) + self.assertEqual(dequanter_count, 14) + self.assertEqual(quanter_count, 9) + + save_path = os.path.join(self.temp_dir.name, 'int8_infer') + paddle.jit.save(converted_model, save_path, [image]) + + paddle.enable_static() + exe = paddle.static.Executor(paddle.CPUPlace()) + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): + [ + inference_program, + feed_target_names, + fetch_targets, + ] = paddle.static.load_inference_model(save_path, exe) + tensor_img = np.array( + np.random.random((1, 1, 32, 32)), dtype=np.float32 + ) + results = exe.run( + inference_program, + feed={feed_target_names[0]: tensor_img}, + fetch_list=fetch_targets, + ) + self.assertIsNotNone(results) + paddle.disable_static() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/setup.py.in b/python/setup.py.in index 29fe8b4551..f5a90e5db9 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -440,6 +440,7 @@ packages=['paddle', 'paddle.incubate.fleet.parameter_server.distribute_transpiler', 'paddle.quantization', 'paddle.quantization.quanters', + 'paddle.quantization.observers', 'paddle.sparse', 'paddle.sparse.nn', 'paddle.sparse.nn.layer', diff --git a/setup.py b/setup.py index 722c541ab0..f9d5aeac07 100644 --- a/setup.py +++ b/setup.py @@ -1326,6 +1326,7 @@ def get_setup_parameters(): 'paddle.incubate.fleet.parameter_server.distribute_transpiler', 'paddle.quantization', 'paddle.quantization.quanters', + 'paddle.quantization.observers', 'paddle.sparse', 'paddle.sparse.nn', 'paddle.sparse.nn.layer', -- GitLab