Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
941444aa
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
941444aa
编写于
12月 23, 2022
作者:
W
whs
提交者:
GitHub
12月 23, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add configure of quantization for dynamic graph (#48000)
上级
ae544586
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
1107 addition
and
11 deletion
+1107
-11
python/paddle/__init__.py
python/paddle/__init__.py
+1
-0
python/paddle/quantization/__init__.py
python/paddle/quantization/__init__.py
+22
-11
python/paddle/quantization/base_quanter.py
python/paddle/quantization/base_quanter.py
+66
-0
python/paddle/quantization/config.py
python/paddle/quantization/config.py
+444
-0
python/paddle/quantization/factory.py
python/paddle/quantization/factory.py
+131
-0
python/paddle/quantization/quanters/__init__.py
python/paddle/quantization/quanters/__init__.py
+17
-0
python/paddle/quantization/quanters/abs_max.py
python/paddle/quantization/quanters/abs_max.py
+194
-0
python/paddle/tests/CMakeLists.txt
python/paddle/tests/CMakeLists.txt
+2
-0
python/paddle/tests/quantization/CMakeLists.txt
python/paddle/tests/quantization/CMakeLists.txt
+9
-0
python/paddle/tests/quantization/test_customized_quanter.py
python/paddle/tests/quantization/test_customized_quanter.py
+66
-0
python/paddle/tests/quantization/test_quant.py
python/paddle/tests/quantization/test_quant.py
+151
-0
python/setup.py.in
python/setup.py.in
+2
-0
setup.py
setup.py
+2
-0
未找到文件。
python/paddle/__init__.py
浏览文件 @
941444aa
...
...
@@ -85,6 +85,7 @@ import paddle.vision # noqa: F401
import
paddle.audio
# noqa: F401
import
paddle.geometric
# noqa: F401
import
paddle.sparse
# noqa: F401
import
paddle.quantization
# noqa: F401
from
.tensor.attribute
import
is_complex
# noqa: F401
from
.tensor.attribute
import
is_integer
# noqa: F401
...
...
python/paddle/quantization/__init__.py
浏览文件 @
941444aa
...
...
@@ -12,35 +12,46 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
..
.
fluid.contrib.slim.quantization.imperative.ptq_config
import
(
from
..fluid.contrib.slim.quantization.imperative.ptq_config
import
(
PTQConfig
,
default_ptq_config
,
)
from
..
.
fluid.contrib.slim.quantization.imperative.ptq_quantizer
import
(
from
..fluid.contrib.slim.quantization.imperative.ptq_quantizer
import
(
BaseQuantizer
,
)
from
..
.
fluid.contrib.slim.quantization.imperative.ptq_quantizer
import
(
from
..fluid.contrib.slim.quantization.imperative.ptq_quantizer
import
(
AbsmaxQuantizer
,
)
from
..
.
fluid.contrib.slim.quantization.imperative.ptq_quantizer
import
(
from
..fluid.contrib.slim.quantization.imperative.ptq_quantizer
import
(
PerChannelAbsmaxQuantizer
,
)
from
..
.
fluid.contrib.slim.quantization.imperative.ptq_quantizer
import
(
from
..fluid.contrib.slim.quantization.imperative.ptq_quantizer
import
(
KLQuantizer
,
)
from
..
.
fluid.contrib.slim.quantization.imperative.ptq_quantizer
import
(
from
..fluid.contrib.slim.quantization.imperative.ptq_quantizer
import
(
HistQuantizer
,
)
from
..
.
fluid.contrib.slim.quantization.imperative.ptq_quantizer
import
(
from
..fluid.contrib.slim.quantization.imperative.ptq_quantizer
import
(
SUPPORT_ACT_QUANTIZERS
,
)
from
..
.
fluid.contrib.slim.quantization.imperative.ptq_quantizer
import
(
from
..fluid.contrib.slim.quantization.imperative.ptq_quantizer
import
(
SUPPORT_WT_QUANTIZERS
,
)
from
..
.
fluid.contrib.slim.quantization.imperative.ptq_registry
import
(
from
..fluid.contrib.slim.quantization.imperative.ptq_registry
import
(
PTQRegistry
,
)
from
..
.
fluid.contrib.slim.quantization.imperative.ptq
import
ImperativePTQ
from
..
.
fluid.contrib.slim.quantization.imperative.qat
import
(
from
..fluid.contrib.slim.quantization.imperative.ptq
import
ImperativePTQ
from
..fluid.contrib.slim.quantization.imperative.qat
import
(
ImperativeQuantAware
,
)
from
.config
import
QuantConfig
from
.base_quanter
import
BaseQuanter
from
.factory
import
quanter
__all__
=
[
"QuantConfig"
,
"BaseQuanter"
,
"quanter"
,
]
python/paddle/quantization/base_quanter.py
0 → 100644
浏览文件 @
941444aa
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
abc
from
collections.abc
import
Iterable
from
typing
import
Union
import
numpy
as
np
import
paddle
from
paddle.nn
import
Layer
class
BaseQuanter
(
Layer
,
metaclass
=
abc
.
ABCMeta
):
r
"""
Built-in quanters and customized quanters should extend this base quanter
and implement abstract methods.
"""
def
__init__
(
self
):
super
(
BaseQuanter
,
self
).
__init__
()
@
abc
.
abstractmethod
def
forward
(
self
,
input
):
pass
@
abc
.
abstractmethod
def
scales
(
self
)
->
Union
[
paddle
.
Tensor
,
np
.
ndarray
]:
r
"""
Get the scales used for quantization.
It can be none which meams the quanter didn't hold scales for quantization.
"""
pass
@
abc
.
abstractmethod
def
zero_points
(
self
)
->
Union
[
paddle
.
Tensor
,
np
.
ndarray
]:
r
"""
Get the zero points used for quantization.
It can be none which meams the quanter didn't hold zero points for quantization.
"""
pass
@
abc
.
abstractmethod
def
quant_axis
(
self
)
->
Union
[
int
,
Iterable
]:
r
"""
Get the axis of quantization. None means tensor-wise quantization.
"""
pass
@
abc
.
abstractmethod
def
bit_length
(
self
):
r
"""
Get the bit length of quantization.
"""
pass
python/paddle/quantization/config.py
0 → 100644
浏览文件 @
941444aa
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
copy
from
typing
import
Dict
,
Union
import
paddle
import
paddle.nn
as
nn
from
paddle.nn
import
Layer
from
.factory
import
QuanterFactory
# TODO: Implement quanted layer and fill the mapping dict
DEFAULT_QAT_LAYER_MAPPINGS
:
Dict
[
Layer
,
Layer
]
=
{}
DEFAULT_LEAVES
=
[
nn
.
ReLU
,
nn
.
AvgPool2D
]
class
SingleLayerConfig
(
object
):
r
"""
Configure how to quantize the activations and weights of a single layer.
Args:
activation(QuanterFactory): The factory to create instance of quanter used to quantize activations.
weight(QuanterFactory): The factory to create instance of quanter used to quantize weights.
"""
def
__init__
(
self
,
activation
:
QuanterFactory
,
weight
:
QuanterFactory
):
self
.
_activation
=
activation
self
.
_weight
=
weight
@
property
def
activation
(
self
):
return
self
.
_activation
@
property
def
weight
(
self
):
return
self
.
_weight
def
__str__
(
self
):
return
f
"activation:
{
self
.
_activation
}
\n
weight:
{
self
.
_weight
}
"
class
QuantConfig
(
object
):
r
"""
Configure how to quantize a model or a part of the model. It will map each layer to
an instance of SingleLayerConfig by the settings. It provides diverse methods to set
the strategies of quantization.
Args:
activation(QuanterFactory): The global quantizer used to quantize the activations.
weight(QuanterFactory): The global quantizer used to quantize the weights.
Examples:
.. code-block:: python
from paddle.quantization import QuantConfig
from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.9)
q_config = QuantConfig(activation=quanter, weight=quanter)
print(q_config)
"""
def
__init__
(
self
,
activation
:
QuanterFactory
,
weight
:
QuanterFactory
):
if
activation
is
None
and
weight
is
None
:
self
.
_global_config
=
None
else
:
self
.
_global_config
=
SingleLayerConfig
(
activation
,
weight
)
self
.
_layer2config
=
{}
self
.
_prefix2config
=
{}
self
.
_type2config
=
{}
self
.
_model
=
None
self
.
_qat_layer_mapping
=
copy
.
deepcopy
(
DEFAULT_QAT_LAYER_MAPPINGS
)
self
.
_customized_leaves
=
[]
def
add_layer_config
(
self
,
layer
:
Union
[
Layer
,
list
],
activation
:
QuanterFactory
=
None
,
weight
:
QuanterFactory
=
None
,
):
r
"""
Set the quantization config by layer. It has the highest priority among
all the setting methods.
Args:
layer(Union[Layer, list]): One or a list of layers.
activation(QuanterFactory): Quanter used for activations.
weight(QuanterFactory): Quanter used for weights.
Examples:
.. code-block:: python
import paddle
from paddle.nn import Linear
from paddle.quantization import QuantConfig
from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
class Model(paddle.nn.Layer):
def __init__(self):
super(Model, self).__init__()
self.fc = Linear(576, 120)
model = Model()
quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.9)
q_config = QuantConfig(activation=None, weight=None)
q_config.add_layer_config([model.fc], activation=quanter, weight=quanter)
print(q_config)
"""
if
isinstance
(
layer
,
list
):
for
_element
in
layer
:
self
.
add_layer_config
(
_element
,
activation
=
activation
,
weight
=
weight
)
else
:
self
.
add_name_config
(
layer
.
full_name
(),
activation
=
activation
,
weight
=
weight
)
def
add_name_config
(
self
,
layer_name
:
Union
[
str
,
list
],
activation
:
QuanterFactory
=
None
,
weight
:
QuanterFactory
=
None
,
):
r
"""
Set the quantization config by full name of layer. Its priority is
lower than `add_layer_config`.
Args:
layer_name(Union[str, list]): One or a list of layers' full name.
activation(QuanterFactory): Quanter used for activations.
weight(QuanterFactory): Quanter used for weights.
Examples:
.. code-block:: python
import paddle
from paddle.nn import Linear
from paddle.quantization import QuantConfig
from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
class Model(paddle.nn.Layer):
def __init__(self):
super(Model, self).__init__()
self.fc = Linear(576, 120)
model = Model()
quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.9)
q_config = QuantConfig(activation=None, weight=None)
q_config.add_name_config([model.fc.full_name()], activation=quanter, weight=quanter)
print(q_config)
"""
if
isinstance
(
layer_name
,
str
):
config
=
SingleLayerConfig
(
activation
,
weight
)
self
.
_prefix2config
[
layer_name
]
=
config
if
isinstance
(
layer_name
,
list
):
for
_element
in
layer_name
:
self
.
add_name_config
(
_element
,
activation
=
activation
,
weight
=
weight
)
def
add_type_config
(
self
,
layer_type
:
Union
[
type
,
list
],
activation
:
QuanterFactory
=
None
,
weight
:
QuanterFactory
=
None
,
):
r
"""
Set the quantization config by the type of layer. The `layer_type` should be
subclass of `paddle.nn.Layer`. Its priority is lower than `add_layer_config`
and `add_name_config`.
Args:
layer_type(Union[type, list]): One or a list of layers' type. It should be subclass of
`paddle.nn.Layer`. Python build-in function `type()` can be used to get the type of a layer.
activation(QuanterFactory): Quanter used for activations.
weight(QuanterFactory): Quanter used for weights.
Examples:
.. code-block:: python
import paddle
from paddle.nn import Linear
from paddle.quantization import QuantConfig
from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
class Model(paddle.nn.Layer):
def __init__(self):
super(Model, self).__init__()
self.fc = Linear(576, 120)
model = Model()
quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.9)
q_config = QuantConfig(activation=None, weight=None)
q_config.add_type_config([Linear], activation=quanter, weight=quanter)
print(q_config)
"""
if
isinstance
(
layer_type
,
type
)
and
issubclass
(
layer_type
,
paddle
.
nn
.
Layer
):
config
=
SingleLayerConfig
(
activation
,
weight
)
self
.
_type2config
[
layer_type
]
=
config
if
isinstance
(
layer_type
,
list
):
for
_element
in
layer_type
:
self
.
add_type_config
(
_element
,
activation
=
activation
,
weight
=
weight
)
def
add_qat_layer_mapping
(
self
,
source
:
type
,
target
:
type
):
r
"""
Add rules converting layers to simulated quantization layers
before quantization-aware training. It will convert layers
with type `source` to layers with type `target`. `source` and
`target` should be subclass of `paddle.nn.Layer`. And a default
mapping is provided by property `default_qat_layer_mapping`.
Args:
source(type): The type of layers that will be converted.
target(type): The type of layers that will be converted to.
Examples:
.. code-block:: python
from paddle.nn import Conv2D
from paddle.quantization import QuantConfig
from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.9)
q_config = QuantConfig(activation=None, weight=None)
class CustomizedQuantedConv2D:
def forward(self, x):
pass
# add some code for quantization simulation
q_config.add_qat_layer_mapping(Conv2D, CustomizedQuantedConv2D)
"""
assert
isinstance
(
source
,
type
)
and
issubclass
(
source
,
paddle
.
nn
.
Layer
),
"The source layer to be placed should be a subclass of paddle.nn.Layer"
assert
isinstance
(
target
,
type
)
and
issubclass
(
source
,
paddle
.
nn
.
Layer
),
"The target layer should be a subclass of paddle.nn.qat.Layer"
self
.
_qat_layer_mapping
[
source
]
=
target
def
add_customized_leaf
(
self
,
layer_type
:
type
):
r
"""
Declare the customized layer as leaf of model for quantization.
The leaf layer is quantized as one layer. The sublayers of
leaf layer will not be quantized.
Args:
layer_type(type): The type of layer to be declared as leaf.
Examples:
.. code-block:: python
from paddle.nn import Sequential
from paddle.quantization import QuantConfig
from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
q_config = QuantConfig(activation=None, weight=None)
q_config.add_customized_leaf(Sequential)
"""
self
.
_customized_leaves
.
append
(
layer_type
)
@
property
def
customized_leaves
(
self
):
r
"""
Get all the customized leaves.
"""
return
self
.
_customized_leaves
def
_need_observe
(
self
,
layer
:
Layer
):
r
"""
Whether the layer should be observed by observer.
"""
return
self
.
_is_leaf
(
layer
)
and
self
.
_has_observer_config
(
layer
)
def
_has_observer_config
(
self
,
layer
:
Layer
):
r
"""
Whether the layer has been configured for activation quantization.
"""
_config
=
self
.
_get_config_by_layer
(
layer
)
return
_config
is
not
None
and
_config
.
activation
is
not
None
def
_is_leaf
(
self
,
layer
:
Layer
):
return
(
self
.
_is_default_leaf
(
layer
)
or
self
.
_is_real_leaf
(
layer
)
or
self
.
_is_customized_leaf
(
layer
)
)
def
_is_default_leaf
(
self
,
layer
:
Layer
):
return
type
(
layer
)
in
DEFAULT_LEAVES
def
_is_real_leaf
(
self
,
layer
:
Layer
):
r
"""
The leaf is real leaf when it has no sublayers.
"""
return
layer
.
_sub_layers
is
None
or
len
(
layer
.
_sub_layers
)
==
0
def
_is_customized_leaf
(
self
,
layer
:
Layer
):
return
type
(
layer
)
in
self
.
customized_leaves
def
_get_observer
(
self
,
layer
:
Layer
):
r
"""
Create an instance of observer or quanter according to the
given layer's quantization config.
"""
_config
=
self
.
_get_config_by_layer
(
layer
)
_observer
=
None
if
_config
is
None
else
_config
.
activation
return
None
if
_observer
is
None
else
_observer
.
_instance
(
layer
)
@
property
def
qat_layer_mappings
(
self
):
return
self
.
_qat_layer_mapping
@
property
def
default_qat_layer_mapping
(
self
):
return
DEFAULT_QAT_LAYER_MAPPINGS
@
property
def
global_config
(
self
)
->
SingleLayerConfig
:
return
self
.
_global_config
def
_get_config_by_layer
(
self
,
layer
)
->
SingleLayerConfig
:
return
self
.
_layer2config
.
get
(
layer
,
None
)
def
_is_quantifiable
(
self
,
layer
:
Layer
):
r
"""
The layer is quantifiable when it configured by activation quanter/observer
or weight quanter/observer.
"""
return
layer
in
self
.
_layer2config
def
_specify
(
self
,
model
:
Layer
):
r
"""
Specify the quantization config of each sublayer in model.
For each layer in sublayers of mode,
1. Set the config by global config
2. Overwrite the config with parents' config
3. Overwrite the config with config set by layer's type
4. Overwrite the config with config set by layer's full name
5. Overwrite the config with config set by layer
Args:
model(Layer): The model to be specified by the config.
Examples:
.. code-block:: python
import paddle
from paddle.nn import Linear, Sequential
from paddle.quantization import QuantConfig
from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
class Model(paddle.nn.Layer):
def __init__(self):
super(Model, self).__init__()
self.fc = Sequential(Linear(576, 120),Linear(576, 120))
model = Model()
quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.9)
q_config = QuantConfig(activation=None, weight=None)
q_config.add_layer_config([model.fc], activation=quanter, weight=quanter)
q_config._specify(model)
"""
self
.
_model
=
model
self
.
_specify_helper
(
self
.
_model
)
def
_specify_helper
(
self
,
model
:
Layer
):
for
child
in
model
.
children
():
layer_prefix
=
child
.
full_name
()
config
=
self
.
_layer2config
.
get
(
model
,
self
.
global_config
)
config
=
self
.
_type2config
.
get
(
type
(
child
),
config
)
config
=
self
.
_prefix2config
.
get
(
layer_prefix
,
config
)
if
config
is
not
None
:
self
.
_layer2config
[
child
]
=
config
self
.
_specify_helper
(
child
)
return
self
def
details
(
self
)
->
str
:
r
"""
Get the formated details of current config.
"""
return
self
.
_details_helper
(
self
.
_model
)
def
_details_helper
(
self
,
layer
:
Layer
):
extra_lines
=
[]
sublayer_lines
=
[]
for
name
,
sublayer
in
layer
.
named_children
():
sublayer_str
=
self
.
_details_helper
(
sublayer
)
sublayer_str
=
self
.
_addindent
(
sublayer_str
,
2
)
sublayer_lines
.
append
(
'('
+
name
+
'): '
+
sublayer_str
+
', '
+
str
(
self
.
_layer2config
[
sublayer
])
)
final_str
=
layer
.
__class__
.
__name__
+
'('
if
extra_lines
:
if
len
(
extra_lines
)
>
1
:
final_str
+=
'
\n
'
+
'
\n
'
.
join
(
extra_lines
)
+
'
\n
'
elif
len
(
extra_lines
)
==
1
:
final_str
+=
extra_lines
[
0
]
if
sublayer_lines
:
final_str
+=
'
\n
'
+
'
\n
'
.
join
(
sublayer_lines
)
+
'
\n
'
final_str
+=
')'
return
final_str
def
_addindent
(
self
,
string
,
indent
):
s1
=
string
.
split
(
'
\n
'
)
if
len
(
s1
)
==
1
:
return
string
s2
=
[]
for
idx
,
line
in
enumerate
(
s1
):
if
idx
>
0
:
s2
.
append
(
str
((
indent
*
' '
)
+
line
))
return
s1
[
0
]
+
'
\n
'
+
'
\n
'
.
join
(
s2
)
def
__str__
(
self
):
result
=
""
result
+=
f
"Global config:
\n
{
self
.
_global_config
}
\n
"
if
len
(
self
.
_type2config
)
>
0
:
result
+=
f
"Layer type config:
\n
{
self
.
_type2config
}
\n
"
if
len
(
self
.
_prefix2config
)
>
0
:
result
+=
f
"Layer prefix config:
\n
{
self
.
_prefix2config
}
\n
"
return
result
python/paddle/quantization/factory.py
0 → 100644
浏览文件 @
941444aa
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
abc
import
inspect
from
functools
import
partial
from
paddle.nn
import
Layer
from
.base_quanter
import
BaseQuanter
class
ClassWithArguments
(
metaclass
=
abc
.
ABCMeta
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
self
.
_args
=
args
self
.
_kwargs
=
kwargs
@
property
def
args
(
self
):
return
self
.
_args
@
property
def
kwargs
(
self
):
return
self
.
_kwargs
@
abc
.
abstractmethod
def
_get_class
(
self
):
pass
def
__str__
(
self
):
args_str
=
","
.
join
(
list
(
self
.
args
)
+
[
f
"
{
k
}
=
{
v
}
"
for
k
,
v
in
self
.
kwargs
.
items
()]
)
return
f
"
{
self
.
__class__
.
__name__
}
(
{
args_str
}
)"
def
__repr__
(
self
):
return
self
.
__str__
()
class
QuanterFactory
(
ClassWithArguments
):
r
"""
The factory holds the quanter's class information and
the arguments used to create quanter instance.
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
QuanterFactory
,
self
).
__init__
(
*
args
,
**
kwargs
)
self
.
partial_class
=
None
def
_instance
(
self
,
layer
:
Layer
)
->
BaseQuanter
:
r
"""
Create an instance of quanter for target layer.
"""
if
self
.
partial_class
is
None
:
self
.
partial_class
=
partial
(
self
.
_get_class
(),
*
self
.
args
,
**
self
.
kwargs
)
return
self
.
partial_class
(
layer
)
def
quanter
(
class_name
):
r
"""
Annotation to declare a factory class for quanter.
Args:
class_name (str) - The name of factory class to be declared.
Examples:
.. code-block:: python
# Given codes in ./customized_quanter.py
from paddle.quantization import quanter
from paddle.quantization import BaseQuanter
@quanter("CustomizedQuanter")
class CustomizedQuanterLayer(BaseQuanter):
def __init__(self, arg1, kwarg1=None):
pass
# Used in ./test.py
# from .customized_quanter import CustomizedQuanter
from paddle.quantization import QuantConfig
arg1_value = "test"
kwarg1_value = 20
quanter = CustomizedQuanter(arg1_value, kwarg1=kwarg1_value)
q_config = QuantConfig(activation=quanter, weight=quanter)
"""
def
wrapper
(
target_class
):
init_function_str
=
f
"""
def init_function(self, *args, **kwargs):
super(type(self), self).__init__(*args, **kwargs)
import importlib
module = importlib.import_module("
{
target_class
.
__module__
}
")
my_class = getattr(module, "
{
target_class
.
__name__
}
")
globals()["
{
target_class
.
__name__
}
"] = my_class
def get_class_function(self):
return
{
target_class
.
__name__
}
locals()["init_function"]=init_function
locals()["get_class_function"]=get_class_function
"""
exec
(
init_function_str
)
frm
=
inspect
.
stack
()[
1
]
mod
=
inspect
.
getmodule
(
frm
[
0
])
new_class
=
type
(
class_name
,
(
QuanterFactory
,),
{
"__init__"
:
locals
()[
"init_function"
],
"_get_class"
:
locals
()[
"get_class_function"
],
},
)
setattr
(
mod
,
class_name
,
new_class
)
if
"__all__"
in
mod
.
__dict__
:
mod
.
__all__
.
append
(
class_name
)
return
target_class
return
wrapper
python/paddle/quantization/quanters/__init__.py
0 → 100644
浏览文件 @
941444aa
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.abs_max
import
FakeQuanterWithAbsMaxObserver
__all__
=
[
"FakeQuanterWithAbsMaxObserver"
]
python/paddle/quantization/quanters/abs_max.py
0 → 100644
浏览文件 @
941444aa
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
paddle
import
_legacy_C_ops
from
paddle.fluid.framework
import
_varbase_creator
from
paddle.framework
import
ParamAttr
from
paddle.nn.initializer
import
Constant
from
paddle.utils
import
unique_name
from
..base_quanter
import
BaseQuanter
from
..factory
import
QuanterFactory
class
FakeQuanterWithAbsMaxObserver
(
QuanterFactory
):
r
"""
Compute quantization parameters and simulate quantization.
It collects maximum absolute values of target tensor with moving average.
The average value will be used as quantization scale to quantize and
dequantize the tensor.
And it is symmetric uniform quantization which means the zero point is always 0.
The computational formula of moving average is described as below:
.. math::
state = rate * state + 1
accum = rate * accum + max(abs(x))
scale = accum / state
Where:
- :math:`x` is the input tensor.
- :math:`state` and :math:`accum` are zero-initialized accumulators.
- :math:`rate` is moving average rate.
- :math:`scale` is quantization scale
And the computational formula of simulate quantization is:
.. math::
range = 2^{bit\_length - 1} - 1
out = round(x / scale * range) * scale / range
Where:
- :math:`{bit\_length}` is the length of bits.
- :math:`x` is the input tensor and :math:`out` is the output of simulate quantization.
Args:
moving_rate(float, optional): The rate of moving average.
bit_length(int, optional): Number of bits to represent an quantized integer in binary.
dtype(str, optional): The data type of input tensor.
name (str, optional): This parameter is used by developers to print debugging information. \
For details, please refer to :ref:`api_guide_Name`. Default is None.
Examples:
.. code-block:: python
from paddle.quantization import QuantConfig
from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.99)
q_config = QuantConfig(activation=quanter, weight=quanter)
"""
def
__init__
(
self
,
moving_rate
=
0.9
,
bit_length
=
8
,
dtype
=
'float32'
,
name
=
None
,
):
super
(
FakeQuanterWithAbsMaxObserver
,
self
).
__init__
(
name
=
name
,
moving_rate
=
moving_rate
,
bit_length
=
bit_length
,
dtype
=
dtype
,
)
def
_get_class
(
self
):
return
FakeQuanterWithAbsMaxObserverLayer
class
FakeQuanterWithAbsMaxObserverLayer
(
BaseQuanter
):
def
__init__
(
self
,
layer
,
name
=
None
,
moving_rate
=
0.9
,
bit_length
=
8
,
dtype
=
'float32'
,
):
super
(
FakeQuanterWithAbsMaxObserverLayer
,
self
).
__init__
()
self
.
_moving_rate
=
moving_rate
self
.
_bit_length
=
bit_length
scale_prefix
=
(
"{}.scale"
.
format
(
name
)
if
name
else
'quant_dequant.scale'
)
scale_attr
=
ParamAttr
(
name
=
unique_name
.
generate
(
scale_prefix
),
initializer
=
Constant
(
0.001
),
trainable
=
False
,
)
self
.
_scale
=
self
.
create_parameter
(
shape
=
[
1
],
attr
=
scale_attr
,
dtype
=
dtype
)
self
.
_scale
.
stop_gradient
=
True
state_prefix
=
(
"{}.state"
.
format
(
name
)
if
name
else
'quant_dequant.state'
)
state_attr
=
ParamAttr
(
name
=
unique_name
.
generate
(
state_prefix
),
initializer
=
Constant
(
1
),
trainable
=
False
,
)
self
.
_state
=
self
.
create_parameter
(
shape
=
[
1
],
attr
=
state_attr
,
dtype
=
dtype
)
self
.
_state
.
stop_gradient
=
True
accum_prefix
=
(
"{}.accum"
.
format
(
name
)
if
name
else
'quant_dequant.accum'
)
accum_attr
=
ParamAttr
(
name
=
unique_name
.
generate
(
accum_prefix
),
initializer
=
Constant
(
1
),
trainable
=
False
,
)
self
.
_accum
=
self
.
create_parameter
(
shape
=
[
1
],
attr
=
accum_attr
,
dtype
=
dtype
)
self
.
_accum
.
stop_gradient
=
True
def
forward
(
self
,
input
):
attrs
=
(
'moving_rate'
,
self
.
_moving_rate
,
'bit_length'
,
self
.
_bit_length
,
'is_test'
,
not
self
.
training
,
)
quant_out
=
_varbase_creator
(
type
=
input
.
type
,
name
=
"{}.quantized.dequantized"
.
format
(
input
.
name
),
shape
=
input
.
shape
,
dtype
=
input
.
dtype
,
persistable
=
False
,
)
state
=
self
.
_state
if
self
.
training
else
None
accum
=
self
.
_accum
if
self
.
training
else
None
(
out
,
_
,
_
,
_
,
)
=
_legacy_C_ops
.
fake_quantize_dequantize_moving_average_abs_max
(
input
,
self
.
_scale
,
accum
,
state
,
quant_out
,
self
.
_scale
,
state
,
accum
,
*
attrs
)
return
out
def
bit_length
(
self
):
return
self
.
bits
def
quant_axis
(
self
):
return
None
def
scales
(
self
):
return
self
.
_scale
def
zero_points
(
self
):
return
None
python/paddle/tests/CMakeLists.txt
浏览文件 @
941444aa
add_subdirectory
(
quantization
)
file
(
GLOB TEST_OPS
RELATIVE
"
${
CMAKE_CURRENT_SOURCE_DIR
}
"
...
...
python/paddle/tests/quantization/CMakeLists.txt
0 → 100644
浏览文件 @
941444aa
file
(
GLOB TEST_OPS
RELATIVE
"
${
CMAKE_CURRENT_SOURCE_DIR
}
"
"test_*.py"
)
string
(
REPLACE
".py"
""
TEST_OPS
"
${
TEST_OPS
}
"
)
foreach
(
src
${
TEST_OPS
}
)
py_test
(
${
src
}
SRCS
${
src
}
.py
)
endforeach
()
python/paddle/tests/quantization/test_customized_quanter.py
0 → 100644
浏览文件 @
941444aa
# copyright (c) 2022 paddlepaddle authors. all rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
from
typing
import
Iterable
,
Union
import
numpy
as
np
import
paddle
from
paddle.nn
import
Linear
from
paddle.quantization.base_quanter
import
BaseQuanter
from
paddle.quantization.factory
import
quanter
linear_quant_axis
=
1
@
quanter
(
"CustomizedQuanter"
)
class
CustomizedQuanterLayer
(
BaseQuanter
):
def
__init__
(
self
,
layer
,
bit_length
=
8
,
kwargs1
=
None
):
super
(
CustomizedQuanterLayer
,
self
).
__init__
()
self
.
_layer
=
layer
self
.
_bit_length
=
bit_length
self
.
_kwargs1
=
kwargs1
def
scales
(
self
)
->
Union
[
paddle
.
Tensor
,
np
.
ndarray
]:
return
None
def
bit_length
(
self
):
return
self
.
_bit_length
def
quant_axis
(
self
)
->
Union
[
int
,
Iterable
]:
return
linear_quant_axis
if
isinstance
(
self
.
_layer
,
Linear
)
else
None
def
zero_points
(
self
)
->
Union
[
paddle
.
Tensor
,
np
.
ndarray
]:
return
None
def
forward
(
self
,
input
):
return
input
class
TestCustomizedQuanter
(
unittest
.
TestCase
):
def
test_details
(
self
):
layer
=
Linear
(
5
,
5
)
bit_length
=
4
quanter
=
CustomizedQuanter
(
# noqa: F821
bit_length
=
bit_length
,
kwargs1
=
"test"
)
quanter
=
quanter
.
_instance
(
layer
)
self
.
assertEqual
(
quanter
.
bit_length
(),
bit_length
)
self
.
assertEqual
(
quanter
.
quant_axis
(),
linear_quant_axis
)
self
.
assertEqual
(
quanter
.
_kwargs1
,
'test'
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/tests/quantization/test_quant.py
0 → 100644
浏览文件 @
941444aa
# copyright (c) 2022 paddlepaddle authors. all rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
paddle
import
paddle.nn.functional
as
F
from
paddle.nn
import
Conv2D
,
Linear
,
ReLU
,
Sequential
from
paddle.quantization
import
QuantConfig
from
paddle.quantization.base_quanter
import
BaseQuanter
from
paddle.quantization.quanters
import
FakeQuanterWithAbsMaxObserver
class
LeNetDygraph
(
paddle
.
nn
.
Layer
):
def
__init__
(
self
,
num_classes
=
10
):
super
(
LeNetDygraph
,
self
).
__init__
()
self
.
num_classes
=
num_classes
self
.
features
=
Sequential
(
Conv2D
(
3
,
6
,
3
,
stride
=
1
,
padding
=
1
),
ReLU
(),
paddle
.
nn
.
MaxPool2D
(
2
,
2
),
Conv2D
(
6
,
16
,
5
,
stride
=
1
,
padding
=
0
),
ReLU
(),
paddle
.
nn
.
MaxPool2D
(
2
,
2
),
)
if
num_classes
>
0
:
self
.
fc
=
Sequential
(
Linear
(
576
,
120
),
Linear
(
120
,
84
),
Linear
(
84
,
10
)
)
def
forward
(
self
,
inputs
):
x
=
self
.
features
(
inputs
)
if
self
.
num_classes
>
0
:
x
=
paddle
.
flatten
(
x
,
1
)
x
=
self
.
fc
(
x
)
out
=
F
.
relu
(
x
)
out
=
F
.
relu
(
out
)
return
out
class
TestQuantConfig
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
model
=
LeNetDygraph
()
self
.
quanter
=
FakeQuanterWithAbsMaxObserver
(
moving_rate
=
0.9
)
def
test_global_config
(
self
):
self
.
q_config
=
QuantConfig
(
activation
=
self
.
quanter
,
weight
=
self
.
quanter
)
self
.
q_config
.
_specify
(
self
.
model
)
self
.
assertIsNotNone
(
self
.
q_config
.
global_config
.
activation
)
self
.
assertIsNotNone
(
self
.
q_config
.
global_config
.
weight
)
for
layer
in
self
.
model
.
sublayers
():
config
=
self
.
q_config
.
_get_config_by_layer
(
layer
)
self
.
assertTrue
(
config
.
activation
==
self
.
quanter
)
self
.
assertTrue
(
config
.
weight
==
self
.
quanter
)
def
assert_just_linear_weight_configure
(
self
,
model
,
config
):
for
layer
in
model
.
sublayers
():
layer_config
=
config
.
_get_config_by_layer
(
layer
)
if
type
(
layer
)
==
Linear
:
self
.
assertIsNone
(
layer_config
.
activation
)
self
.
assertEqual
(
layer_config
.
weight
,
self
.
quanter
)
self
.
assertTrue
(
config
.
_is_quantifiable
(
layer
))
elif
type
(
layer
)
==
Conv2D
:
self
.
assertIsNone
(
layer_config
)
self
.
assertFalse
(
config
.
_is_quantifiable
(
layer
))
def
test_add_layer_config
(
self
):
self
.
q_config
=
QuantConfig
(
activation
=
None
,
weight
=
None
)
self
.
q_config
.
add_layer_config
(
[
self
.
model
.
fc
],
activation
=
None
,
weight
=
self
.
quanter
)
self
.
q_config
.
_specify
(
self
.
model
)
self
.
assert_just_linear_weight_configure
(
self
.
model
,
self
.
q_config
)
def
test_add_name_config
(
self
):
self
.
q_config
=
QuantConfig
(
activation
=
None
,
weight
=
None
)
self
.
q_config
.
add_name_config
(
[
self
.
model
.
fc
.
full_name
()],
activation
=
None
,
weight
=
self
.
quanter
)
self
.
q_config
.
_specify
(
self
.
model
)
self
.
assert_just_linear_weight_configure
(
self
.
model
,
self
.
q_config
)
def
test_add_type_config
(
self
):
self
.
q_config
=
QuantConfig
(
activation
=
None
,
weight
=
None
)
self
.
q_config
.
add_type_config
(
[
Linear
],
activation
=
None
,
weight
=
self
.
quanter
)
self
.
q_config
.
_specify
(
self
.
model
)
self
.
assert_just_linear_weight_configure
(
self
.
model
,
self
.
q_config
)
def
test_add_qat_layer_mapping
(
self
):
self
.
q_config
=
QuantConfig
(
activation
=
None
,
weight
=
None
)
self
.
q_config
.
add_qat_layer_mapping
(
Sequential
,
Conv2D
)
self
.
assertTrue
(
Sequential
in
self
.
q_config
.
qat_layer_mappings
)
self
.
assertTrue
(
Sequential
not
in
self
.
q_config
.
default_qat_layer_mapping
)
def
test_add_customized_leaf
(
self
):
self
.
q_config
=
QuantConfig
(
activation
=
None
,
weight
=
None
)
self
.
q_config
.
add_customized_leaf
(
Sequential
)
self
.
assertTrue
(
Sequential
in
self
.
q_config
.
customized_leaves
)
self
.
assertTrue
(
self
.
q_config
.
_is_customized_leaf
(
self
.
model
.
fc
))
self
.
assertTrue
(
self
.
q_config
.
_is_leaf
(
self
.
model
.
fc
))
self
.
assertFalse
(
self
.
q_config
.
_is_default_leaf
(
self
.
model
.
fc
))
self
.
assertFalse
(
self
.
q_config
.
_is_real_leaf
(
self
.
model
.
fc
))
def
test_need_observe
(
self
):
self
.
q_config
=
QuantConfig
(
activation
=
None
,
weight
=
None
)
self
.
q_config
.
add_layer_config
(
[
self
.
model
.
fc
],
activation
=
self
.
quanter
,
weight
=
self
.
quanter
)
self
.
q_config
.
add_customized_leaf
(
Sequential
)
self
.
q_config
.
_specify
(
self
.
model
)
self
.
assertTrue
(
self
.
q_config
.
_has_observer_config
(
self
.
model
.
fc
))
self
.
assertTrue
(
self
.
q_config
.
_need_observe
(
self
.
model
.
fc
))
def
test__get_observer
(
self
):
self
.
q_config
=
QuantConfig
(
activation
=
None
,
weight
=
None
)
self
.
q_config
.
add_layer_config
(
[
self
.
model
.
fc
],
activation
=
self
.
quanter
,
weight
=
self
.
quanter
)
self
.
q_config
.
_specify
(
self
.
model
)
observer
=
self
.
q_config
.
_get_observer
(
self
.
model
.
fc
)
self
.
assertIsInstance
(
observer
,
BaseQuanter
)
def
test_details
(
self
):
self
.
q_config
=
QuantConfig
(
activation
=
self
.
quanter
,
weight
=
self
.
quanter
)
self
.
q_config
.
_specify
(
self
.
model
)
self
.
assertIsNotNone
(
self
.
q_config
.
details
())
self
.
assertIsNotNone
(
self
.
q_config
.
__str__
())
if
__name__
==
'__main__'
:
unittest
.
main
()
python/setup.py.in
浏览文件 @
941444aa
...
...
@@ -386,6 +386,8 @@ packages=['paddle',
'paddle.incubate.distributed.models',
'paddle.incubate.distributed.models.moe',
'paddle.incubate.distributed.models.moe.gate',
'paddle.quantization',
'paddle.quantization.quanters',
'paddle.sparse',
'paddle.sparse.nn',
'paddle.sparse.nn.layer',
...
...
setup.py
浏览文件 @
941444aa
...
...
@@ -1254,6 +1254,8 @@ def get_setup_parameters():
'paddle.incubate.distributed.models'
,
'paddle.incubate.distributed.models.moe'
,
'paddle.incubate.distributed.models.moe.gate'
,
'paddle.quantization'
,
'paddle.quantization.quanters'
,
'paddle.sparse'
,
'paddle.sparse.nn'
,
'paddle.sparse.nn.layer'
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录