From b53888e78f8675527179a4cf6aede1aa1c1ee374 Mon Sep 17 00:00:00 2001 From: whs Date: Wed, 11 Jan 2023 19:55:47 +0800 Subject: [PATCH] Add API for quantization-aware training in dygraph mode (#49398) * Add tools for quantization-aware training 1. Expose an API named paddle.quantization.QAT 2. Define a wrapper class to insert quanters into model for QAT 3. Add some functions in QuantConfig for QAT 4. Add unittest for QAT * Add QuantedConv2D and QuantedLinear for QAT * Add paddle.nn.quant.qat to setup.py --- python/paddle/nn/quant/__init__.py | 1 + python/paddle/nn/quant/qat/__init__.py | 15 +++ python/paddle/nn/quant/qat/conv.py | 79 +++++++++++++++ python/paddle/nn/quant/qat/linear.py | 51 ++++++++++ python/paddle/quantization/__init__.py | 2 + python/paddle/quantization/config.py | 16 ++- python/paddle/quantization/qat.py | 100 +++++++++++++++++++ python/paddle/quantization/wrapper.py | 48 +++++++++ python/paddle/tests/quantization/test_qat.py | 100 +++++++++++++++++++ python/setup.py.in | 1 + setup.py | 1 + 11 files changed, 413 insertions(+), 1 deletion(-) create mode 100644 python/paddle/nn/quant/qat/__init__.py create mode 100644 python/paddle/nn/quant/qat/conv.py create mode 100644 python/paddle/nn/quant/qat/linear.py create mode 100644 python/paddle/quantization/qat.py create mode 100644 python/paddle/quantization/wrapper.py create mode 100644 python/paddle/tests/quantization/test_qat.py diff --git a/python/paddle/nn/quant/__init__.py b/python/paddle/nn/quant/__init__.py index f96558bfbe..f1c5f7590c 100644 --- a/python/paddle/nn/quant/__init__.py +++ b/python/paddle/nn/quant/__init__.py @@ -23,5 +23,6 @@ from .functional_layers import concat # noqa: F401 from .functional_layers import flatten # noqa: F401 from .functional_layers import matmul # noqa: F401 from .quant_layers import QuantStub # noqa: F401 +from . import qat __all__ = [] diff --git a/python/paddle/nn/quant/qat/__init__.py b/python/paddle/nn/quant/qat/__init__.py new file mode 100644 index 0000000000..8701b8af76 --- /dev/null +++ b/python/paddle/nn/quant/qat/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .conv import QuantedConv2D +from .linear import QuantedLinear diff --git a/python/paddle/nn/quant/qat/conv.py b/python/paddle/nn/quant/qat/conv.py new file mode 100644 index 0000000000..d6ee061f3d --- /dev/null +++ b/python/paddle/nn/quant/qat/conv.py @@ -0,0 +1,79 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Layers used for QAT. +""" +from paddle.nn import Layer +from paddle.nn import functional as F + + +class QuantedConv2D(Layer): + """ + The computational logic of QuantizedConv2D is the same with Conv2D. + The only difference is that its inputs are all fake quantized. + """ + + def __init__(self, layer: Layer, q_config): + super(QuantedConv2D, self).__init__() + + # For Conv2D + self._groups = getattr(layer, '_groups') + self._stride = getattr(layer, '_stride') + self._padding = getattr(layer, '_padding') + self._padding_mode = getattr(layer, '_padding_mode') + if self._padding_mode != 'zeros': + self._reversed_padding_repeated_twice = getattr( + layer, '_reversed_padding_repeated_twice' + ) + self._dilation = getattr(layer, '_dilation') + self._data_format = getattr(layer, '_data_format') + self.weight = getattr(layer, 'weight') + self.bias = getattr(layer, 'bias') + + self.weight_quanter = None + self.activation_quanter = None + if q_config.weight is not None: + self.weight_quanter = q_config.weight._instance(layer) + if q_config.activation is not None: + self.activation_quanter = q_config.activation._instance(layer) + + def forward(self, input): + quant_input = input + quant_weight = self.weight + if self.activation_quanter is not None: + quant_input = self.activation_quanter(input) + if self.weight_quanter is not None: + quant_weight = self.weight_quanter(self.weight) + return self._conv_forward(quant_input, quant_weight) + + def _conv_forward(self, inputs, weights): + if self._padding_mode != 'zeros': + inputs = F.pad( + inputs, + self._reversed_padding_repeated_twice, + mode=self._padding_mode, + data_format=self._data_format, + ) + self._padding = 0 + + return F.conv2d( + inputs, + weights, + bias=self.bias, + padding=self._padding, + stride=self._stride, + dilation=self._dilation, + groups=self._groups, + data_format=self._data_format, + ) diff --git a/python/paddle/nn/quant/qat/linear.py b/python/paddle/nn/quant/qat/linear.py new file mode 100644 index 0000000000..004a493ce7 --- /dev/null +++ b/python/paddle/nn/quant/qat/linear.py @@ -0,0 +1,51 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.nn import Layer +from paddle.nn import functional as F + + +class QuantedLinear(Layer): + """ + The computational logic of QuantizedLinear is the same with Linear. + The only difference is that its inputs are all fake quantized. + """ + + def __init__(self, layer: Layer, q_config): + super(QuantedLinear, self).__init__() + # For Linear + self.weight = getattr(layer, 'weight') + self.bias = getattr(layer, 'bias') + self.name = getattr(layer, 'name') + # For FakeQuant + + self.weight_quanter = None + self.activation_quanter = None + if q_config.weight is not None: + self.weight_quanter = q_config.weight._instance(layer) + if q_config.activation is not None: + self.activation_quanter = q_config.activation._instance(layer) + + def forward(self, input): + quant_input = input + quant_weight = self.weight + if self.activation_quanter is not None: + quant_input = self.activation_quanter(input) + if self.weight_quanter is not None: + quant_weight = self.weight_quanter(self.weight) + return self._linear_forward(quant_input, quant_weight) + + def _linear_forward(self, input, weight): + out = F.linear(x=input, weight=weight, bias=self.bias, name=self.name) + return out diff --git a/python/paddle/quantization/__init__.py b/python/paddle/quantization/__init__.py index 8b7f9769e8..beb05125af 100644 --- a/python/paddle/quantization/__init__.py +++ b/python/paddle/quantization/__init__.py @@ -50,9 +50,11 @@ from .imperative.qat import ( from .config import QuantConfig from .base_quanter import BaseQuanter from .factory import quanter +from .qat import QAT __all__ = [ "QuantConfig", "BaseQuanter", "quanter", + "QAT", ] diff --git a/python/paddle/quantization/config.py b/python/paddle/quantization/config.py index 689c614660..2762490831 100644 --- a/python/paddle/quantization/config.py +++ b/python/paddle/quantization/config.py @@ -20,9 +20,13 @@ import paddle.nn as nn from paddle.nn import Layer from .factory import QuanterFactory +from .wrapper import ObserveWrapper # TODO: Implement quanted layer and fill the mapping dict -DEFAULT_QAT_LAYER_MAPPINGS: Dict[Layer, Layer] = {} +DEFAULT_QAT_LAYER_MAPPINGS: Dict[Layer, Layer] = { + nn.Linear: nn.quant.qat.QuantedLinear, + nn.Conv2D: nn.quant.qat.QuantedConv2D, +} DEFAULT_LEAVES = [nn.ReLU, nn.AvgPool2D] @@ -289,6 +293,10 @@ class QuantConfig(object): """ return self._is_leaf(layer) and self._has_observer_config(layer) + def _get_qat_layer(self, layer: Layer): + q_config = self._get_config_by_layer(layer) + return self.qat_layer_mappings[type(layer)](layer, q_config) + def _has_observer_config(self, layer: Layer): r""" Whether the layer has been configured for activation quantization. @@ -324,6 +332,10 @@ class QuantConfig(object): _observer = None if _config is None else _config.activation return None if _observer is None else _observer._instance(layer) + def _get_observe_wrapper(self, layer: Layer): + _observer = self._get_observer(layer) + return ObserveWrapper(_observer, layer) + @property def qat_layer_mappings(self): return self._qat_layer_mapping @@ -395,6 +407,8 @@ class QuantConfig(object): r""" Get the formated details of current config. """ + if self._model is None: + return self.__str__() return self._details_helper(self._model) def _details_helper(self, layer: Layer): diff --git a/python/paddle/quantization/qat.py b/python/paddle/quantization/qat.py new file mode 100644 index 0000000000..e70b56ec18 --- /dev/null +++ b/python/paddle/quantization/qat.py @@ -0,0 +1,100 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +from paddle.nn import Layer + +from .config import QuantConfig + + +class QAT(object): + r""" + Tools used to prepare model for quantization-aware training. + Args: + config(QuantConfig) - Quantization configuration + + Examples: + .. code-block:: python + from paddle.quantization import QAT, QuantConfig + from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver + quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.9) + q_config = QuantConfig(activation=quanter, weight=quanter) + qat = QAT(q_config) + """ + + def __init__(self, config: QuantConfig): + self._config = copy.deepcopy(config) + + def quantize(self, model: Layer, inplace=False): + r""" + Create a model for quantization-aware training. + + The quantization configuration will be propagated in the model. + And it will insert fake quanters into the model to simulate the quantization. + + Args: + model(Layer) - The model to be quantized. + inplace(bool) - Whether to modify the model in-place. + + Return: The prepared model for quantization-aware training. + + Examples: + .. code-block:: python + from paddle.quantization import QAT, QuantConfig + from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver + from paddle.vision.models import LeNet + + quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.9) + q_config = QuantConfig(activation=quanter, weight=quanter) + qat = QAT(q_config) + model = LeNet() + quant_model = qat.quantize(model) + print(quant_model) + """ + _model = model if inplace else copy.deepcopy(model) + self._config._specify(_model) + self._convert_to_quant_layers(_model, self._config) + self._insert_activation_observers(_model, self._config) + return _model + + def _convert_to_quant_layers(self, model: Layer, config: QuantConfig): + replaced = {} + for name, child in model.named_children(): + if config._is_quantifiable(child): + if type(child) not in config.qat_layer_mappings: + self._convert_to_quant_layers(child, config) + else: + replaced[name] = config._get_qat_layer(child) + for key, value in replaced.items(): + model._sub_layers[key] = value + + def _insert_activation_observers(self, model: Layer, config: QuantConfig): + replaced = {} + for name, child in model.named_children(): + if config._need_observe(child): + replaced[name] = config._get_observe_wrapper(child) + else: + self._insert_activation_observers(child, config) + for key, value in replaced.items(): + model._sub_layers[key] = value + + def _details(self): + return self._config.details() + + def __str__(self): + return self._details() + + def __repr__(self): + return self.__str__() diff --git a/python/paddle/quantization/wrapper.py b/python/paddle/quantization/wrapper.py new file mode 100644 index 0000000000..96178d2821 --- /dev/null +++ b/python/paddle/quantization/wrapper.py @@ -0,0 +1,48 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.nn import Layer + +from .base_quanter import BaseQuanter + + +class ObserveWrapper(Layer): + r""" + Put an observer layer and an observed layer into a wrapping layer. + It is used to insert layers into the model for QAT or PTQ. + Args: + observer(BaseQuanter) - Observer layer + observed(Layer) - Observed layer + observe_input(bool) - If it is true the observer layer will be called before observed layer. + If it is false the observed layer will be called before observer layer. Default: True. + """ + + def __init__( + self, + observer: BaseQuanter, + observed: Layer, + observe_input=True, + ): + super(ObserveWrapper, self).__init__() + self._observer = observer + self._observed = observed + self._observe_input = observe_input + + def forward(self, *inputs, **kwargs): + if self._observe_input: + out = self._observer(*inputs, **kwargs) + return self._observed(out, **kwargs) + else: + out = self._observed(*inputs, **kwargs) + return self._observer(out, **kwargs) diff --git a/python/paddle/tests/quantization/test_qat.py b/python/paddle/tests/quantization/test_qat.py new file mode 100644 index 0000000000..920e6b2bde --- /dev/null +++ b/python/paddle/tests/quantization/test_qat.py @@ -0,0 +1,100 @@ +# copyright (c) 2022 paddlepaddle authors. all rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +import paddle.nn.functional as F +from paddle.io import Dataset +from paddle.nn import Conv2D, Linear, ReLU, Sequential +from paddle.quantization import QAT, QuantConfig +from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver +from paddle.quantization.quanters.abs_max import ( + FakeQuanterWithAbsMaxObserverLayer, +) + + +class RandomDataset(Dataset): + def __init__(self, num_samples): + self.num_samples = num_samples + + def __getitem__(self, idx): + data = np.random.random([3, 32, 32]).astype('float32') + return data + + def __len__(self): + return self.num_samples + + +class Model(paddle.nn.Layer): + def __init__(self, num_classes=10): + super(Model, self).__init__() + self.num_classes = num_classes + self.features = Sequential( + Conv2D(3, 6, 3, stride=1, padding=1), + ReLU(), + paddle.nn.MaxPool2D(2, stride=2), + Conv2D(6, 16, 5, stride=1, padding=0), + ReLU(), + paddle.nn.MaxPool2D(2, stride=2), + ) + + if num_classes > 0: + self.fc = Sequential( + Linear(576, 120), Linear(120, 84), Linear(84, 10) + ) + + def forward(self, inputs): + x = self.features(inputs) + if self.num_classes > 0: + x = paddle.flatten(x, 1) + x = self.fc(x) + out = F.relu(x) + return out + + +class TestQAT(unittest.TestCase): + def test_qat(self): + nums_batch = 100 + batch_size = 32 + dataset = RandomDataset(nums_batch * batch_size) + loader = paddle.io.DataLoader( + dataset, + batch_size=batch_size, + shuffle=False, + drop_last=True, + num_workers=0, + ) + model = Model() + quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.9) + q_config = QuantConfig(activation=quanter, weight=quanter) + qat = QAT(q_config) + print(model) + quant_model = qat.quantize(model) + print(quant_model) + quanter_count = 0 + for _layer in quant_model.sublayers(True): + if isinstance(_layer, FakeQuanterWithAbsMaxObserverLayer): + quanter_count += 1 + self.assertEqual(quanter_count, 14) + + for _, data in enumerate(loader): + out = quant_model(data) + out.backward() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/setup.py.in b/python/setup.py.in index 25a1204610..d6d01fc4b0 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -393,6 +393,7 @@ packages=['paddle', 'paddle.nn.functional', 'paddle.nn.layer', 'paddle.nn.quant', + 'paddle.nn.quant.qat', 'paddle.nn.initializer', 'paddle.nn.utils', 'paddle.metric', diff --git a/setup.py b/setup.py index 14cbd27f5b..177bb64d28 100644 --- a/setup.py +++ b/setup.py @@ -1292,6 +1292,7 @@ def get_setup_parameters(): 'paddle.nn.functional', 'paddle.nn.layer', 'paddle.nn.quant', + 'paddle.nn.quant.qat', 'paddle.nn.initializer', 'paddle.nn.utils', 'paddle.metric', -- GitLab