未验证 提交 941444aa 编写于 作者: W whs 提交者: GitHub

Add configure of quantization for dynamic graph (#48000)

上级 ae544586
......@@ -85,6 +85,7 @@ import paddle.vision # noqa: F401
import paddle.audio # noqa: F401
import paddle.geometric # noqa: F401
import paddle.sparse # noqa: F401
import paddle.quantization # noqa: F401
from .tensor.attribute import is_complex # noqa: F401
from .tensor.attribute import is_integer # noqa: F401
......
......@@ -12,35 +12,46 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ...fluid.contrib.slim.quantization.imperative.ptq_config import (
from ..fluid.contrib.slim.quantization.imperative.ptq_config import (
PTQConfig,
default_ptq_config,
)
from ...fluid.contrib.slim.quantization.imperative.ptq_quantizer import (
from ..fluid.contrib.slim.quantization.imperative.ptq_quantizer import (
BaseQuantizer,
)
from ...fluid.contrib.slim.quantization.imperative.ptq_quantizer import (
from ..fluid.contrib.slim.quantization.imperative.ptq_quantizer import (
AbsmaxQuantizer,
)
from ...fluid.contrib.slim.quantization.imperative.ptq_quantizer import (
from ..fluid.contrib.slim.quantization.imperative.ptq_quantizer import (
PerChannelAbsmaxQuantizer,
)
from ...fluid.contrib.slim.quantization.imperative.ptq_quantizer import (
from ..fluid.contrib.slim.quantization.imperative.ptq_quantizer import (
KLQuantizer,
)
from ...fluid.contrib.slim.quantization.imperative.ptq_quantizer import (
from ..fluid.contrib.slim.quantization.imperative.ptq_quantizer import (
HistQuantizer,
)
from ...fluid.contrib.slim.quantization.imperative.ptq_quantizer import (
from ..fluid.contrib.slim.quantization.imperative.ptq_quantizer import (
SUPPORT_ACT_QUANTIZERS,
)
from ...fluid.contrib.slim.quantization.imperative.ptq_quantizer import (
from ..fluid.contrib.slim.quantization.imperative.ptq_quantizer import (
SUPPORT_WT_QUANTIZERS,
)
from ...fluid.contrib.slim.quantization.imperative.ptq_registry import (
from ..fluid.contrib.slim.quantization.imperative.ptq_registry import (
PTQRegistry,
)
from ...fluid.contrib.slim.quantization.imperative.ptq import ImperativePTQ
from ...fluid.contrib.slim.quantization.imperative.qat import (
from ..fluid.contrib.slim.quantization.imperative.ptq import ImperativePTQ
from ..fluid.contrib.slim.quantization.imperative.qat import (
ImperativeQuantAware,
)
from .config import QuantConfig
from .base_quanter import BaseQuanter
from .factory import quanter
__all__ = [
"QuantConfig",
"BaseQuanter",
"quanter",
]
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from collections.abc import Iterable
from typing import Union
import numpy as np
import paddle
from paddle.nn import Layer
class BaseQuanter(Layer, metaclass=abc.ABCMeta):
r"""
Built-in quanters and customized quanters should extend this base quanter
and implement abstract methods.
"""
def __init__(self):
super(BaseQuanter, self).__init__()
@abc.abstractmethod
def forward(self, input):
pass
@abc.abstractmethod
def scales(self) -> Union[paddle.Tensor, np.ndarray]:
r"""
Get the scales used for quantization.
It can be none which meams the quanter didn't hold scales for quantization.
"""
pass
@abc.abstractmethod
def zero_points(self) -> Union[paddle.Tensor, np.ndarray]:
r"""
Get the zero points used for quantization.
It can be none which meams the quanter didn't hold zero points for quantization.
"""
pass
@abc.abstractmethod
def quant_axis(self) -> Union[int, Iterable]:
r"""
Get the axis of quantization. None means tensor-wise quantization.
"""
pass
@abc.abstractmethod
def bit_length(self):
r"""
Get the bit length of quantization.
"""
pass
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from typing import Dict, Union
import paddle
import paddle.nn as nn
from paddle.nn import Layer
from .factory import QuanterFactory
# TODO: Implement quanted layer and fill the mapping dict
DEFAULT_QAT_LAYER_MAPPINGS: Dict[Layer, Layer] = {}
DEFAULT_LEAVES = [nn.ReLU, nn.AvgPool2D]
class SingleLayerConfig(object):
r"""
Configure how to quantize the activations and weights of a single layer.
Args:
activation(QuanterFactory): The factory to create instance of quanter used to quantize activations.
weight(QuanterFactory): The factory to create instance of quanter used to quantize weights.
"""
def __init__(self, activation: QuanterFactory, weight: QuanterFactory):
self._activation = activation
self._weight = weight
@property
def activation(self):
return self._activation
@property
def weight(self):
return self._weight
def __str__(self):
return f"activation: {self._activation}\nweight: {self._weight}"
class QuantConfig(object):
r"""
Configure how to quantize a model or a part of the model. It will map each layer to
an instance of SingleLayerConfig by the settings. It provides diverse methods to set
the strategies of quantization.
Args:
activation(QuanterFactory): The global quantizer used to quantize the activations.
weight(QuanterFactory): The global quantizer used to quantize the weights.
Examples:
.. code-block:: python
from paddle.quantization import QuantConfig
from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.9)
q_config = QuantConfig(activation=quanter, weight=quanter)
print(q_config)
"""
def __init__(self, activation: QuanterFactory, weight: QuanterFactory):
if activation is None and weight is None:
self._global_config = None
else:
self._global_config = SingleLayerConfig(activation, weight)
self._layer2config = {}
self._prefix2config = {}
self._type2config = {}
self._model = None
self._qat_layer_mapping = copy.deepcopy(DEFAULT_QAT_LAYER_MAPPINGS)
self._customized_leaves = []
def add_layer_config(
self,
layer: Union[Layer, list],
activation: QuanterFactory = None,
weight: QuanterFactory = None,
):
r"""
Set the quantization config by layer. It has the highest priority among
all the setting methods.
Args:
layer(Union[Layer, list]): One or a list of layers.
activation(QuanterFactory): Quanter used for activations.
weight(QuanterFactory): Quanter used for weights.
Examples:
.. code-block:: python
import paddle
from paddle.nn import Linear
from paddle.quantization import QuantConfig
from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
class Model(paddle.nn.Layer):
def __init__(self):
super(Model, self).__init__()
self.fc = Linear(576, 120)
model = Model()
quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.9)
q_config = QuantConfig(activation=None, weight=None)
q_config.add_layer_config([model.fc], activation=quanter, weight=quanter)
print(q_config)
"""
if isinstance(layer, list):
for _element in layer:
self.add_layer_config(
_element, activation=activation, weight=weight
)
else:
self.add_name_config(
layer.full_name(), activation=activation, weight=weight
)
def add_name_config(
self,
layer_name: Union[str, list],
activation: QuanterFactory = None,
weight: QuanterFactory = None,
):
r"""
Set the quantization config by full name of layer. Its priority is
lower than `add_layer_config`.
Args:
layer_name(Union[str, list]): One or a list of layers' full name.
activation(QuanterFactory): Quanter used for activations.
weight(QuanterFactory): Quanter used for weights.
Examples:
.. code-block:: python
import paddle
from paddle.nn import Linear
from paddle.quantization import QuantConfig
from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
class Model(paddle.nn.Layer):
def __init__(self):
super(Model, self).__init__()
self.fc = Linear(576, 120)
model = Model()
quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.9)
q_config = QuantConfig(activation=None, weight=None)
q_config.add_name_config([model.fc.full_name()], activation=quanter, weight=quanter)
print(q_config)
"""
if isinstance(layer_name, str):
config = SingleLayerConfig(activation, weight)
self._prefix2config[layer_name] = config
if isinstance(layer_name, list):
for _element in layer_name:
self.add_name_config(
_element, activation=activation, weight=weight
)
def add_type_config(
self,
layer_type: Union[type, list],
activation: QuanterFactory = None,
weight: QuanterFactory = None,
):
r"""
Set the quantization config by the type of layer. The `layer_type` should be
subclass of `paddle.nn.Layer`. Its priority is lower than `add_layer_config`
and `add_name_config`.
Args:
layer_type(Union[type, list]): One or a list of layers' type. It should be subclass of
`paddle.nn.Layer`. Python build-in function `type()` can be used to get the type of a layer.
activation(QuanterFactory): Quanter used for activations.
weight(QuanterFactory): Quanter used for weights.
Examples:
.. code-block:: python
import paddle
from paddle.nn import Linear
from paddle.quantization import QuantConfig
from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
class Model(paddle.nn.Layer):
def __init__(self):
super(Model, self).__init__()
self.fc = Linear(576, 120)
model = Model()
quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.9)
q_config = QuantConfig(activation=None, weight=None)
q_config.add_type_config([Linear], activation=quanter, weight=quanter)
print(q_config)
"""
if isinstance(layer_type, type) and issubclass(
layer_type, paddle.nn.Layer
):
config = SingleLayerConfig(activation, weight)
self._type2config[layer_type] = config
if isinstance(layer_type, list):
for _element in layer_type:
self.add_type_config(
_element, activation=activation, weight=weight
)
def add_qat_layer_mapping(self, source: type, target: type):
r"""
Add rules converting layers to simulated quantization layers
before quantization-aware training. It will convert layers
with type `source` to layers with type `target`. `source` and
`target` should be subclass of `paddle.nn.Layer`. And a default
mapping is provided by property `default_qat_layer_mapping`.
Args:
source(type): The type of layers that will be converted.
target(type): The type of layers that will be converted to.
Examples:
.. code-block:: python
from paddle.nn import Conv2D
from paddle.quantization import QuantConfig
from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.9)
q_config = QuantConfig(activation=None, weight=None)
class CustomizedQuantedConv2D:
def forward(self, x):
pass
# add some code for quantization simulation
q_config.add_qat_layer_mapping(Conv2D, CustomizedQuantedConv2D)
"""
assert isinstance(source, type) and issubclass(
source, paddle.nn.Layer
), "The source layer to be placed should be a subclass of paddle.nn.Layer"
assert isinstance(target, type) and issubclass(
source, paddle.nn.Layer
), "The target layer should be a subclass of paddle.nn.qat.Layer"
self._qat_layer_mapping[source] = target
def add_customized_leaf(self, layer_type: type):
r"""
Declare the customized layer as leaf of model for quantization.
The leaf layer is quantized as one layer. The sublayers of
leaf layer will not be quantized.
Args:
layer_type(type): The type of layer to be declared as leaf.
Examples:
.. code-block:: python
from paddle.nn import Sequential
from paddle.quantization import QuantConfig
from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
q_config = QuantConfig(activation=None, weight=None)
q_config.add_customized_leaf(Sequential)
"""
self._customized_leaves.append(layer_type)
@property
def customized_leaves(self):
r"""
Get all the customized leaves.
"""
return self._customized_leaves
def _need_observe(self, layer: Layer):
r"""
Whether the layer should be observed by observer.
"""
return self._is_leaf(layer) and self._has_observer_config(layer)
def _has_observer_config(self, layer: Layer):
r"""
Whether the layer has been configured for activation quantization.
"""
_config = self._get_config_by_layer(layer)
return _config is not None and _config.activation is not None
def _is_leaf(self, layer: Layer):
return (
self._is_default_leaf(layer)
or self._is_real_leaf(layer)
or self._is_customized_leaf(layer)
)
def _is_default_leaf(self, layer: Layer):
return type(layer) in DEFAULT_LEAVES
def _is_real_leaf(self, layer: Layer):
r"""
The leaf is real leaf when it has no sublayers.
"""
return layer._sub_layers is None or len(layer._sub_layers) == 0
def _is_customized_leaf(self, layer: Layer):
return type(layer) in self.customized_leaves
def _get_observer(self, layer: Layer):
r"""
Create an instance of observer or quanter according to the
given layer's quantization config.
"""
_config = self._get_config_by_layer(layer)
_observer = None if _config is None else _config.activation
return None if _observer is None else _observer._instance(layer)
@property
def qat_layer_mappings(self):
return self._qat_layer_mapping
@property
def default_qat_layer_mapping(self):
return DEFAULT_QAT_LAYER_MAPPINGS
@property
def global_config(self) -> SingleLayerConfig:
return self._global_config
def _get_config_by_layer(self, layer) -> SingleLayerConfig:
return self._layer2config.get(layer, None)
def _is_quantifiable(self, layer: Layer):
r"""
The layer is quantifiable when it configured by activation quanter/observer
or weight quanter/observer.
"""
return layer in self._layer2config
def _specify(self, model: Layer):
r"""
Specify the quantization config of each sublayer in model.
For each layer in sublayers of mode,
1. Set the config by global config
2. Overwrite the config with parents' config
3. Overwrite the config with config set by layer's type
4. Overwrite the config with config set by layer's full name
5. Overwrite the config with config set by layer
Args:
model(Layer): The model to be specified by the config.
Examples:
.. code-block:: python
import paddle
from paddle.nn import Linear, Sequential
from paddle.quantization import QuantConfig
from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
class Model(paddle.nn.Layer):
def __init__(self):
super(Model, self).__init__()
self.fc = Sequential(Linear(576, 120),Linear(576, 120))
model = Model()
quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.9)
q_config = QuantConfig(activation=None, weight=None)
q_config.add_layer_config([model.fc], activation=quanter, weight=quanter)
q_config._specify(model)
"""
self._model = model
self._specify_helper(self._model)
def _specify_helper(self, model: Layer):
for child in model.children():
layer_prefix = child.full_name()
config = self._layer2config.get(model, self.global_config)
config = self._type2config.get(type(child), config)
config = self._prefix2config.get(layer_prefix, config)
if config is not None:
self._layer2config[child] = config
self._specify_helper(child)
return self
def details(self) -> str:
r"""
Get the formated details of current config.
"""
return self._details_helper(self._model)
def _details_helper(self, layer: Layer):
extra_lines = []
sublayer_lines = []
for name, sublayer in layer.named_children():
sublayer_str = self._details_helper(sublayer)
sublayer_str = self._addindent(sublayer_str, 2)
sublayer_lines.append(
'('
+ name
+ '): '
+ sublayer_str
+ ', '
+ str(self._layer2config[sublayer])
)
final_str = layer.__class__.__name__ + '('
if extra_lines:
if len(extra_lines) > 1:
final_str += '\n ' + '\n '.join(extra_lines) + '\n'
elif len(extra_lines) == 1:
final_str += extra_lines[0]
if sublayer_lines:
final_str += '\n ' + '\n '.join(sublayer_lines) + '\n'
final_str += ')'
return final_str
def _addindent(self, string, indent):
s1 = string.split('\n')
if len(s1) == 1:
return string
s2 = []
for idx, line in enumerate(s1):
if idx > 0:
s2.append(str((indent * ' ') + line))
return s1[0] + '\n' + '\n'.join(s2)
def __str__(self):
result = ""
result += f"Global config:\n{self._global_config}\n"
if len(self._type2config) > 0:
result += f"Layer type config:\n{self._type2config}\n"
if len(self._prefix2config) > 0:
result += f"Layer prefix config: \n{self._prefix2config}\n"
return result
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import inspect
from functools import partial
from paddle.nn import Layer
from .base_quanter import BaseQuanter
class ClassWithArguments(metaclass=abc.ABCMeta):
def __init__(self, *args, **kwargs):
self._args = args
self._kwargs = kwargs
@property
def args(self):
return self._args
@property
def kwargs(self):
return self._kwargs
@abc.abstractmethod
def _get_class(self):
pass
def __str__(self):
args_str = ",".join(
list(self.args) + [f"{k}={v}" for k, v in self.kwargs.items()]
)
return f"{self.__class__.__name__}({args_str})"
def __repr__(self):
return self.__str__()
class QuanterFactory(ClassWithArguments):
r"""
The factory holds the quanter's class information and
the arguments used to create quanter instance.
"""
def __init__(self, *args, **kwargs):
super(QuanterFactory, self).__init__(*args, **kwargs)
self.partial_class = None
def _instance(self, layer: Layer) -> BaseQuanter:
r"""
Create an instance of quanter for target layer.
"""
if self.partial_class is None:
self.partial_class = partial(
self._get_class(), *self.args, **self.kwargs
)
return self.partial_class(layer)
def quanter(class_name):
r"""
Annotation to declare a factory class for quanter.
Args:
class_name (str) - The name of factory class to be declared.
Examples:
.. code-block:: python
# Given codes in ./customized_quanter.py
from paddle.quantization import quanter
from paddle.quantization import BaseQuanter
@quanter("CustomizedQuanter")
class CustomizedQuanterLayer(BaseQuanter):
def __init__(self, arg1, kwarg1=None):
pass
# Used in ./test.py
# from .customized_quanter import CustomizedQuanter
from paddle.quantization import QuantConfig
arg1_value = "test"
kwarg1_value = 20
quanter = CustomizedQuanter(arg1_value, kwarg1=kwarg1_value)
q_config = QuantConfig(activation=quanter, weight=quanter)
"""
def wrapper(target_class):
init_function_str = f"""
def init_function(self, *args, **kwargs):
super(type(self), self).__init__(*args, **kwargs)
import importlib
module = importlib.import_module("{target_class.__module__}")
my_class = getattr(module, "{target_class.__name__}")
globals()["{target_class.__name__}"] = my_class
def get_class_function(self):
return {target_class.__name__}
locals()["init_function"]=init_function
locals()["get_class_function"]=get_class_function
"""
exec(init_function_str)
frm = inspect.stack()[1]
mod = inspect.getmodule(frm[0])
new_class = type(
class_name,
(QuanterFactory,),
{
"__init__": locals()["init_function"],
"_get_class": locals()["get_class_function"],
},
)
setattr(mod, class_name, new_class)
if "__all__" in mod.__dict__:
mod.__all__.append(class_name)
return target_class
return wrapper
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .abs_max import FakeQuanterWithAbsMaxObserver
__all__ = ["FakeQuanterWithAbsMaxObserver"]
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle import _legacy_C_ops
from paddle.fluid.framework import _varbase_creator
from paddle.framework import ParamAttr
from paddle.nn.initializer import Constant
from paddle.utils import unique_name
from ..base_quanter import BaseQuanter
from ..factory import QuanterFactory
class FakeQuanterWithAbsMaxObserver(QuanterFactory):
r"""
Compute quantization parameters and simulate quantization.
It collects maximum absolute values of target tensor with moving average.
The average value will be used as quantization scale to quantize and
dequantize the tensor.
And it is symmetric uniform quantization which means the zero point is always 0.
The computational formula of moving average is described as below:
.. math::
state = rate * state + 1
accum = rate * accum + max(abs(x))
scale = accum / state
Where:
- :math:`x` is the input tensor.
- :math:`state` and :math:`accum` are zero-initialized accumulators.
- :math:`rate` is moving average rate.
- :math:`scale` is quantization scale
And the computational formula of simulate quantization is:
.. math::
range = 2^{bit\_length - 1} - 1
out = round(x / scale * range) * scale / range
Where:
- :math:`{bit\_length}` is the length of bits.
- :math:`x` is the input tensor and :math:`out` is the output of simulate quantization.
Args:
moving_rate(float, optional): The rate of moving average.
bit_length(int, optional): Number of bits to represent an quantized integer in binary.
dtype(str, optional): The data type of input tensor.
name (str, optional): This parameter is used by developers to print debugging information. \
For details, please refer to :ref:`api_guide_Name`. Default is None.
Examples:
.. code-block:: python
from paddle.quantization import QuantConfig
from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.99)
q_config = QuantConfig(activation=quanter, weight=quanter)
"""
def __init__(
self,
moving_rate=0.9,
bit_length=8,
dtype='float32',
name=None,
):
super(FakeQuanterWithAbsMaxObserver, self).__init__(
name=name,
moving_rate=moving_rate,
bit_length=bit_length,
dtype=dtype,
)
def _get_class(self):
return FakeQuanterWithAbsMaxObserverLayer
class FakeQuanterWithAbsMaxObserverLayer(BaseQuanter):
def __init__(
self,
layer,
name=None,
moving_rate=0.9,
bit_length=8,
dtype='float32',
):
super(FakeQuanterWithAbsMaxObserverLayer, self).__init__()
self._moving_rate = moving_rate
self._bit_length = bit_length
scale_prefix = (
"{}.scale".format(name) if name else 'quant_dequant.scale'
)
scale_attr = ParamAttr(
name=unique_name.generate(scale_prefix),
initializer=Constant(0.001),
trainable=False,
)
self._scale = self.create_parameter(
shape=[1], attr=scale_attr, dtype=dtype
)
self._scale.stop_gradient = True
state_prefix = (
"{}.state".format(name) if name else 'quant_dequant.state'
)
state_attr = ParamAttr(
name=unique_name.generate(state_prefix),
initializer=Constant(1),
trainable=False,
)
self._state = self.create_parameter(
shape=[1], attr=state_attr, dtype=dtype
)
self._state.stop_gradient = True
accum_prefix = (
"{}.accum".format(name) if name else 'quant_dequant.accum'
)
accum_attr = ParamAttr(
name=unique_name.generate(accum_prefix),
initializer=Constant(1),
trainable=False,
)
self._accum = self.create_parameter(
shape=[1], attr=accum_attr, dtype=dtype
)
self._accum.stop_gradient = True
def forward(self, input):
attrs = (
'moving_rate',
self._moving_rate,
'bit_length',
self._bit_length,
'is_test',
not self.training,
)
quant_out = _varbase_creator(
type=input.type,
name="{}.quantized.dequantized".format(input.name),
shape=input.shape,
dtype=input.dtype,
persistable=False,
)
state = self._state if self.training else None
accum = self._accum if self.training else None
(
out,
_,
_,
_,
) = _legacy_C_ops.fake_quantize_dequantize_moving_average_abs_max(
input,
self._scale,
accum,
state,
quant_out,
self._scale,
state,
accum,
*attrs
)
return out
def bit_length(self):
return self.bits
def quant_axis(self):
return None
def scales(self):
return self._scale
def zero_points(self):
return None
add_subdirectory(quantization)
file(
GLOB TEST_OPS
RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}"
......
file(
GLOB TEST_OPS
RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}"
"test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
foreach(src ${TEST_OPS})
py_test(${src} SRCS ${src}.py)
endforeach()
# copyright (c) 2022 paddlepaddle authors. all rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from typing import Iterable, Union
import numpy as np
import paddle
from paddle.nn import Linear
from paddle.quantization.base_quanter import BaseQuanter
from paddle.quantization.factory import quanter
linear_quant_axis = 1
@quanter("CustomizedQuanter")
class CustomizedQuanterLayer(BaseQuanter):
def __init__(self, layer, bit_length=8, kwargs1=None):
super(CustomizedQuanterLayer, self).__init__()
self._layer = layer
self._bit_length = bit_length
self._kwargs1 = kwargs1
def scales(self) -> Union[paddle.Tensor, np.ndarray]:
return None
def bit_length(self):
return self._bit_length
def quant_axis(self) -> Union[int, Iterable]:
return linear_quant_axis if isinstance(self._layer, Linear) else None
def zero_points(self) -> Union[paddle.Tensor, np.ndarray]:
return None
def forward(self, input):
return input
class TestCustomizedQuanter(unittest.TestCase):
def test_details(self):
layer = Linear(5, 5)
bit_length = 4
quanter = CustomizedQuanter( # noqa: F821
bit_length=bit_length, kwargs1="test"
)
quanter = quanter._instance(layer)
self.assertEqual(quanter.bit_length(), bit_length)
self.assertEqual(quanter.quant_axis(), linear_quant_axis)
self.assertEqual(quanter._kwargs1, 'test')
if __name__ == '__main__':
unittest.main()
# copyright (c) 2022 paddlepaddle authors. all rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import paddle.nn.functional as F
from paddle.nn import Conv2D, Linear, ReLU, Sequential
from paddle.quantization import QuantConfig
from paddle.quantization.base_quanter import BaseQuanter
from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
class LeNetDygraph(paddle.nn.Layer):
def __init__(self, num_classes=10):
super(LeNetDygraph, self).__init__()
self.num_classes = num_classes
self.features = Sequential(
Conv2D(3, 6, 3, stride=1, padding=1),
ReLU(),
paddle.nn.MaxPool2D(2, 2),
Conv2D(6, 16, 5, stride=1, padding=0),
ReLU(),
paddle.nn.MaxPool2D(2, 2),
)
if num_classes > 0:
self.fc = Sequential(
Linear(576, 120), Linear(120, 84), Linear(84, 10)
)
def forward(self, inputs):
x = self.features(inputs)
if self.num_classes > 0:
x = paddle.flatten(x, 1)
x = self.fc(x)
out = F.relu(x)
out = F.relu(out)
return out
class TestQuantConfig(unittest.TestCase):
def setUp(self):
self.model = LeNetDygraph()
self.quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.9)
def test_global_config(self):
self.q_config = QuantConfig(
activation=self.quanter, weight=self.quanter
)
self.q_config._specify(self.model)
self.assertIsNotNone(self.q_config.global_config.activation)
self.assertIsNotNone(self.q_config.global_config.weight)
for layer in self.model.sublayers():
config = self.q_config._get_config_by_layer(layer)
self.assertTrue(config.activation == self.quanter)
self.assertTrue(config.weight == self.quanter)
def assert_just_linear_weight_configure(self, model, config):
for layer in model.sublayers():
layer_config = config._get_config_by_layer(layer)
if type(layer) == Linear:
self.assertIsNone(layer_config.activation)
self.assertEqual(layer_config.weight, self.quanter)
self.assertTrue(config._is_quantifiable(layer))
elif type(layer) == Conv2D:
self.assertIsNone(layer_config)
self.assertFalse(config._is_quantifiable(layer))
def test_add_layer_config(self):
self.q_config = QuantConfig(activation=None, weight=None)
self.q_config.add_layer_config(
[self.model.fc], activation=None, weight=self.quanter
)
self.q_config._specify(self.model)
self.assert_just_linear_weight_configure(self.model, self.q_config)
def test_add_name_config(self):
self.q_config = QuantConfig(activation=None, weight=None)
self.q_config.add_name_config(
[self.model.fc.full_name()], activation=None, weight=self.quanter
)
self.q_config._specify(self.model)
self.assert_just_linear_weight_configure(self.model, self.q_config)
def test_add_type_config(self):
self.q_config = QuantConfig(activation=None, weight=None)
self.q_config.add_type_config(
[Linear], activation=None, weight=self.quanter
)
self.q_config._specify(self.model)
self.assert_just_linear_weight_configure(self.model, self.q_config)
def test_add_qat_layer_mapping(self):
self.q_config = QuantConfig(activation=None, weight=None)
self.q_config.add_qat_layer_mapping(Sequential, Conv2D)
self.assertTrue(Sequential in self.q_config.qat_layer_mappings)
self.assertTrue(
Sequential not in self.q_config.default_qat_layer_mapping
)
def test_add_customized_leaf(self):
self.q_config = QuantConfig(activation=None, weight=None)
self.q_config.add_customized_leaf(Sequential)
self.assertTrue(Sequential in self.q_config.customized_leaves)
self.assertTrue(self.q_config._is_customized_leaf(self.model.fc))
self.assertTrue(self.q_config._is_leaf(self.model.fc))
self.assertFalse(self.q_config._is_default_leaf(self.model.fc))
self.assertFalse(self.q_config._is_real_leaf(self.model.fc))
def test_need_observe(self):
self.q_config = QuantConfig(activation=None, weight=None)
self.q_config.add_layer_config(
[self.model.fc], activation=self.quanter, weight=self.quanter
)
self.q_config.add_customized_leaf(Sequential)
self.q_config._specify(self.model)
self.assertTrue(self.q_config._has_observer_config(self.model.fc))
self.assertTrue(self.q_config._need_observe(self.model.fc))
def test__get_observer(self):
self.q_config = QuantConfig(activation=None, weight=None)
self.q_config.add_layer_config(
[self.model.fc], activation=self.quanter, weight=self.quanter
)
self.q_config._specify(self.model)
observer = self.q_config._get_observer(self.model.fc)
self.assertIsInstance(observer, BaseQuanter)
def test_details(self):
self.q_config = QuantConfig(
activation=self.quanter, weight=self.quanter
)
self.q_config._specify(self.model)
self.assertIsNotNone(self.q_config.details())
self.assertIsNotNone(self.q_config.__str__())
if __name__ == '__main__':
unittest.main()
......@@ -386,6 +386,8 @@ packages=['paddle',
'paddle.incubate.distributed.models',
'paddle.incubate.distributed.models.moe',
'paddle.incubate.distributed.models.moe.gate',
'paddle.quantization',
'paddle.quantization.quanters',
'paddle.sparse',
'paddle.sparse.nn',
'paddle.sparse.nn.layer',
......
......@@ -1254,6 +1254,8 @@ def get_setup_parameters():
'paddle.incubate.distributed.models',
'paddle.incubate.distributed.models.moe',
'paddle.incubate.distributed.models.moe.gate',
'paddle.quantization',
'paddle.quantization.quanters',
'paddle.sparse',
'paddle.sparse.nn',
'paddle.sparse.nn.layer',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册