Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
7dce5a2e
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
1 年多 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7dce5a2e
编写于
3月 17, 2023
作者:
C
Chang Xu
提交者:
GitHub
3月 17, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add Quanters (#1686)
上级
f54331a6
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
948 addition
and
0 deletion
+948
-0
paddleslim/quant/quanters/__init__.py
paddleslim/quant/quanters/__init__.py
+19
-0
paddleslim/quant/quanters/base_fake_quanter.py
paddleslim/quant/quanters/base_fake_quanter.py
+51
-0
paddleslim/quant/quanters/lsq_act.py
paddleslim/quant/quanters/lsq_act.py
+197
-0
paddleslim/quant/quanters/lsq_func.py
paddleslim/quant/quanters/lsq_func.py
+99
-0
paddleslim/quant/quanters/lsq_weight.py
paddleslim/quant/quanters/lsq_weight.py
+204
-0
paddleslim/quant/quanters/pact.py
paddleslim/quant/quanters/pact.py
+111
-0
tests/quantization/test_quanters.py
tests/quantization/test_quanters.py
+267
-0
未找到文件。
paddleslim/quant/quanters/__init__.py
0 → 100644
浏览文件 @
7dce5a2e
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.lsq_act
import
ActLSQplusQuanter
from
.lsq_weight
import
WeightLSQplusQuanter
from
.pact
import
PACTQuanter
__all__
=
[
"ActLSQplusQuanter"
,
"WeightLSQplusQuanter"
,
"PACTQuanter"
]
\ No newline at end of file
paddleslim/quant/quanters/base_fake_quanter.py
0 → 100644
浏览文件 @
7dce5a2e
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
abc
import
paddle
import
numpy
as
np
from
paddle.quantization.base_quanter
import
BaseQuanter
class
BaseFakeQuanterLayer
(
BaseQuanter
):
def
__init__
(
self
,
quant_bits
=
8
,
sign
=
True
,
symmetric
=
True
,
):
super
(
BaseFakeQuanterLayer
,
self
).
__init__
()
self
.
_quant_bits
=
quant_bits
self
.
_sign
=
sign
self
.
_symmetric
=
symmetric
self
.
_min
=
None
self
.
_max
=
None
self
.
_qmin
=
None
self
.
_qmax
=
None
self
.
_scale
=
None
self
.
_zero_point
=
None
@
property
def
qmin_qmax
(
self
):
""" Get the range of the integer."""
if
self
.
_qmin
is
not
None
and
self
.
_qmax
is
not
None
:
return
self
.
qmin
,
self
.
qmax
if
self
.
_sign
:
self
.
qmin
=
-
2
**
(
self
.
bit_length
()
-
1
)
self
.
qmax
=
2
**
(
self
.
bit_length
()
-
1
)
-
1
else
:
self
.
qmin
=
0
self
.
qmax
=
2
**
self
.
bit_length
()
return
self
.
qmin
,
self
.
qmax
paddleslim/quant/quanters/lsq_act.py
0 → 100644
浏览文件 @
7dce5a2e
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
numpy
as
np
import
math
from
paddle.framework
import
ParamAttr
from
paddle.nn
import
Layer
from
paddle.nn.initializer
import
Constant
from
paddle.utils
import
unique_name
from
paddle.quantization.factory
import
QuanterFactory
from
.base_fake_quanter
import
BaseFakeQuanterLayer
from
.lsq_func
import
LsqFunc
,
LsqPlusActFunc
,
round
class
ActLSQplusQuanter
(
QuanterFactory
):
r
"""
Activation quantizer. More details can be found in
https://arxiv.org/pdf/1902.08153.pdf and https://arxiv.org/pdf/2004.09576.pdf.
Args:
per_channel(bool): whether layer-wise or channel-wise quantization, where True for layer-wise quantization and False for channel-wise quantization.
batch_init(int): number of batches that collect Gaussian approximation for the weight distribution in each layer.
quant_linear(bool): whether the weight is from Linear.
dtype(str): data type.
name(str): the name of the layer.
reduce_type(str): the reduce type which is needed when parallel training.
Examples:
.. code-block:: python
from paddle.quantization import QuantConfig
from paddle.quantization.quanters import ActLSQplusQuanter, WeightLSQplusQuanter
weight_quanter = WeightLSQplusQuanter()
act_quanter = ActLSQplusQuanter()
q_config = QuantConfig(activation=act_quanter, weight=weight_quanter)
"""
def
__init__
(
self
,
quant_bits
=
8
,
sign
=
True
,
symmetric
=
True
,
per_channel
=
False
,
batch_init
=
20
,
quant_linear
=
False
,
reduce_type
=
None
,
dtype
=
'float32'
,
name
=
None
):
super
(
ActLSQplusQuanter
,
self
).
__init__
(
quant_bits
=
quant_bits
,
sign
=
sign
,
symmetric
=
symmetric
,
per_channel
=
per_channel
,
batch_init
=
batch_init
,
quant_linear
=
quant_linear
,
reduce_type
=
reduce_type
,
dtype
=
dtype
,
name
=
name
)
def
_get_class
(
self
):
return
ActLSQplusQuanterLayer
class
ActLSQplusQuanterLayer
(
BaseFakeQuanterLayer
):
def
__init__
(
self
,
layer
,
quant_bits
=
8
,
sign
=
True
,
symmetric
=
True
,
per_channel
=
False
,
batch_init
=
20
,
quant_linear
=
False
,
reduce_type
=
None
,
dtype
=
'float32'
,
name
=
None
):
super
(
ActLSQplusQuanterLayer
,
self
).
__init__
()
self
.
_symmetric
=
symmetric
self
.
_per_channel
=
per_channel
self
.
_quant_linear
=
quant_linear
self
.
_batch_init
=
batch_init
self
.
_name
=
name
self
.
_quant_axis
=
1
if
quant_linear
else
0
self
.
_collect_axis
=
0
if
quant_linear
else
1
self
.
_reduce_type
=
reduce_type
self
.
div
=
2
**
self
.
_quant_bits
-
1
self
.
qmin
,
self
.
qmax
=
self
.
qmin_qmax
self
.
_current_batch_id
=
0
self
.
_init_state
=
0
scale_prefix
=
(
"{}.scale"
.
format
(
name
)
if
name
else
'quant_dequant.scale'
)
self
.
_scale_name
=
unique_name
.
generate
(
scale_prefix
)
s_attr
=
ParamAttr
(
name
=
self
.
_scale_name
,
initializer
=
Constant
(
1.0
),
trainable
=
True
)
self
.
_scale
=
self
.
create_parameter
(
shape
=
[
1
],
attr
=
s_attr
,
dtype
=
dtype
)
self
.
_scale
.
stop_gradient
=
False
if
not
self
.
_symmetric
:
beta_prefix
=
(
"{}.beta"
.
format
(
name
)
if
name
else
'quant_dequant.beta'
)
self
.
_beta_name
=
unique_name
.
generate
(
beta_prefix
)
beta_attr
=
ParamAttr
(
name
=
self
.
_beta_name
,
initializer
=
Constant
(
0.0
),
trainable
=
True
)
self
.
_beta
=
self
.
create_parameter
(
shape
=
[
1
],
attr
=
beta_attr
,
dtype
=
'float32'
)
self
.
_beta
.
stop_gradient
=
False
def
init_params
(
self
,
activation
):
self
.
g
=
paddle
.
to_tensor
(
1.0
/
math
.
sqrt
(
activation
.
numel
()
*
self
.
qmax
))
min_a
=
paddle
.
min
(
activation
.
detach
())
max_a
=
paddle
.
max
(
activation
.
detach
())
self
.
_scale
.
set_value
((
max_a
-
min_a
)
/
(
self
.
qmax
-
self
.
qmin
))
if
not
self
.
_symmetric
:
self
.
_beta
.
set_value
(
min_a
-
self
.
_scale
*
self
.
qmin
)
self
.
_init_state
+=
1
def
collect_gaussian
(
self
,
activation
):
min_a
=
paddle
.
min
(
activation
.
detach
())
max_a
=
paddle
.
max
(
activation
.
detach
())
self
.
_scale
.
set_value
(
self
.
_scale
*
0.9
+
0.1
*
(
max_a
-
min_a
)
/
(
self
.
qmax
-
self
.
qmin
))
if
not
self
.
_symmetric
:
self
.
_beta
.
set_value
(
self
.
_scale
*
0.9
+
0.1
*
(
min_a
-
self
.
_scale
*
self
.
qmin
))
self
.
_init_state
+=
1
def
forward
(
self
,
activation
):
if
self
.
_reduce_type
==
"max"
:
paddle
.
distributed
.
all_reduce
(
self
.
_scale
,
op
=
paddle
.
distributed
.
ReduceOp
.
MAX
)
if
not
self
.
_symmetric
and
self
.
_reduce_type
==
"max"
:
paddle
.
distributed
.
all_reduce
(
self
.
_beta
,
op
=
paddle
.
distributed
.
ReduceOp
.
MAX
)
if
self
.
_init_state
==
0
:
self
.
init_params
(
activation
)
elif
self
.
_init_state
<
self
.
_batch_init
:
self
.
collect_gaussian
(
activation
)
activation
.
stop_gradient
=
False
if
not
self
.
_symmetric
:
q_a
=
LsqPlusActFunc
.
apply
(
activation
,
self
.
_scale
,
self
.
_beta
,
self
.
g
,
self
.
qmin
,
self
.
qmax
)
else
:
q_a
=
LsqFunc
.
apply
(
activation
,
self
.
_scale
,
self
.
g
,
self
.
qmin
,
self
.
qmax
,
per_channel
=
False
)
return
q_a
def
bit_length
(
self
):
""" Return the bit length of quantized data.
"""
return
self
.
_quant_bits
def
quant_axis
(
self
):
""" Return quantization axis.
"""
return
self
.
_quant_axis
def
scales
(
self
):
""" Return output scales.
"""
return
self
.
_scale
def
zero_points
(
self
):
""" Return output zero points.
"""
if
self
.
_zero_point
is
None
:
if
self
.
_symmetric
:
if
self
.
_sign
:
self
.
_zero_point
=
0
else
:
self
.
_zero_point
=
(
self
.
qmax
+
self
.
qmin
)
/
2
else
:
self
.
_zero_point
=
self
.
qmin
-
round
(
self
.
qmin
/
self
.
_scale
)
self
.
_zero_point
=
paddle
.
clip
(
self
.
_zero_point
,
self
.
qmin
,
self
.
qmax
)
return
self
.
_zero_point
paddleslim/quant/quanters/lsq_func.py
0 → 100644
浏览文件 @
7dce5a2e
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
paddle
from
paddle.autograd
import
PyLayer
def
round
(
x
):
sign
=
paddle
.
sign
(
x
)
x
=
sign
*
paddle
.
floor
(
paddle
.
abs
(
x
)
+
0.5
)
return
x
class
LsqFunc
(
PyLayer
):
@
staticmethod
def
forward
(
ctx
,
weight
,
alpha
,
g
,
Qn
,
Qp
,
per_channel
=
False
,
quant_axis
=
0
):
ctx
.
save_for_backward
(
weight
,
alpha
)
ctx
.
other
=
g
,
Qn
,
Qp
,
per_channel
,
quant_axis
if
per_channel
:
sizes
=
weight
.
shape
weight
=
weight
.
reshape
((
weight
.
shape
[
quant_axis
],
-
1
))
weight
=
weight
.
transpose
((
1
,
0
))
alpha
=
paddle
.
broadcast_to
(
alpha
,
weight
.
shape
)
quant_w
=
round
(
paddle
.
divide
(
weight
,
alpha
)).
clip
(
Qn
,
Qp
)
quant_w
=
quant_w
*
alpha
quant_w
=
quant_w
.
transpose
((
1
,
0
))
quant_w
=
quant_w
.
reshape
(
sizes
)
else
:
quant_w
=
round
(
paddle
.
divide
(
weight
,
alpha
)).
clip
(
Qn
,
Qp
)
quant_w
=
quant_w
*
alpha
return
quant_w
@
staticmethod
def
backward
(
ctx
,
grad_weight
):
weight
,
alpha
=
ctx
.
saved_tensor
()
g
,
Qn
,
Qp
,
per_channel
,
quant_axis
=
ctx
.
other
if
per_channel
:
sizes
=
weight
.
shape
weight
=
weight
.
reshape
((
weight
.
shape
[
quant_axis
],
-
1
))
weight
=
weight
.
transpose
((
1
,
0
))
alpha
=
paddle
.
broadcast_to
(
alpha
,
weight
.
shape
)
q_w
=
paddle
.
divide
(
weight
,
alpha
)
q_w
=
q_w
.
transpose
((
1
,
0
))
q_w
=
q_w
.
reshape
(
sizes
)
else
:
q_w
=
paddle
.
divide
(
weight
,
alpha
)
lower_flag
=
paddle
.
cast
((
q_w
<
Qn
),
'float32'
)
upper_flag
=
paddle
.
cast
((
q_w
>
Qp
),
'float32'
)
middle_flag
=
1.0
-
lower_flag
-
upper_flag
if
per_channel
:
grad_alpha
=
(
(
lower_flag
*
Qn
+
upper_flag
*
Qp
+
middle_flag
*
round
(
q_w
)
-
middle_flag
*
q_w
)
*
grad_weight
*
g
)
grad_alpha
=
grad_alpha
.
reshape
((
grad_alpha
.
shape
[
quant_axis
],
-
1
)).
sum
(
axis
=
1
)
else
:
grad_alpha
=
((
(
lower_flag
*
Qn
+
upper_flag
*
Qp
+
middle_flag
*
round
(
q_w
)
-
middle_flag
*
q_w
)
*
grad_weight
*
g
).
sum
().
unsqueeze
(
axis
=
0
)[
0
])
grad_weight
=
middle_flag
*
grad_weight
return
grad_weight
,
grad_alpha
class
LsqPlusActFunc
(
PyLayer
):
@
staticmethod
def
forward
(
ctx
,
x
,
alpha
,
beta
,
g
,
Qn
,
Qp
):
ctx
.
save_for_backward
(
x
,
alpha
,
beta
)
ctx
.
other
=
g
,
Qn
,
Qp
quant_x
=
round
(
paddle
.
divide
((
x
-
beta
),
alpha
)).
clip
(
Qn
,
Qp
)
return
quant_x
*
alpha
+
beta
@
staticmethod
def
backward
(
ctx
,
grad_x
):
x
,
alpha
,
beta
=
ctx
.
saved_tensor
()
g
,
Qn
,
Qp
=
ctx
.
other
q_x
=
(
x
-
beta
)
/
alpha
lower_flag
=
paddle
.
cast
((
q_x
<
Qn
),
'float32'
)
upper_flag
=
paddle
.
cast
((
q_x
>
Qp
),
'float32'
)
middle_flag
=
1.0
-
lower_flag
-
upper_flag
grad_alpha
=
((
(
lower_flag
*
Qn
+
upper_flag
*
Qp
+
middle_flag
*
round
(
q_x
)
-
middle_flag
*
q_x
)
*
grad_x
*
g
).
sum
().
unsqueeze
(
axis
=
0
)[
0
])
grad_beta
=
(((
lower_flag
+
upper_flag
)
*
grad_x
*
g
).
sum
().
unsqueeze
(
axis
=
0
)[
0
])
grad_x
=
middle_flag
*
grad_x
return
grad_x
,
grad_alpha
,
grad_beta
paddleslim/quant/quanters/lsq_weight.py
0 → 100644
浏览文件 @
7dce5a2e
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
abc
import
paddle
import
numpy
as
np
import
math
from
paddle.framework
import
ParamAttr
from
paddle.nn
import
Layer
from
paddle.nn.initializer
import
Constant
from
paddle.utils
import
unique_name
from
paddle.quantization.factory
import
QuanterFactory
from
.lsq_func
import
LsqFunc
,
round
from
.base_fake_quanter
import
BaseFakeQuanterLayer
class
WeightLSQplusQuanter
(
QuanterFactory
):
r
"""
Weight quantizer. More details can be found in
https://arxiv.org/pdf/1902.08153.pdf and https://arxiv.org/pdf/2004.09576.pdf.
Args:
per_channel(bool): Whether layer-wise or channel-wise quantization, where True for layer-wise quantization and False for channel-wise quantization.
batch_init(int): Number of batches that collect Gaussian approximation for the weight distribution in each layer.
quant_linear(bool): whether the weight is from Linear.
dtype(str): Trainable data type.
name(str): The name of the layer.
reduce_type(str): The reduce type which is needed when parallel training.
Examples:
.. code-block:: python
from paddle.quantization import QuantConfig
from paddle.quantization.quanters import ActLSQplusQuanter, WeightLSQplusQuanter
weight_quanter = WeightLSQplusQuanter()
act_quanter = ActLSQplusQuanter()
q_config = QuantConfig(activation=act_quanter, weight=weight_quanter)
"""
def
__init__
(
self
,
quant_bits
=
8
,
sign
=
True
,
symmetric
=
True
,
per_channel
=
False
,
batch_init
=
20
,
quant_linear
=
False
,
channel_num
=
None
,
reduce_type
=
None
,
dtype
=
'float32'
,
name
=
None
):
super
(
WeightLSQplusQuanter
,
self
).
__init__
(
quant_bits
=
quant_bits
,
sign
=
sign
,
symmetric
=
symmetric
,
per_channel
=
per_channel
,
batch_init
=
batch_init
,
quant_linear
=
quant_linear
,
channel_num
=
channel_num
,
reduce_type
=
reduce_type
,
dtype
=
dtype
,
name
=
name
)
def
_get_class
(
self
):
return
WeightLSQplusQuanterLayer
class
WeightLSQplusQuanterLayer
(
BaseFakeQuanterLayer
):
def
__init__
(
self
,
layer
,
quant_bits
=
8
,
sign
=
True
,
symmetric
=
True
,
per_channel
=
False
,
all_postive
=
False
,
batch_init
=
20
,
quant_linear
=
False
,
channel_num
=
None
,
reduce_type
=
None
,
dtype
=
'float32'
,
name
=
None
):
super
(
WeightLSQplusQuanterLayer
,
self
).
__init__
()
self
.
_per_channel
=
per_channel
self
.
_quant_linear
=
quant_linear
self
.
_batch_init
=
batch_init
self
.
_name
=
name
self
.
_quant_axis
=
1
if
quant_linear
else
0
self
.
_collect_axis
=
0
if
quant_linear
else
1
self
.
_reduce_type
=
reduce_type
self
.
div
=
2
**
self
.
_quant_bits
-
1
self
.
qmin
,
self
.
qmax
=
self
.
qmin_qmax
self
.
_current_batch_id
=
0
self
.
_init_state
=
0
scale_prefix
=
(
"{}.scale"
.
format
(
name
)
if
name
else
'quant_dequant.scale'
)
self
.
_scale_name
=
unique_name
.
generate
(
scale_prefix
)
s_attr
=
ParamAttr
(
name
=
self
.
_scale_name
,
initializer
=
Constant
(
1.0
),
trainable
=
True
)
channel_num
=
layer
.
weight
.
shape
[
self
.
_quant_axis
]
if
self
.
_per_channel
else
1
self
.
_scale
=
self
.
create_parameter
(
shape
=
[
channel_num
],
attr
=
s_attr
,
dtype
=
dtype
)
self
.
_scale
.
stop_gradient
=
False
def
init_params
(
self
,
weight
):
self
.
g
=
paddle
.
to_tensor
(
1.0
/
math
.
sqrt
(
weight
.
numel
()
*
self
.
qmax
))
if
self
.
_per_channel
:
weight_tmp
=
weight
.
detach
().
reshape
((
weight
.
shape
[
0
],
-
1
))
mean
=
paddle
.
mean
(
weight_tmp
,
axis
=
self
.
_collect_axis
)
std
=
paddle
.
std
(
weight_tmp
,
axis
=
self
.
_collect_axis
)
s
=
paddle
.
max
(
paddle
.
stack
(
[
paddle
.
abs
(
mean
-
3
*
std
),
paddle
.
abs
(
mean
+
3
*
std
)]),
axis
=
0
,
)
self
.
_scale
.
set_value
(
s
/
self
.
div
)
else
:
mean
=
paddle
.
mean
(
weight
.
detach
())
std
=
paddle
.
std
(
weight
.
detach
())
self
.
_scale
.
set_value
(
max
([
paddle
.
abs
(
mean
-
3
*
std
),
paddle
.
abs
(
mean
+
3
*
std
)])
/
self
.
div
)
self
.
_init_state
+=
1
def
collect_gaussian
(
self
,
weight
):
if
self
.
_per_channel
:
weight_tmp
=
weight
.
detach
().
reshape
((
weight
.
shape
[
0
],
-
1
))
mean
=
paddle
.
mean
(
weight_tmp
,
axis
=
self
.
_collect_axis
)
std
=
paddle
.
std
(
weight_tmp
,
axis
=
self
.
_collect_axis
)
s
=
paddle
.
max
(
paddle
.
stack
(
[
paddle
.
abs
(
mean
-
3
*
std
),
paddle
.
abs
(
mean
+
3
*
std
)]),
axis
=
0
,
)
self
.
_scale
.
set_value
(
s
*
0.9
+
0.1
*
s
/
self
.
div
)
else
:
mean
=
paddle
.
mean
(
weight
.
detach
())
std
=
paddle
.
std
(
weight
.
detach
())
self
.
_scale
.
set_value
(
self
.
_scale
*
0.9
+
0.1
*
max
(
[
paddle
.
abs
(
mean
-
3
*
std
),
paddle
.
abs
(
mean
+
3
*
std
)])
/
self
.
div
)
self
.
_init_state
+=
1
def
forward
(
self
,
weight
):
if
self
.
_reduce_type
==
"max"
:
paddle
.
distributed
.
all_reduce
(
self
.
_scale
,
op
=
paddle
.
distributed
.
ReduceOp
.
MAX
)
if
self
.
_init_state
==
0
:
self
.
init_params
(
weight
)
elif
self
.
_init_state
<
self
.
_batch_init
:
self
.
collect_gaussian
(
weight
)
weight
.
stop_gradient
=
False
w_q
=
LsqFunc
.
apply
(
weight
,
self
.
_scale
,
self
.
g
,
self
.
qmin
,
self
.
qmax
,
self
.
_per_channel
,
self
.
_quant_axis
,
)
return
w_q
def
bit_length
(
self
):
""" Return the bit length of quantized data.
"""
return
self
.
_quant_bits
def
quant_axis
(
self
):
""" Return quantization axis.
"""
return
self
.
_quant_axis
def
scales
(
self
):
""" Return output scales.
"""
return
self
.
_scale
def
zero_points
(
self
):
""" Return output zero points.
"""
if
self
.
_zero_point
is
None
:
if
self
.
_symmetric
:
if
self
.
_sign
:
self
.
_zero_point
=
0
else
:
self
.
_zero_point
=
(
self
.
qmax
+
self
.
qmin
)
/
2
else
:
self
.
_zero_point
=
self
.
qmin
-
round
(
self
.
qmin
/
self
.
_scale
)
self
.
_zero_point
=
paddle
.
clip
(
self
.
_zero_point
,
self
.
qmin
,
self
.
qmax
)
return
self
.
_zero_point
paddleslim/quant/quanters/pact.py
0 → 100644
浏览文件 @
7dce5a2e
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
abc
import
paddle
import
numpy
as
np
import
math
from
paddle.framework
import
ParamAttr
from
paddle.nn
import
Layer
from
paddle.nn.initializer
import
Constant
from
paddle.utils
import
unique_name
from
paddle.quantization.factory
import
QuanterFactory
from
paddle.quantization.base_quanter
import
BaseQuanter
class
PACTQuanter
(
QuanterFactory
):
r
"""
PArameterized Clipping acTivation(PACT) uses an activation clipping parameter alpha to find the right quantization scale.
More details can be found in
https://arxiv.org/pdf/1805.06085.pdf.
Args:
quanter(BaseQuanter, required): It can be any BaseQuanter. PACT can be used with any other quantization method.
init_value(float, optional): Value of initial alpha. Default 100
learning_rate(float, optional): The learning rate of alpha when optimizing.
dtype(str): Trainable data type.
name(str): The name of the layer.
Examples:
.. code-block:: python
from paddle.quantization import QuantConfig
from paddle.quantization.quanters import PACTQuanter
from paddle.quantization.quanters.abs_max import FakeQuanterWithAbsMaxObserverLayer
pact_quanter = PACTQuanter(quanter=FakeQuanterWithAbsMaxObserverLayer)
q_config = QuantConfig(activation=pact_quanter, weight=pact_quanter)
"""
def
__init__
(
self
,
quanter
,
init_value
=
100.
,
learning_rate
=
1000.
,
dtype
=
'float32'
,
name
=
None
):
super
(
PACTQuanter
,
self
).
__init__
(
quanter
=
quanter
,
init_value
=
init_value
,
learning_rate
=
learning_rate
,
dtype
=
dtype
,
name
=
name
)
def
_get_class
(
self
):
return
PACTQuanterLayer
class
PACTQuanterLayer
(
BaseQuanter
):
def
__init__
(
self
,
layer
,
quanter
,
init_value
=
1000
,
learning_rate
=
1000.
,
dtype
=
'float32'
,
name
=
None
):
super
(
PACTQuanterLayer
,
self
).
__init__
()
self
.
quanter
=
quanter
(
layer
)
alpha_prefix
=
(
"{}.pact"
.
format
(
name
)
if
name
else
'quant_dequant.pact'
)
name
=
unique_name
.
generate
(
alpha_prefix
)
alpha_attr
=
paddle
.
ParamAttr
(
name
=
name
,
initializer
=
paddle
.
nn
.
initializer
.
Constant
(
value
=
init_value
),
learning_rate
=
learning_rate
)
self
.
alpha
=
self
.
create_parameter
(
shape
=
[
1
],
attr
=
alpha_attr
,
dtype
=
dtype
)
def
forward
(
self
,
activation
):
out_left
=
paddle
.
nn
.
functional
.
relu
(
activation
-
self
.
alpha
)
out_right
=
paddle
.
nn
.
functional
.
relu
(
-
self
.
alpha
-
activation
)
activation
=
activation
-
out_left
+
out_right
return
self
.
quanter
(
activation
)
def
bit_length
(
self
):
""" Return the bit length of quantized data.
"""
return
self
.
quanter
.
bit_length
()
def
quant_axis
(
self
):
""" Return quantization axis.
"""
return
self
.
quanter
.
quant_axis
()
def
scales
(
self
):
""" Return output scales.
"""
return
self
.
quanter
.
scales
()
def
zero_points
(
self
):
""" Return output zero points.
"""
return
self
.
quanter
.
zero_points
()
tests/quantization/test_quanters.py
0 → 100644
浏览文件 @
7dce5a2e
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
sys
import
os
import
unittest
import
paddle
import
tempfile
import
numpy
as
np
sys
.
path
.
append
(
"../../"
)
from
paddle.vision.models
import
resnet18
from
paddle.quantization
import
QuantConfig
from
paddle.quantization
import
QAT
from
paddleslim.quant.quanters
import
ActLSQplusQuanter
,
WeightLSQplusQuanter
,
PACTQuanter
from
paddleslim.quant.quanters.lsq_act
import
ActLSQplusQuanterLayer
from
paddleslim.quant.quanters.lsq_weight
import
WeightLSQplusQuanterLayer
from
paddleslim.quant.quanters.pact
import
PACTQuanterLayer
from
paddle.quantization.quanters
import
FakeQuanterWithAbsMaxObserver
from
paddle.quantization.quanters.abs_max
import
FakeQuanterWithAbsMaxObserverLayer
from
paddle.nn.quant.format
import
LinearDequanter
,
LinearQuanter
import
logging
from
paddleslim.common
import
get_logger
_logger
=
get_logger
(
__name__
,
level
=
logging
.
INFO
)
class
ImperativeLenet
(
paddle
.
nn
.
Layer
):
def
__init__
(
self
,
num_classes
=
10
,
classifier_activation
=
'softmax'
):
super
(
ImperativeLenet
,
self
).
__init__
()
self
.
features
=
paddle
.
nn
.
Sequential
(
paddle
.
nn
.
Conv2D
(
in_channels
=
1
,
out_channels
=
6
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
),
paddle
.
nn
.
AvgPool2D
(
kernel_size
=
2
,
stride
=
2
),
paddle
.
nn
.
Conv2D
(
in_channels
=
6
,
out_channels
=
16
,
kernel_size
=
5
,
stride
=
1
,
padding
=
0
),
paddle
.
nn
.
AvgPool2D
(
kernel_size
=
2
,
stride
=
2
))
self
.
fc
=
paddle
.
nn
.
Sequential
(
paddle
.
nn
.
Linear
(
in_features
=
400
,
out_features
=
120
),
paddle
.
nn
.
Linear
(
in_features
=
120
,
out_features
=
84
),
paddle
.
nn
.
Linear
(
in_features
=
84
,
out_features
=
num_classes
),
)
def
forward
(
self
,
inputs
):
x
=
self
.
features
(
inputs
)
x
=
paddle
.
flatten
(
x
,
1
)
x
=
self
.
fc
(
x
)
return
x
class
TestQATWithQuanters
(
unittest
.
TestCase
):
def
__init__
(
self
,
act_observer
,
act_observer_type
,
weight_observer
,
weight_observer_type
,
*
args
,
**
kvargs
):
super
(
TestQATWithQuanters
,
self
).
__init__
(
*
args
,
**
kvargs
)
self
.
act_observer
=
act_observer
self
.
act_observer_type
=
act_observer_type
self
.
weight_observer
=
weight_observer
self
.
weight_observer_type
=
weight_observer_type
def
setUp
(
self
):
self
.
init_case
()
self
.
dummy_input
=
paddle
.
rand
([
1
,
3
,
224
,
224
])
self
.
temp_dir
=
tempfile
.
TemporaryDirectory
(
dir
=
"./"
)
self
.
path
=
os
.
path
.
join
(
self
.
temp_dir
.
name
,
'qat'
)
if
not
os
.
path
.
exists
(
'ILSVRC2012_data_demo'
):
os
.
system
(
'wget -q https://sys-p0.bj.bcebos.com/slim_ci/ILSVRC2012_data_demo.tar.gz'
)
os
.
system
(
'tar -xf ILSVRC2012_data_demo.tar.gz'
)
seed
=
1
np
.
random
.
seed
(
seed
)
paddle
.
static
.
default_main_program
().
random_seed
=
seed
paddle
.
static
.
default_startup_program
().
random_seed
=
seed
def
tearDown
(
self
):
self
.
temp_dir
.
cleanup
()
def
runTest
(
self
):
self
.
test_quantize
()
self
.
test_convert
()
self
.
test_convergence
()
def
init_case
(
self
):
self
.
q_config
=
QuantConfig
(
activation
=
None
,
weight
=
None
)
self
.
q_config
.
add_type_config
(
paddle
.
nn
.
Conv2D
,
activation
=
self
.
act_observer
,
weight
=
self
.
weight_observer
)
def
_count_layers
(
self
,
model
,
layer_type
):
count
=
0
for
_layer
in
model
.
sublayers
(
True
):
if
isinstance
(
_layer
,
layer_type
):
count
+=
1
return
count
def
test_quantize
(
self
):
model
=
resnet18
()
conv_count
=
self
.
_count_layers
(
model
,
paddle
.
nn
.
Conv2D
)
qat
=
QAT
(
self
.
q_config
)
model
.
train
()
quant_model
=
qat
.
quantize
(
model
,
inplace
=
False
)
out
=
quant_model
(
self
.
dummy_input
)
quantizer_cnt
=
self
.
_count_layers
(
quant_model
,
self
.
act_observer_type
)
self
.
assertEqual
(
quantizer_cnt
,
conv_count
)
quantizer_cnt
=
self
.
_count_layers
(
quant_model
,
self
.
weight_observer_type
)
self
.
assertEqual
(
quantizer_cnt
,
conv_count
)
def
test_convergence
(
self
):
model
=
ImperativeLenet
()
conv_count
=
self
.
_count_layers
(
model
,
paddle
.
nn
.
Conv2D
)
qat
=
QAT
(
self
.
q_config
)
model
.
train
()
quant_model
=
qat
.
quantize
(
model
,
inplace
=
False
)
place
=
paddle
.
CUDAPlace
(
0
)
\
if
paddle
.
is_compiled_with_cuda
()
else
paddle
.
CPUPlace
()
transform
=
paddle
.
vision
.
transforms
.
Compose
([
paddle
.
vision
.
transforms
.
Transpose
(),
paddle
.
vision
.
transforms
.
Normalize
([
127.5
],
[
127.5
])
])
train_dataset
=
paddle
.
vision
.
datasets
.
MNIST
(
mode
=
'train'
,
backend
=
'cv2'
,
transform
=
transform
)
val_dataset
=
paddle
.
vision
.
datasets
.
MNIST
(
mode
=
'test'
,
backend
=
'cv2'
,
transform
=
transform
)
train_reader
=
paddle
.
io
.
DataLoader
(
train_dataset
,
drop_last
=
True
,
places
=
place
,
batch_size
=
64
,
return_list
=
True
)
test_reader
=
paddle
.
io
.
DataLoader
(
val_dataset
,
places
=
place
,
batch_size
=
64
,
return_list
=
True
)
def
train
(
model
):
adam
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
0.0001
,
parameters
=
model
.
parameters
())
epoch_num
=
1
for
epoch
in
range
(
epoch_num
):
model
.
train
()
for
batch_id
,
data
in
enumerate
(
train_reader
):
img
=
paddle
.
to_tensor
(
data
[
0
])
label
=
paddle
.
to_tensor
(
data
[
1
])
img
=
paddle
.
reshape
(
img
,
[
-
1
,
1
,
28
,
28
])
label
=
paddle
.
reshape
(
label
,
[
-
1
,
1
])
out
=
model
(
img
)
acc
=
paddle
.
metric
.
accuracy
(
out
,
label
)
loss
=
paddle
.
nn
.
functional
.
loss
.
cross_entropy
(
out
,
label
)
avg_loss
=
paddle
.
mean
(
loss
)
avg_loss
.
backward
()
adam
.
minimize
(
avg_loss
)
model
.
clear_gradients
()
if
batch_id
%
100
==
0
:
_logger
.
info
(
"Train | At epoch {} step {}: loss = {:}, acc= {:}"
.
format
(
epoch
,
batch_id
,
avg_loss
.
numpy
(),
acc
.
numpy
()))
def
test
(
model
):
model
.
eval
()
avg_acc
=
[[],
[]]
for
batch_id
,
data
in
enumerate
(
test_reader
):
img
=
paddle
.
to_tensor
(
data
[
0
])
img
=
paddle
.
reshape
(
img
,
[
-
1
,
1
,
28
,
28
])
label
=
paddle
.
to_tensor
(
data
[
1
])
label
=
paddle
.
reshape
(
label
,
[
-
1
,
1
])
out
=
model
(
img
)
acc_top1
=
paddle
.
metric
.
accuracy
(
input
=
out
,
label
=
label
,
k
=
1
)
acc_top5
=
paddle
.
metric
.
accuracy
(
input
=
out
,
label
=
label
,
k
=
5
)
avg_acc
[
0
].
append
(
acc_top1
.
numpy
())
avg_acc
[
1
].
append
(
acc_top5
.
numpy
())
if
batch_id
%
100
==
0
:
_logger
.
info
(
"Test | step {}: acc1 = {:}, acc5 = {:}"
.
format
(
batch_id
,
acc_top1
.
numpy
(),
acc_top5
.
numpy
()))
_logger
.
info
(
"Test | Average: acc_top1 {}, acc_top5 {}"
.
format
(
np
.
mean
(
avg_acc
[
0
]),
np
.
mean
(
avg_acc
[
1
])))
return
np
.
mean
(
avg_acc
[
0
]),
np
.
mean
(
avg_acc
[
1
])
train
(
model
)
top1_1
,
top5_1
=
test
(
model
)
quant_model
.
train
()
train
(
quant_model
)
top1_2
,
top5_2
=
test
(
quant_model
)
_logger
.
info
(
"Before quantization: top1: {}, top5: {}"
.
format
(
top1_1
,
top5_1
))
_logger
.
info
(
"After quantization: top1: {}, top5: {}"
.
format
(
top1_2
,
top5_2
))
_logger
.
info
(
"
\n
"
)
diff
=
0.01
self
.
assertTrue
(
top1_1
-
top1_2
<
diff
,
msg
=
"The acc of quant model is too lower than fp32 model"
)
_logger
.
info
(
'done'
)
return
def
test_convert
(
self
):
model
=
resnet18
()
conv_count
=
self
.
_count_layers
(
model
,
paddle
.
nn
.
Conv2D
)
qat
=
QAT
(
self
.
q_config
)
model
.
train
()
quant_model
=
qat
.
quantize
(
model
,
inplace
=
False
)
out
=
quant_model
(
self
.
dummy_input
)
converted_model
=
qat
.
convert
(
quant_model
,
inplace
=
False
)
# check count of LinearQuanter and LinearDequanter in dygraph
quantizer_count_in_dygraph
=
self
.
_count_layers
(
converted_model
,
LinearQuanter
)
dequantizer_count_in_dygraph
=
self
.
_count_layers
(
converted_model
,
LinearDequanter
)
self
.
assertEqual
(
quantizer_count_in_dygraph
,
conv_count
)
self
.
assertEqual
(
dequantizer_count_in_dygraph
,
conv_count
*
2
)
observer_suite
=
unittest
.
TestSuite
()
observer_suite
.
addTest
(
TestQATWithQuanters
(
act_observer
=
ActLSQplusQuanter
(),
act_observer_type
=
ActLSQplusQuanterLayer
,
weight_observer
=
WeightLSQplusQuanter
(),
weight_observer_type
=
WeightLSQplusQuanterLayer
))
observer_suite
.
addTest
(
TestQATWithQuanters
(
act_observer
=
ActLSQplusQuanter
(
symmetric
=
False
),
act_observer_type
=
ActLSQplusQuanterLayer
,
weight_observer
=
WeightLSQplusQuanter
(
per_channel
=
True
),
weight_observer_type
=
WeightLSQplusQuanterLayer
))
observer_suite
.
addTest
(
TestQATWithQuanters
(
act_observer
=
PACTQuanter
(
quanter
=
ActLSQplusQuanterLayer
),
act_observer_type
=
PACTQuanterLayer
,
weight_observer
=
WeightLSQplusQuanter
(),
weight_observer_type
=
WeightLSQplusQuanterLayer
))
if
__name__
==
'__main__'
:
runner
=
unittest
.
TextTestRunner
(
verbosity
=
2
)
runner
.
run
(
observer_suite
)
os
.
system
(
'rm -rf ILSVRC2012_data_demo.tar.gz'
)
os
.
system
(
'rm -rf ILSVRC2012_data_demo'
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录