From 973dab86ad1cea374697d1310851de8824068726 Mon Sep 17 00:00:00 2001 From: whs Date: Wed, 8 Mar 2023 17:21:51 +0800 Subject: [PATCH] Enhance the quantization API with some new features (#50816) --- python/paddle/nn/quant/format.py | 10 ++- python/paddle/quantization/config.py | 32 +++---- .../paddle/quantization/quanters/abs_max.py | 46 +++++++++- python/paddle/quantization/quantize.py | 12 +-- .../tests/quantization/test_trace_quanter.py | 87 +++++++++++++++++++ 5 files changed, 162 insertions(+), 25 deletions(-) create mode 100644 python/paddle/tests/quantization/test_trace_quanter.py diff --git a/python/paddle/nn/quant/format.py b/python/paddle/nn/quant/format.py index d6154942f55..ca5b6ea7f3e 100644 --- a/python/paddle/nn/quant/format.py +++ b/python/paddle/nn/quant/format.py @@ -37,6 +37,7 @@ class LinearQuanterDequanter(Layer): @staticmethod def from_quanter(quanter): + assert quanter is not None return LinearQuanterDequanter( LinearQuanter.from_quanter(quanter), LinearDequanter.from_quanter(quanter), @@ -208,6 +209,8 @@ class ConvertibleQuantedLayer(Layer, metaclass=abc.ABCMeta): self, quanter_name ), f"{quanter_name} is not attribute of current layer." quanter = getattr(self, quanter_name) + if quanter is None: + return None quanter = LinearQuanterDequanter.from_quanter(quanter) setattr(self, quanter_name, quanter) self._sub_layers[quanter_name] = quanter @@ -224,9 +227,10 @@ class ConvertibleQuantedLayer(Layer, metaclass=abc.ABCMeta): assert not self.converted, "The model should be converted only once." for weight_name, quanter_name in self.weights_to_quanters(): qdq = self._convert_quanter_to_qdq(quanter_name) - self._quant_weights(weight_name, qdq._quanter) - qdq._quanter = None - qdq._sub_layers['_quanter'] = None + if qdq is not None: + self._quant_weights(weight_name, qdq._quanter) + qdq._quanter = None + qdq._sub_layers['_quanter'] = None for quanter_name in self.activation_quanters(): self._convert_quanter_to_qdq(quanter_name) diff --git a/python/paddle/quantization/config.py b/python/paddle/quantization/config.py index 8412c7fba90..79776c42e44 100644 --- a/python/paddle/quantization/config.py +++ b/python/paddle/quantization/config.py @@ -89,6 +89,7 @@ class QuantConfig(object): self._type2config = {} self._model = None self._qat_layer_mapping = copy.deepcopy(DEFAULT_QAT_LAYER_MAPPINGS) + self._customized_qat_layer_mapping = dict() self._customized_leaves = [] @@ -259,6 +260,7 @@ class QuantConfig(object): source, paddle.nn.Layer ), "The target layer should be a subclass of paddle.nn.qat.Layer" self._qat_layer_mapping[source] = target + self._customized_qat_layer_mapping[source] = target def add_customized_leaf(self, layer_type: type): r""" @@ -296,7 +298,11 @@ class QuantConfig(object): def _get_qat_layer(self, layer: Layer): q_config = self._get_config_by_layer(layer) - return self.qat_layer_mappings[type(layer)](layer, q_config) + + target_type = self._customized_qat_layer_mapping.get( + type(layer), self.qat_layer_mappings.get(type(layer)) + ) + return target_type(layer, q_config) def _has_observer_config(self, layer: Layer): r""" @@ -397,6 +403,7 @@ class QuantConfig(object): for child in model.children(): layer_prefix = child.full_name() config = self._layer2config.get(model, self.global_config) + config = self._type2config.get(type(child), config) config = self._prefix2config.get(layer_prefix, config) if config is not None: @@ -413,26 +420,21 @@ class QuantConfig(object): return self._details_helper(self._model) def _details_helper(self, layer: Layer): - extra_lines = [] sublayer_lines = [] for name, sublayer in layer.named_children(): sublayer_str = self._details_helper(sublayer) sublayer_str = self._addindent(sublayer_str, 2) - sublayer_lines.append( - '(' - + name - + '): ' - + sublayer_str - + ', ' - + str(self._layer2config[sublayer]) - ) + if sublayer in self._layer2config: + sublayer_lines.append( + '(' + + name + + '): ' + + sublayer_str + + ', ' + + str(self._layer2config[sublayer]) + ) final_str = layer.__class__.__name__ + '(' - if extra_lines: - if len(extra_lines) > 1: - final_str += '\n ' + '\n '.join(extra_lines) + '\n' - elif len(extra_lines) == 1: - final_str += extra_lines[0] if sublayer_lines: final_str += '\n ' + '\n '.join(sublayer_lines) + '\n' diff --git a/python/paddle/quantization/quanters/abs_max.py b/python/paddle/quantization/quanters/abs_max.py index c88269a9a98..14344459eba 100644 --- a/python/paddle/quantization/quanters/abs_max.py +++ b/python/paddle/quantization/quanters/abs_max.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import paddle from paddle import _legacy_C_ops +from paddle.fluid.data_feeder import check_variable_and_dtype from paddle.fluid.framework import _varbase_creator -from paddle.framework import ParamAttr +from paddle.framework import ParamAttr, core from paddle.nn.initializer import Constant from paddle.utils import unique_name @@ -142,7 +144,7 @@ class FakeQuanterWithAbsMaxObserverLayer(BaseQuanter): ) self._accum.stop_gradient = True - def forward(self, input): + def dynamic_forward(self, input): attrs = ( 'moving_rate', self._moving_rate, @@ -181,6 +183,46 @@ class FakeQuanterWithAbsMaxObserverLayer(BaseQuanter): return out + def static_forward(self, input): + check_variable_and_dtype( + input, 'input', ['float32'], "FakeQuantMovingAverageAbsMax" + ) + attrs = { + 'moving_rate': self._moving_rate, + 'bit_length': self._bit_length, + 'is_test': not self.training, + } + inputs = {"X": [input], "InScale": [self._scale]} + quant_out = self._helper.create_variable( + name="{}.quantized.dequantized".format(input.name), + dtype=input.dtype, + type=core.VarDesc.VarType.LOD_TENSOR, + persistable=False, + stop_gradient=False, + ) + outputs = {"Out": [quant_out], "OutScale": [self._scale]} + + if self.training: + inputs['InState'] = [self._state] + inputs['InAccum'] = [self._accum] + outputs['OutState'] = [self._state] + outputs['OutAccum'] = [self._accum] + + self._helper.append_op( + type="fake_quantize_dequantize_moving_average_abs_max", + inputs=inputs, + outputs=outputs, + attrs=attrs, + ) + + return quant_out + + def forward(self, input): + if paddle.framework.in_dynamic_mode(): + return self.dynamic_forward(input) + else: + return self.static_forward(input) + def bit_length(self): return self._bit_length diff --git a/python/paddle/quantization/quantize.py b/python/paddle/quantization/quantize.py index 4c1e257b97e..0d78befb245 100644 --- a/python/paddle/quantization/quantize.py +++ b/python/paddle/quantization/quantize.py @@ -84,11 +84,13 @@ class Quantization(object, metaclass=abc.ABCMeta): def _convert_to_quant_layers(self, model: Layer, config: QuantConfig): replaced = {} for name, child in model.named_children(): - if config._is_quantifiable(child): - if type(child) not in config.qat_layer_mappings: - self._convert_to_quant_layers(child, config) - else: - replaced[name] = config._get_qat_layer(child) + if ( + config._is_quantifiable(child) + and type(child) in config.qat_layer_mappings + ): + replaced[name] = config._get_qat_layer(child) + else: + self._convert_to_quant_layers(child, config) for key, value in replaced.items(): model._sub_layers[key] = value diff --git a/python/paddle/tests/quantization/test_trace_quanter.py b/python/paddle/tests/quantization/test_trace_quanter.py new file mode 100644 index 00000000000..6c42a9a399a --- /dev/null +++ b/python/paddle/tests/quantization/test_trace_quanter.py @@ -0,0 +1,87 @@ +# copyright (c) 2023 paddlepaddle authors. all rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The quantizer layers should be traced by paddle.jit.save function.""" +import os +import tempfile +import unittest + +import paddle +from paddle.quantization import QAT, QuantConfig +from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver +from paddle.quantization.quanters.abs_max import ( + FakeQuanterWithAbsMaxObserverLayer, +) +from paddle.vision.models import resnet18 + + +class TestPTQ(unittest.TestCase): + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory(dir="./") + self.path = os.path.join(self.temp_dir.name, 'ptq') + + def tearDown(self): + self.temp_dir.cleanup() + + def _get_model_for_qat(self): + observer = FakeQuanterWithAbsMaxObserver() + model = resnet18() + model.train() + q_config = QuantConfig(activation=None, weight=None) + q_config.add_type_config( + paddle.nn.Conv2D, activation=observer, weight=observer + ) + qat = QAT(q_config) + quant_model = qat.quantize(model) + return quant_model, qat + + def _count_layers(self, model, layer_type): + count = 0 + for _layer in model.sublayers(True): + if isinstance(_layer, layer_type): + count += 1 + return count + + def test_trace(self): + + quant_model, ptq = self._get_model_for_qat() + image = paddle.rand([1, 3, 32, 32], dtype="float32") + quantizer_count_in_dygraph = self._count_layers( + quant_model, FakeQuanterWithAbsMaxObserverLayer + ) + save_path = os.path.join(self.path, 'int8_infer') + paddle.jit.save(quant_model, save_path, [image]) + print(f"quant_model is saved into {save_path}") + + paddle.enable_static() + exe = paddle.static.Executor(paddle.CPUPlace()) + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): + [ + inference_program, + feed_target_names, + fetch_targets, + ] = paddle.static.load_inference_model(save_path, exe) + quantizer_count_in_static_model = 0 + for _op in inference_program.global_block().ops: + if _op.type == "fake_quantize_dequantize_moving_average_abs_max": + quantizer_count_in_static_model += 1 + self.assertEqual( + quantizer_count_in_dygraph, quantizer_count_in_static_model + ) + paddle.disable_static() + + +if __name__ == '__main__': + unittest.main() -- GitLab