Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
341f68fe
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
341f68fe
编写于
9月 27, 2022
作者:
C
Chang Xu
提交者:
GitHub
9月 27, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add LSQ/LSQ+ in QAT (#45652)
上级
383dc908
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
557 addition
and
5 deletion
+557
-5
python/paddle/fluid/contrib/slim/quantization/imperative/qat.py
.../paddle/fluid/contrib/slim/quantization/imperative/qat.py
+5
-3
python/paddle/fluid/contrib/slim/tests/CMakeLists.txt
python/paddle/fluid/contrib/slim/tests/CMakeLists.txt
+3
-0
python/paddle/fluid/contrib/slim/tests/test_imperative_qat_lsq.py
...addle/fluid/contrib/slim/tests/test_imperative_qat_lsq.py
+198
-0
python/paddle/nn/quant/lsq.py
python/paddle/nn/quant/lsq.py
+328
-0
python/paddle/nn/quant/quant_layers.py
python/paddle/nn/quant/quant_layers.py
+23
-2
未找到文件。
python/paddle/fluid/contrib/slim/quantization/imperative/qat.py
浏览文件 @
341f68fe
...
...
@@ -324,16 +324,18 @@ class ImperativeQuantizeInputs(object):
"%s is unspported to be quantized."
%
layer
quantize_type
=
{
'abs_max'
,
'moving_average_abs_max'
,
'channel_wise_abs_max'
'abs_max'
,
'moving_average_abs_max'
,
'channel_wise_abs_max'
,
'lsq_weight'
,
'channel_wise_lsq_weight'
}
act_quantize_type
=
{
'moving_average_abs_max'
,
'lsq_act'
}
assert
weight_quantize_type
!=
'moving_average_abs_max'
\
and
weight_quantize_type
in
quantize_type
,
\
"Unsupported weight_quantize_type: %s. It can only "
\
"be abs_max or channel_wise_abs_max."
%
weight_quantize_type
# TODO (jc): activation_quantize_type supports range_abs_max
assert
activation_quantize_type
==
'moving_average_abs_max'
,
\
assert
activation_quantize_type
in
act_quantize_type
,
\
"Unsupported activation_quantize_type: %s. It can "
\
"only be moving_average_abs_max now."
\
"only be moving_average_abs_max
or lsq_act
now."
\
%
activation_quantize_type
bits_check
=
lambda
bits
:
isinstance
(
bits
,
int
)
\
...
...
python/paddle/fluid/contrib/slim/tests/CMakeLists.txt
浏览文件 @
341f68fe
...
...
@@ -252,6 +252,7 @@ if(WIN32)
list
(
REMOVE_ITEM TEST_OPS test_weight_quantization_mobilenetv1
)
list
(
REMOVE_ITEM TEST_OPS test_quantize_transpiler_v2
)
list
(
REMOVE_ITEM TEST_OPS test_imperative_qat_amp
)
list
(
REMOVE_ITEM TEST_OPS test_imperative_qat_lsq
)
endif
()
if
(
LINUX AND WITH_MKLDNN
)
...
...
@@ -505,6 +506,7 @@ if(WIN32)
test_moving_average_abs_max_scale_op
test_imperative_qat_channelwise
test_imperative_qat
test_imperative_qat_lsq
test_imperative_out_scale
test_graph
)
list
(
REMOVE_ITEM TEST_OPS
${
SINGLE_CARD_TEST_OPS
}
)
...
...
@@ -544,6 +546,7 @@ set_tests_properties(test_imperative_qat PROPERTIES TIMEOUT 200)
set_tests_properties
(
test_imperative_qat_fuse PROPERTIES TIMEOUT 200
)
set_tests_properties
(
test_imperative_out_scale PROPERTIES TIMEOUT 200
)
set_tests_properties
(
test_imperative_qat_user_defined PROPERTIES TIMEOUT 200
)
set_tests_properties
(
test_imperative_qat_lsq PROPERTIES TIMEOUT 300
)
if
(
LINUX AND WITH_MKLDNN
)
set_tests_properties
(
test_quant2_int8_mobilenetv1_mkldnn PROPERTIES TIMEOUT
...
...
python/paddle/fluid/contrib/slim/tests/test_imperative_qat_lsq.py
0 → 100644
浏览文件 @
341f68fe
# copyright (c) 2022 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
# http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.
from
__future__
import
print_function
import
os
import
numpy
as
np
import
random
import
time
import
tempfile
import
unittest
import
logging
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid
import
core
from
paddle.fluid.optimizer
import
SGDOptimizer
,
AdamOptimizer
,
MomentumOptimizer
from
paddle.fluid.contrib.slim.quantization
import
ImperativeQuantAware
from
paddle.fluid.dygraph.container
import
Sequential
from
paddle.nn
import
ReLU
,
ReLU6
,
LeakyReLU
,
Sigmoid
,
Softmax
,
PReLU
from
paddle.nn
import
Linear
,
Conv2D
,
Softmax
,
BatchNorm2D
,
MaxPool2D
from
paddle.fluid.log_helper
import
get_logger
from
paddle.fluid.dygraph.io
import
INFER_MODEL_SUFFIX
,
INFER_PARAMS_SUFFIX
from
paddle.nn.quant.quant_layers
import
QuantizedConv2D
,
QuantizedConv2DTranspose
from
paddle.fluid.framework
import
_test_eager_guard
from
imperative_test_utils
import
fix_model_dict
paddle
.
enable_static
()
os
.
environ
[
"CPU_NUM"
]
=
"1"
if
core
.
is_compiled_with_cuda
():
fluid
.
set_flags
({
"FLAGS_cudnn_deterministic"
:
True
})
_logger
=
get_logger
(
__name__
,
logging
.
INFO
,
fmt
=
'%(asctime)s-%(levelname)s: %(message)s'
)
class
ImperativeLenet
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
num_classes
=
10
):
super
(
ImperativeLenet
,
self
).
__init__
()
conv2d_w1_attr
=
fluid
.
ParamAttr
(
name
=
"conv2d_w_1"
)
conv2d_w2_attr
=
fluid
.
ParamAttr
(
name
=
"conv2d_w_2"
)
fc_w1_attr
=
fluid
.
ParamAttr
(
name
=
"fc_w_1"
)
fc_w2_attr
=
fluid
.
ParamAttr
(
name
=
"fc_w_2"
)
fc_w3_attr
=
fluid
.
ParamAttr
(
name
=
"fc_w_3"
)
conv2d_b2_attr
=
fluid
.
ParamAttr
(
name
=
"conv2d_b_2"
)
fc_b1_attr
=
fluid
.
ParamAttr
(
name
=
"fc_b_1"
)
fc_b2_attr
=
fluid
.
ParamAttr
(
name
=
"fc_b_2"
)
fc_b3_attr
=
fluid
.
ParamAttr
(
name
=
"fc_b_3"
)
self
.
features
=
Sequential
(
Conv2D
(
in_channels
=
1
,
out_channels
=
6
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
weight_attr
=
conv2d_w1_attr
,
bias_attr
=
False
),
BatchNorm2D
(
6
),
ReLU
(),
MaxPool2D
(
kernel_size
=
2
,
stride
=
2
),
Conv2D
(
in_channels
=
6
,
out_channels
=
16
,
kernel_size
=
5
,
stride
=
1
,
padding
=
0
,
weight_attr
=
conv2d_w2_attr
,
bias_attr
=
conv2d_b2_attr
),
BatchNorm2D
(
16
),
PReLU
(),
MaxPool2D
(
kernel_size
=
2
,
stride
=
2
))
self
.
fc
=
Sequential
(
Linear
(
in_features
=
400
,
out_features
=
120
,
weight_attr
=
fc_w1_attr
,
bias_attr
=
fc_b1_attr
),
LeakyReLU
(),
Linear
(
in_features
=
120
,
out_features
=
84
,
weight_attr
=
fc_w2_attr
,
bias_attr
=
fc_b2_attr
),
Sigmoid
(),
Linear
(
in_features
=
84
,
out_features
=
num_classes
,
weight_attr
=
fc_w3_attr
,
bias_attr
=
fc_b3_attr
),
Softmax
())
def
forward
(
self
,
inputs
):
x
=
self
.
features
(
inputs
)
x
=
fluid
.
layers
.
flatten
(
x
,
1
)
x
=
self
.
fc
(
x
)
return
x
class
TestImperativeQatLSQ
(
unittest
.
TestCase
):
def
set_vars
(
self
):
self
.
weight_quantize_type
=
'channel_wise_lsq_weight'
self
.
activation_quantize_type
=
'lsq_act'
self
.
onnx_format
=
False
self
.
fuse_conv_bn
=
False
def
func_qat
(
self
):
self
.
set_vars
()
imperative_qat
=
ImperativeQuantAware
(
weight_quantize_type
=
self
.
weight_quantize_type
,
activation_quantize_type
=
self
.
activation_quantize_type
,
fuse_conv_bn
=
self
.
fuse_conv_bn
)
seed
=
100
np
.
random
.
seed
(
seed
)
fluid
.
default_main_program
().
random_seed
=
seed
fluid
.
default_startup_program
().
random_seed
=
seed
paddle
.
disable_static
()
lenet
=
ImperativeLenet
()
lenet
=
fix_model_dict
(
lenet
)
imperative_qat
.
quantize
(
lenet
)
optimizer
=
MomentumOptimizer
(
learning_rate
=
0.1
,
parameter_list
=
lenet
.
parameters
(),
momentum
=
0.9
)
train_reader
=
paddle
.
batch
(
paddle
.
dataset
.
mnist
.
train
(),
batch_size
=
64
,
drop_last
=
True
)
test_reader
=
paddle
.
batch
(
paddle
.
dataset
.
mnist
.
test
(),
batch_size
=
32
)
epoch_num
=
2
for
epoch
in
range
(
epoch_num
):
lenet
.
train
()
for
batch_id
,
data
in
enumerate
(
train_reader
()):
x_data
=
np
.
array
([
x
[
0
].
reshape
(
1
,
28
,
28
)
for
x
in
data
]).
astype
(
'float32'
)
y_data
=
np
.
array
([
x
[
1
]
for
x
in
data
]).
astype
(
'int64'
).
reshape
(
-
1
,
1
)
img
=
fluid
.
dygraph
.
to_variable
(
x_data
)
label
=
fluid
.
dygraph
.
to_variable
(
y_data
)
out
=
lenet
(
img
)
acc
=
fluid
.
layers
.
accuracy
(
out
,
label
)
loss
=
fluid
.
layers
.
cross_entropy
(
out
,
label
)
avg_loss
=
paddle
.
mean
(
loss
)
avg_loss
.
backward
()
optimizer
.
minimize
(
avg_loss
)
lenet
.
clear_gradients
()
if
batch_id
%
100
==
0
:
_logger
.
info
(
"Train | At epoch {} step {}: loss = {:}, acc= {:}"
.
format
(
epoch
,
batch_id
,
avg_loss
.
numpy
(),
acc
.
numpy
()))
lenet
.
eval
()
eval_acc_top1_list
=
[]
with
paddle
.
no_grad
():
for
batch_id
,
data
in
enumerate
(
test_reader
()):
x_data
=
np
.
array
([
x
[
0
].
reshape
(
1
,
28
,
28
)
for
x
in
data
]).
astype
(
'float32'
)
y_data
=
np
.
array
([
x
[
1
]
for
x
in
data
]).
astype
(
'int64'
).
reshape
(
-
1
,
1
)
img
=
fluid
.
dygraph
.
to_variable
(
x_data
)
label
=
fluid
.
dygraph
.
to_variable
(
y_data
)
out
=
lenet
(
img
)
acc_top1
=
fluid
.
layers
.
accuracy
(
input
=
out
,
label
=
label
,
k
=
1
)
acc_top5
=
fluid
.
layers
.
accuracy
(
input
=
out
,
label
=
label
,
k
=
5
)
if
batch_id
%
100
==
0
:
eval_acc_top1_list
.
append
(
float
(
acc_top1
.
numpy
()))
_logger
.
info
(
"Test | At epoch {} step {}: acc1 = {:}, acc5 = {:}"
.
format
(
epoch
,
batch_id
,
acc_top1
.
numpy
(),
acc_top5
.
numpy
()))
# check eval acc
eval_acc_top1
=
sum
(
eval_acc_top1_list
)
/
len
(
eval_acc_top1_list
)
print
(
'eval_acc_top1'
,
eval_acc_top1
)
self
.
assertTrue
(
eval_acc_top1
>
0.9
,
msg
=
"The test acc {%f} is less than 0.9."
%
eval_acc_top1
)
def
test_qat
(
self
):
self
.
func_qat
()
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/nn/quant/lsq.py
0 → 100644
浏览文件 @
341f68fe
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
from
paddle.framework
import
core
from
paddle.fluid
import
dygraph_utils
from
paddle.utils
import
unique_name
from
paddle.framework
import
ParamAttr
from
paddle.fluid.framework
import
_varbase_creator
from
paddle.nn.initializer
import
Constant
from
paddle.fluid.data_feeder
import
check_variable_and_dtype
from
paddle.nn
import
functional
as
F
import
logging
from
paddle.fluid.log_helper
import
get_logger
from
paddle
import
in_dynamic_mode
from
paddle.nn
import
Layer
from
paddle.autograd
import
PyLayer
import
math
import
copy
def
round
(
x
):
sign
=
paddle
.
sign
(
x
)
x
=
sign
*
paddle
.
floor
(
paddle
.
abs
(
x
)
+
0.5
)
return
x
class
LsqFunc
(
PyLayer
):
@
staticmethod
def
forward
(
ctx
,
weight
,
alpha
,
g
,
Qn
,
Qp
,
per_channel
=
False
,
quant_axis
=
0
):
ctx
.
save_for_backward
(
weight
,
alpha
)
ctx
.
other
=
g
,
Qn
,
Qp
,
per_channel
,
quant_axis
if
per_channel
:
sizes
=
weight
.
shape
weight
=
weight
.
reshape
((
weight
.
shape
[
quant_axis
],
-
1
))
weight
=
weight
.
transpose
((
1
,
0
))
alpha
=
paddle
.
broadcast_to
(
alpha
,
weight
.
shape
)
quant_w
=
round
(
paddle
.
divide
(
weight
,
alpha
)).
clip
(
Qn
,
Qp
)
quant_w
=
quant_w
*
alpha
quant_w
=
quant_w
.
transpose
((
1
,
0
))
quant_w
=
quant_w
.
reshape
(
sizes
)
else
:
quant_w
=
round
(
paddle
.
divide
(
weight
,
alpha
)).
clip
(
Qn
,
Qp
)
quant_w
=
quant_w
*
alpha
return
quant_w
@
staticmethod
def
backward
(
ctx
,
grad_weight
):
weight
,
alpha
=
ctx
.
saved_tensor
()
g
,
Qn
,
Qp
,
per_channel
,
quant_axis
=
ctx
.
other
if
per_channel
:
sizes
=
weight
.
shape
weight
=
weight
.
reshape
((
weight
.
shape
[
quant_axis
],
-
1
))
weight
=
weight
.
transpose
((
1
,
0
))
alpha
=
paddle
.
broadcast_to
(
alpha
,
weight
.
shape
)
q_w
=
paddle
.
divide
(
weight
,
alpha
)
q_w
=
q_w
.
transpose
((
1
,
0
))
q_w
=
q_w
.
reshape
(
sizes
)
else
:
q_w
=
paddle
.
divide
(
weight
,
alpha
)
lower_flag
=
paddle
.
cast
((
q_w
<
Qn
),
'float32'
)
upper_flag
=
paddle
.
cast
((
q_w
>
Qp
),
'float32'
)
middle_flag
=
1.0
-
lower_flag
-
upper_flag
if
per_channel
:
grad_alpha
=
((
lower_flag
*
Qn
+
upper_flag
*
Qp
+
middle_flag
*
round
(
q_w
)
-
middle_flag
*
q_w
)
*
grad_weight
*
g
)
grad_alpha
=
grad_alpha
.
reshape
(
(
grad_alpha
.
shape
[
quant_axis
],
-
1
)).
sum
(
axis
=
1
)
else
:
grad_alpha
=
((
lower_flag
*
Qn
+
upper_flag
*
Qp
+
middle_flag
*
round
(
q_w
)
-
middle_flag
*
q_w
)
*
grad_weight
*
g
).
sum
().
unsqueeze
(
axis
=
0
)[
0
]
grad_weight
=
middle_flag
*
grad_weight
return
grad_weight
,
grad_alpha
class
LsqPlusActFunc
(
PyLayer
):
@
staticmethod
def
forward
(
ctx
,
x
,
alpha
,
beta
,
g
,
Qn
,
Qp
):
ctx
.
save_for_backward
(
x
,
alpha
,
beta
)
ctx
.
other
=
g
,
Qn
,
Qp
quant_x
=
round
(
paddle
.
divide
((
x
-
beta
),
alpha
)).
clip
(
Qn
,
Qp
)
return
quant_x
*
alpha
+
beta
@
staticmethod
def
backward
(
ctx
,
grad_x
):
x
,
alpha
,
beta
=
ctx
.
saved_tensor
()
g
,
Qn
,
Qp
=
ctx
.
other
q_x
=
(
x
-
beta
)
/
alpha
lower_flag
=
paddle
.
cast
((
q_x
<
Qn
),
'float32'
)
upper_flag
=
paddle
.
cast
((
q_x
>
Qp
),
'float32'
)
middle_flag
=
1.0
-
lower_flag
-
upper_flag
grad_alpha
=
((
lower_flag
*
Qn
+
upper_flag
*
Qp
+
middle_flag
*
round
(
q_x
)
-
middle_flag
*
q_x
)
*
grad_x
*
g
).
sum
().
unsqueeze
(
axis
=
0
)[
0
]
grad_beta
=
((
lower_flag
+
upper_flag
)
*
grad_x
*
g
).
sum
().
unsqueeze
(
axis
=
0
)[
0
]
grad_x
=
middle_flag
*
grad_x
return
grad_x
,
grad_alpha
,
grad_beta
class
FakeQuantActLSQPlus
(
Layer
):
def
__init__
(
self
,
quant_bits
,
all_postive
=
False
,
symmetric
=
False
,
batch_init
=
20
,
dtype
=
'float32'
,
name
=
None
,
reduce_type
=
None
):
super
(
FakeQuantActLSQPlus
,
self
).
__init__
()
'''
Args:
quant_bits(int): quantization bit number for weights.
all_postive(bool): whether unsigned or signed quantization, where True for unsigned quantization and False for signed quantization.
symmetric(bool): whether symmetric or asymmetric quantization.
batch_init(int): number of batches that collect Gaussian approximation for the weight distribution in each layer.
dtype(str): data type.
name(str): the name of the weight.
reduce_type(str): the reduce type which is needed when parallel training.
'''
self
.
bits
=
quant_bits
self
.
all_positive
=
all_postive
self
.
symmetric
=
symmetric
self
.
batch_init
=
batch_init
self
.
name
=
name
self
.
reduce_type
=
reduce_type
if
self
.
all_positive
:
# unsigned activation
self
.
Qn
=
0
self
.
Qp
=
2
**
self
.
bits
-
1
else
:
# signed activation
self
.
Qn
=
-
2
**
(
self
.
bits
-
1
)
self
.
Qp
=
2
**
(
self
.
bits
-
1
)
-
1
scale_prefix
=
"{}.scale"
.
format
(
name
)
if
name
else
'quant_dequant.scale'
self
.
_scale_name
=
unique_name
.
generate
(
scale_prefix
)
s_attr
=
ParamAttr
(
name
=
self
.
_scale_name
,
initializer
=
Constant
(
1.0
),
trainable
=
True
)
self
.
s
=
self
.
create_parameter
(
shape
=
[
1
],
attr
=
s_attr
,
dtype
=
'float32'
)
self
.
s
.
stop_gradient
=
False
if
not
self
.
symmetric
:
beta_prefix
=
"{}.beta"
.
format
(
name
)
if
name
else
'quant_dequant.beta'
self
.
_beta_name
=
unique_name
.
generate
(
beta_prefix
)
beta_attr
=
ParamAttr
(
name
=
self
.
_beta_name
,
initializer
=
Constant
(
0.0
),
trainable
=
True
)
self
.
beta
=
self
.
create_parameter
(
shape
=
[
1
],
attr
=
beta_attr
,
dtype
=
'float32'
)
self
.
beta
.
stop_gradient
=
False
self
.
init_state
=
0
def
forward
(
self
,
activation
):
if
self
.
reduce_type
==
"max"
:
paddle
.
distributed
.
all_reduce
(
self
.
s
,
op
=
paddle
.
distributed
.
ReduceOp
.
MAX
)
if
not
self
.
symmetric
and
self
.
reduce_type
==
"max"
:
paddle
.
distributed
.
all_reduce
(
self
.
beta
,
op
=
paddle
.
distributed
.
ReduceOp
.
MAX
)
if
self
.
init_state
==
0
:
self
.
g
=
paddle
.
to_tensor
(
1.0
/
math
.
sqrt
(
activation
.
numel
()
*
self
.
Qp
))
min_a
=
paddle
.
min
(
activation
.
detach
())
max_a
=
paddle
.
max
(
activation
.
detach
())
self
.
s
.
set_value
((
max_a
-
min_a
)
/
(
self
.
Qp
-
self
.
Qn
))
if
not
self
.
symmetric
:
self
.
beta
.
set_value
(
min_a
-
self
.
s
*
self
.
Qn
)
self
.
init_state
+=
1
elif
self
.
init_state
<
self
.
batch_init
:
min_a
=
paddle
.
min
(
activation
.
detach
())
max_a
=
paddle
.
max
(
activation
.
detach
())
self
.
s
.
set_value
(
self
.
s
*
0.9
+
0.1
*
(
max_a
-
min_a
)
/
(
self
.
Qp
-
self
.
Qn
))
if
not
self
.
symmetric
:
self
.
beta
.
set_value
(
self
.
s
*
0.9
+
0.1
*
(
min_a
-
self
.
s
*
self
.
Qn
))
self
.
init_state
+=
1
else
:
self
.
init_state
+=
1
activation
.
stop_gradient
=
False
if
not
self
.
symmetric
:
q_a
=
LsqPlusActFunc
.
apply
(
activation
,
self
.
s
,
self
.
beta
,
self
.
g
,
self
.
Qn
,
self
.
Qp
)
else
:
q_a
=
LsqFunc
.
apply
(
activation
,
self
.
s
,
self
.
g
,
self
.
Qn
,
self
.
Qp
,
per_channel
=
False
)
return
q_a
class
FakeQuantWeightLSQPlus
(
Layer
):
def
__init__
(
self
,
quant_bits
,
all_postive
=
False
,
per_channel
=
False
,
batch_init
=
20
,
channel_num
=
None
,
quant_linear
=
False
,
dtype
=
'float32'
,
name
=
None
,
reduce_type
=
None
):
super
(
FakeQuantWeightLSQPlus
,
self
).
__init__
()
'''
Args:
quant_bits(int): quantization bit number for weights.
all_postive(bool): whether unsigned or signed quantization, where True for unsigned quantization and False for signed quantization.
per_channel(bool): whether layer-wise or channel-wise quantization, where True for layer-wise quantization and False for channel-wise quantization.
batch_init(int): number of batches that collect Gaussian approximation for the weight distribution in each layer.
channel_num(int): the channel number of the weight which is needed when per_channel is True.
quant_linear(bool): whether the weight is from Linear.
dtype(str): data type.
name(str): the name of the weight.
reduce_type(str): the reduce type which is needed when parallel training.
'''
self
.
bits
=
quant_bits
self
.
all_positive
=
all_postive
self
.
per_channel
=
per_channel
self
.
quant_linear
=
quant_linear
self
.
batch_init
=
batch_init
self
.
name
=
name
self
.
quant_axis
=
1
if
quant_linear
else
0
self
.
collect_axis
=
0
if
quant_linear
else
1
self
.
reduce_type
=
reduce_type
if
self
.
all_positive
:
# unsigned weight
self
.
Qn
=
0
self
.
Qp
=
2
**
self
.
bits
-
1
else
:
# signed weight
self
.
Qn
=
-
2
**
(
self
.
bits
-
1
)
self
.
Qp
=
2
**
(
self
.
bits
-
1
)
-
1
self
.
init_state
=
0
scale_prefix
=
"{}.scale"
.
format
(
name
)
if
name
else
'quant_dequant.scale'
self
.
_scale_name
=
unique_name
.
generate
(
scale_prefix
)
s_attr
=
ParamAttr
(
name
=
self
.
_scale_name
,
initializer
=
Constant
(
1.0
),
trainable
=
True
)
self
.
s
=
self
.
create_parameter
(
shape
=
[
channel_num
],
attr
=
s_attr
,
dtype
=
dtype
)
self
.
s
.
stop_gradient
=
False
def
forward
(
self
,
weight
):
if
self
.
reduce_type
==
"max"
:
paddle
.
distributed
.
all_reduce
(
self
.
s
,
op
=
paddle
.
distributed
.
ReduceOp
.
MAX
)
if
self
.
init_state
==
0
:
self
.
g
=
paddle
.
to_tensor
(
1.0
/
math
.
sqrt
(
weight
.
numel
()
*
self
.
Qp
))
self
.
div
=
2
**
self
.
bits
-
1
if
self
.
per_channel
:
weight_tmp
=
weight
.
detach
().
reshape
((
weight
.
shape
[
0
],
-
1
))
mean
=
paddle
.
mean
(
weight_tmp
,
axis
=
self
.
collect_axis
)
std
=
paddle
.
std
(
weight_tmp
,
axis
=
self
.
collect_axis
)
s
=
paddle
.
max
(
paddle
.
stack
(
[
paddle
.
abs
(
mean
-
3
*
std
),
paddle
.
abs
(
mean
+
3
*
std
)]),
axis
=
0
)
self
.
s
.
set_value
(
s
/
self
.
div
)
else
:
mean
=
paddle
.
mean
(
weight
.
detach
())
std
=
paddle
.
std
(
weight
.
detach
())
self
.
s
.
set_value
(
max
([
paddle
.
abs
(
mean
-
3
*
std
),
paddle
.
abs
(
mean
+
3
*
std
)
])
/
self
.
div
)
self
.
init_state
+=
1
elif
self
.
init_state
<
self
.
batch_init
:
self
.
div
=
2
**
self
.
bits
-
1
if
self
.
per_channel
:
weight_tmp
=
weight
.
detach
().
reshape
((
weight
.
shape
[
0
],
-
1
))
mean
=
paddle
.
mean
(
weight_tmp
,
axis
=
self
.
collect_axis
)
std
=
paddle
.
std
(
weight_tmp
,
axis
=
self
.
collect_axis
)
s
=
paddle
.
max
(
paddle
.
stack
(
[
paddle
.
abs
(
mean
-
3
*
std
),
paddle
.
abs
(
mean
+
3
*
std
)]),
axis
=
0
)
self
.
s
.
set_value
(
s
*
0.9
+
0.1
*
s
/
self
.
div
)
else
:
mean
=
paddle
.
mean
(
weight
.
detach
())
std
=
paddle
.
std
(
weight
.
detach
())
self
.
s
.
set_value
(
self
.
s
*
0.9
+
0.1
*
max
(
[
paddle
.
abs
(
mean
-
3
*
std
),
paddle
.
abs
(
mean
+
3
*
std
)])
/
self
.
div
)
self
.
init_state
+=
1
elif
self
.
init_state
==
self
.
batch_init
:
self
.
init_state
+=
1
weight
.
stop_gradient
=
False
w_q
=
LsqFunc
.
apply
(
weight
,
self
.
s
,
self
.
g
,
self
.
Qn
,
self
.
Qp
,
self
.
per_channel
,
self
.
quant_axis
)
return
w_q
python/paddle/nn/quant/quant_layers.py
浏览文件 @
341f68fe
...
...
@@ -26,6 +26,7 @@ from paddle.fluid.log_helper import get_logger
from
paddle
import
_C_ops
,
_legacy_C_ops
from
paddle
import
in_dynamic_mode
from
paddle.nn
import
Layer
from
paddle.nn.quant.lsq
import
FakeQuantActLSQPlus
,
FakeQuantWeightLSQPlus
__all__
=
[
'FakeQuantAbsMax'
,
...
...
@@ -653,7 +654,8 @@ class QuantizedLinear(Layer):
dtype
=
self
.
_dtype
,
quant_on_weight
=
True
,
channel_num
=
self
.
weight
.
shape
[
self
.
_linear_quant_axis
],
quant_axis
=
self
.
_linear_quant_axis
)
quant_axis
=
self
.
_linear_quant_axis
,
quant_linear
=
True
)
if
act_quant_layer
is
not
None
:
self
.
_fake_quant_input
=
act_quant_layer
()
...
...
@@ -946,10 +948,29 @@ def _get_fake_quant_type(quant_type, **kwargs):
assert
call_args
[
"channel_num"
]
is
not
None
,
(
"You need to input channel_num"
"when you use channel_wise_abs_max strategy."
)
elif
quant_type
==
'lsq_weight'
:
call_args
[
"all_postive"
]
=
kwargs
.
get
(
"all_postive"
,
False
)
call_args
[
"per_channel"
]
=
False
call_args
[
"channel_num"
]
=
1
call_args
[
"quant_linear"
]
=
kwargs
.
get
(
"quant_linear"
,
False
)
elif
quant_type
==
'channel_wise_lsq_weight'
:
quant_type
=
'lsq_weight'
call_args
[
"all_postive"
]
=
kwargs
.
get
(
"all_postive"
,
False
)
call_args
[
"per_channel"
]
=
True
call_args
[
"channel_num"
]
=
kwargs
.
get
(
"channel_num"
,
None
)
call_args
[
"quant_linear"
]
=
kwargs
.
get
(
"quant_linear"
,
False
)
assert
call_args
[
"channel_num"
]
is
not
None
,
(
"You need to input channel_num"
"when you use channel_wise_abs_max strategy."
)
elif
quant_type
==
'lsq_act'
:
call_args
[
"all_postive"
]
=
kwargs
.
get
(
"all_postive"
,
False
)
call_args
[
"symmetric"
]
=
kwargs
.
get
(
"symmetric"
,
True
)
fake_quant_map
=
{
'abs_max'
:
FakeQuantAbsMax
,
'moving_average_abs_max'
:
FakeQuantMovingAverageAbsMax
,
'channel_wise_abs_max'
:
FakeQuantChannelWiseAbsMax
'channel_wise_abs_max'
:
FakeQuantChannelWiseAbsMax
,
'lsq_weight'
:
FakeQuantWeightLSQPlus
,
'lsq_act'
:
FakeQuantActLSQPlus
}
return
fake_quant_map
[
quant_type
](
**
call_args
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录