未验证 提交 b53888e7 编写于 作者: W whs 提交者: GitHub

Add API for quantization-aware training in dygraph mode (#49398)

* Add tools for quantization-aware training
1. Expose an API named paddle.quantization.QAT
2. Define a wrapper class to insert quanters into model for QAT
3. Add some functions in QuantConfig for QAT
4. Add unittest for QAT

* Add QuantedConv2D and QuantedLinear for QAT

* Add paddle.nn.quant.qat to setup.py
上级 e0b50269
......@@ -23,5 +23,6 @@ from .functional_layers import concat # noqa: F401
from .functional_layers import flatten # noqa: F401
from .functional_layers import matmul # noqa: F401
from .quant_layers import QuantStub # noqa: F401
from . import qat
__all__ = []
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .conv import QuantedConv2D
from .linear import QuantedLinear
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Layers used for QAT.
"""
from paddle.nn import Layer
from paddle.nn import functional as F
class QuantedConv2D(Layer):
"""
The computational logic of QuantizedConv2D is the same with Conv2D.
The only difference is that its inputs are all fake quantized.
"""
def __init__(self, layer: Layer, q_config):
super(QuantedConv2D, self).__init__()
# For Conv2D
self._groups = getattr(layer, '_groups')
self._stride = getattr(layer, '_stride')
self._padding = getattr(layer, '_padding')
self._padding_mode = getattr(layer, '_padding_mode')
if self._padding_mode != 'zeros':
self._reversed_padding_repeated_twice = getattr(
layer, '_reversed_padding_repeated_twice'
)
self._dilation = getattr(layer, '_dilation')
self._data_format = getattr(layer, '_data_format')
self.weight = getattr(layer, 'weight')
self.bias = getattr(layer, 'bias')
self.weight_quanter = None
self.activation_quanter = None
if q_config.weight is not None:
self.weight_quanter = q_config.weight._instance(layer)
if q_config.activation is not None:
self.activation_quanter = q_config.activation._instance(layer)
def forward(self, input):
quant_input = input
quant_weight = self.weight
if self.activation_quanter is not None:
quant_input = self.activation_quanter(input)
if self.weight_quanter is not None:
quant_weight = self.weight_quanter(self.weight)
return self._conv_forward(quant_input, quant_weight)
def _conv_forward(self, inputs, weights):
if self._padding_mode != 'zeros':
inputs = F.pad(
inputs,
self._reversed_padding_repeated_twice,
mode=self._padding_mode,
data_format=self._data_format,
)
self._padding = 0
return F.conv2d(
inputs,
weights,
bias=self.bias,
padding=self._padding,
stride=self._stride,
dilation=self._dilation,
groups=self._groups,
data_format=self._data_format,
)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.nn import Layer
from paddle.nn import functional as F
class QuantedLinear(Layer):
"""
The computational logic of QuantizedLinear is the same with Linear.
The only difference is that its inputs are all fake quantized.
"""
def __init__(self, layer: Layer, q_config):
super(QuantedLinear, self).__init__()
# For Linear
self.weight = getattr(layer, 'weight')
self.bias = getattr(layer, 'bias')
self.name = getattr(layer, 'name')
# For FakeQuant
self.weight_quanter = None
self.activation_quanter = None
if q_config.weight is not None:
self.weight_quanter = q_config.weight._instance(layer)
if q_config.activation is not None:
self.activation_quanter = q_config.activation._instance(layer)
def forward(self, input):
quant_input = input
quant_weight = self.weight
if self.activation_quanter is not None:
quant_input = self.activation_quanter(input)
if self.weight_quanter is not None:
quant_weight = self.weight_quanter(self.weight)
return self._linear_forward(quant_input, quant_weight)
def _linear_forward(self, input, weight):
out = F.linear(x=input, weight=weight, bias=self.bias, name=self.name)
return out
......@@ -50,9 +50,11 @@ from .imperative.qat import (
from .config import QuantConfig
from .base_quanter import BaseQuanter
from .factory import quanter
from .qat import QAT
__all__ = [
"QuantConfig",
"BaseQuanter",
"quanter",
"QAT",
]
......@@ -20,9 +20,13 @@ import paddle.nn as nn
from paddle.nn import Layer
from .factory import QuanterFactory
from .wrapper import ObserveWrapper
# TODO: Implement quanted layer and fill the mapping dict
DEFAULT_QAT_LAYER_MAPPINGS: Dict[Layer, Layer] = {}
DEFAULT_QAT_LAYER_MAPPINGS: Dict[Layer, Layer] = {
nn.Linear: nn.quant.qat.QuantedLinear,
nn.Conv2D: nn.quant.qat.QuantedConv2D,
}
DEFAULT_LEAVES = [nn.ReLU, nn.AvgPool2D]
......@@ -289,6 +293,10 @@ class QuantConfig(object):
"""
return self._is_leaf(layer) and self._has_observer_config(layer)
def _get_qat_layer(self, layer: Layer):
q_config = self._get_config_by_layer(layer)
return self.qat_layer_mappings[type(layer)](layer, q_config)
def _has_observer_config(self, layer: Layer):
r"""
Whether the layer has been configured for activation quantization.
......@@ -324,6 +332,10 @@ class QuantConfig(object):
_observer = None if _config is None else _config.activation
return None if _observer is None else _observer._instance(layer)
def _get_observe_wrapper(self, layer: Layer):
_observer = self._get_observer(layer)
return ObserveWrapper(_observer, layer)
@property
def qat_layer_mappings(self):
return self._qat_layer_mapping
......@@ -395,6 +407,8 @@ class QuantConfig(object):
r"""
Get the formated details of current config.
"""
if self._model is None:
return self.__str__()
return self._details_helper(self._model)
def _details_helper(self, layer: Layer):
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from paddle.nn import Layer
from .config import QuantConfig
class QAT(object):
r"""
Tools used to prepare model for quantization-aware training.
Args:
config(QuantConfig) - Quantization configuration
Examples:
.. code-block:: python
from paddle.quantization import QAT, QuantConfig
from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.9)
q_config = QuantConfig(activation=quanter, weight=quanter)
qat = QAT(q_config)
"""
def __init__(self, config: QuantConfig):
self._config = copy.deepcopy(config)
def quantize(self, model: Layer, inplace=False):
r"""
Create a model for quantization-aware training.
The quantization configuration will be propagated in the model.
And it will insert fake quanters into the model to simulate the quantization.
Args:
model(Layer) - The model to be quantized.
inplace(bool) - Whether to modify the model in-place.
Return: The prepared model for quantization-aware training.
Examples:
.. code-block:: python
from paddle.quantization import QAT, QuantConfig
from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
from paddle.vision.models import LeNet
quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.9)
q_config = QuantConfig(activation=quanter, weight=quanter)
qat = QAT(q_config)
model = LeNet()
quant_model = qat.quantize(model)
print(quant_model)
"""
_model = model if inplace else copy.deepcopy(model)
self._config._specify(_model)
self._convert_to_quant_layers(_model, self._config)
self._insert_activation_observers(_model, self._config)
return _model
def _convert_to_quant_layers(self, model: Layer, config: QuantConfig):
replaced = {}
for name, child in model.named_children():
if config._is_quantifiable(child):
if type(child) not in config.qat_layer_mappings:
self._convert_to_quant_layers(child, config)
else:
replaced[name] = config._get_qat_layer(child)
for key, value in replaced.items():
model._sub_layers[key] = value
def _insert_activation_observers(self, model: Layer, config: QuantConfig):
replaced = {}
for name, child in model.named_children():
if config._need_observe(child):
replaced[name] = config._get_observe_wrapper(child)
else:
self._insert_activation_observers(child, config)
for key, value in replaced.items():
model._sub_layers[key] = value
def _details(self):
return self._config.details()
def __str__(self):
return self._details()
def __repr__(self):
return self.__str__()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.nn import Layer
from .base_quanter import BaseQuanter
class ObserveWrapper(Layer):
r"""
Put an observer layer and an observed layer into a wrapping layer.
It is used to insert layers into the model for QAT or PTQ.
Args:
observer(BaseQuanter) - Observer layer
observed(Layer) - Observed layer
observe_input(bool) - If it is true the observer layer will be called before observed layer.
If it is false the observed layer will be called before observer layer. Default: True.
"""
def __init__(
self,
observer: BaseQuanter,
observed: Layer,
observe_input=True,
):
super(ObserveWrapper, self).__init__()
self._observer = observer
self._observed = observed
self._observe_input = observe_input
def forward(self, *inputs, **kwargs):
if self._observe_input:
out = self._observer(*inputs, **kwargs)
return self._observed(out, **kwargs)
else:
out = self._observed(*inputs, **kwargs)
return self._observer(out, **kwargs)
# copyright (c) 2022 paddlepaddle authors. all rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.nn.functional as F
from paddle.io import Dataset
from paddle.nn import Conv2D, Linear, ReLU, Sequential
from paddle.quantization import QAT, QuantConfig
from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
from paddle.quantization.quanters.abs_max import (
FakeQuanterWithAbsMaxObserverLayer,
)
class RandomDataset(Dataset):
def __init__(self, num_samples):
self.num_samples = num_samples
def __getitem__(self, idx):
data = np.random.random([3, 32, 32]).astype('float32')
return data
def __len__(self):
return self.num_samples
class Model(paddle.nn.Layer):
def __init__(self, num_classes=10):
super(Model, self).__init__()
self.num_classes = num_classes
self.features = Sequential(
Conv2D(3, 6, 3, stride=1, padding=1),
ReLU(),
paddle.nn.MaxPool2D(2, stride=2),
Conv2D(6, 16, 5, stride=1, padding=0),
ReLU(),
paddle.nn.MaxPool2D(2, stride=2),
)
if num_classes > 0:
self.fc = Sequential(
Linear(576, 120), Linear(120, 84), Linear(84, 10)
)
def forward(self, inputs):
x = self.features(inputs)
if self.num_classes > 0:
x = paddle.flatten(x, 1)
x = self.fc(x)
out = F.relu(x)
return out
class TestQAT(unittest.TestCase):
def test_qat(self):
nums_batch = 100
batch_size = 32
dataset = RandomDataset(nums_batch * batch_size)
loader = paddle.io.DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
drop_last=True,
num_workers=0,
)
model = Model()
quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.9)
q_config = QuantConfig(activation=quanter, weight=quanter)
qat = QAT(q_config)
print(model)
quant_model = qat.quantize(model)
print(quant_model)
quanter_count = 0
for _layer in quant_model.sublayers(True):
if isinstance(_layer, FakeQuanterWithAbsMaxObserverLayer):
quanter_count += 1
self.assertEqual(quanter_count, 14)
for _, data in enumerate(loader):
out = quant_model(data)
out.backward()
if __name__ == '__main__':
unittest.main()
......@@ -393,6 +393,7 @@ packages=['paddle',
'paddle.nn.functional',
'paddle.nn.layer',
'paddle.nn.quant',
'paddle.nn.quant.qat',
'paddle.nn.initializer',
'paddle.nn.utils',
'paddle.metric',
......
......@@ -1292,6 +1292,7 @@ def get_setup_parameters():
'paddle.nn.functional',
'paddle.nn.layer',
'paddle.nn.quant',
'paddle.nn.quant.qat',
'paddle.nn.initializer',
'paddle.nn.utils',
'paddle.metric',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册