Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
bc442429
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
大约 2 年 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
bc442429
编写于
10月 14, 2022
作者:
G
gushiqiao
提交者:
GitHub
10月 14, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add reconstuction quant algorithm (#1457)
上级
880ad20b
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
1196 addition
and
57 deletion
+1196
-57
example/post_training_quantization/pytorch_yolo_series/configs/yolov6s_fine_tune.yaml
...zation/pytorch_yolo_series/configs/yolov6s_fine_tune.yaml
+32
-0
example/post_training_quantization/pytorch_yolo_series/fine_tune.py
...st_training_quantization/pytorch_yolo_series/fine_tune.py
+121
-0
paddleslim/quant/__init__.py
paddleslim/quant/__init__.py
+1
-0
paddleslim/quant/quanter.py
paddleslim/quant/quanter.py
+1
-1
paddleslim/quant/reconstruction_quantization.py
paddleslim/quant/reconstruction_quantization.py
+975
-0
tests/test_reconstruct_quantization.py
tests/test_reconstruct_quantization.py
+66
-56
未找到文件。
example/post_training_quantization/pytorch_yolo_series/configs/yolov6s_fine_tune.yaml
0 → 100755
浏览文件 @
bc442429
arch
:
YOLOv6
model_dir
:
./yolov6s.onnx
dataset_dir
:
/dataset/coco/
model_filename
:
model.pdmodel
params_filename
:
model.pdiparams
train_image_dir
:
train2017
val_image_dir
:
val2017
train_anno_path
:
annotations/instances_train2017.json
val_anno_path
:
annotations/instances_val2017.json
skip_tensor_list
:
None
regions
:
[[
'
x2paddle_image_arrays'
,
'
relu_8.tmp_0'
],
[
'
relu_8.tmp_0'
,
'
relu_15.tmp_0'
],
[
'
relu_15.tmp_0'
,
'
relu_21.tmp_0'
],
[
'
concat_1.tmp_0'
,
'
relu_26.tmp_0'
],
[
'
concat_2.tmp_0'
,
'
relu_30.tmp_0'
],
[
'
relu_30.tmp_0'
,
'
concat_4.tmp_0'
],
[
'
relu_30.tmp_0'
,
'
relu_31.tmp_0'
],
[
'
concat_3.tmp_0'
,
'
relu_35.tmp_0'
],
[
'
relu_35.tmp_0'
,
'
relu_36.tmp_0'
],
[
'
concat_5.tmp_0'
,
'
concat_10.tmp_0'
],
[
'
relu_35.tmp_0'
,
'
concat_8.tmp_0'
]]
region_weights_names
:
[[
'
conv2d_0.w_0'
,
'
conv2d_1.w_0'
,
'
conv2d_2.w_0'
,
'
conv2d_3.w_0'
,
'
conv2d_4.w_0'
,
'
conv2d_5.w_0'
,
'
conv2d_6.w_0'
,
'
conv2d_7.w_0'
,
'
conv2d_8.w_0'
],
[
'
conv2d_9.w_0'
,
'
conv2d_10.w_0'
,
'
conv2d_11.w_0'
,
'
conv2d_12.w_0'
,
'
conv2d_13.w_0'
,
'
conv2d_14.w_0'
,
'
conv2d_15.w_0'
],
[
'
conv2d_16.w_0'
,
'
conv2d_17.w_0'
,
'
conv2d_18.w_0'
,
'
conv2d_19.w_0'
,
'
conv2d_20.w_0'
,
'
conv2d_21.w_0'
],
[
'
conv2d_22.w_0'
,
'
conv2d_23.w_0'
,
'
conv2d_24.w_0'
,
'
conv2d_25.w_0'
,
'
conv2d_26.w_0'
],
[
'
conv2d_27.w_0'
,
'
conv2d_28.w_0'
,
'
conv2d_29.w_0'
,
'
conv2d_30.w_0'
],
[
'
conv2d_32.w_0'
,
'
conv2d_34.w_0'
,
'
conv2d_35.w_0'
,
'
conv2d_37.w_0'
,
'
conv2d_38.w_0'
,
'
conv2d_39.w_0'
],
[
'
conv2d_31.w_0'
],
[
'
conv2d_33.w_0'
,
'
conv2d_36.w_0'
,
'
conv2d_40.w_0'
,
'
conv2d_41.w_0'
],
[
'
conv2d_42.w_0'
],
[
'
conv2d_44.w_0'
,
'
conv2d_47.w_0'
,
'
conv2d_51.w_0'
,
'
conv2d_52.w_0'
,
'
conv2d_53.w_0'
,
'
conv2d_54.w_0'
,
'
conv2d_55.w_0'
,
'
conv2d_56.w_0'
,
'
conv2d_57.w_0'
,
'
conv2d_58.w_0'
],
[
'
conv2d_43.w_0'
,
'
conv2d_45.w_0'
,
'
conv2d_46.w_0'
,
'
conv2d_49.w_0'
,
'
conv2d_48.w_0'
,
'
conv2d_50.w_0'
],]
\ No newline at end of file
example/post_training_quantization/pytorch_yolo_series/fine_tune.py
0 → 100755
浏览文件 @
bc442429
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
import
numpy
as
np
import
argparse
import
paddle
from
paddleslim.common
import
load_config
,
load_onnx_model
from
paddleslim.quant
import
quant_post_static
from
paddleslim.quant
import
quant_recon_static
from
dataset
import
COCOTrainDataset
def
argsparser
():
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
.
add_argument
(
'--config_path'
,
type
=
str
,
default
=
None
,
help
=
"path of post training quantization config."
,
required
=
True
)
parser
.
add_argument
(
'--save_dir'
,
type
=
str
,
default
=
'ptq_out'
,
help
=
"directory to save compressed model."
)
parser
.
add_argument
(
'--devices'
,
type
=
str
,
default
=
'gpu'
,
help
=
"which device used to compress."
)
parser
.
add_argument
(
'--algo'
,
type
=
str
,
default
=
'avg'
,
help
=
"post quant algo."
)
parser
.
add_argument
(
'--round_type'
,
type
=
str
,
default
=
'adaround'
,
help
=
"round type."
)
parser
.
add_argument
(
'--gpu'
,
type
=
int
,
default
=
0
,
help
=
'gpu index'
)
parser
.
add_argument
(
'--recon_level'
,
type
=
str
,
default
=
'layer-wise'
,
help
=
'reconstruction level'
)
parser
.
add_argument
(
'--simulate_activation_quant'
,
type
=
bool
,
default
=
False
,
help
=
'simulate activation quant'
)
return
parser
def
main
():
global
config
config
=
load_config
(
FLAGS
.
config_path
)
input_name
=
'x2paddle_image_arrays'
if
config
[
'arch'
]
==
'YOLOv6'
else
'x2paddle_images'
dataset
=
COCOTrainDataset
(
dataset_dir
=
config
[
'dataset_dir'
],
image_dir
=
config
[
'val_image_dir'
],
anno_path
=
config
[
'val_anno_path'
],
input_name
=
input_name
)
train_loader
=
paddle
.
io
.
DataLoader
(
dataset
,
batch_size
=
1
,
shuffle
=
True
,
drop_last
=
True
,
num_workers
=
0
)
place
=
paddle
.
CUDAPlace
(
FLAGS
.
gpu
)
if
FLAGS
.
devices
==
'gpu'
else
paddle
.
CPUPlace
()
exe
=
paddle
.
static
.
Executor
(
place
)
# since the type pf model converted from pytorch is onnx,
# use load_onnx_model firstly and rename the model_dir
load_onnx_model
(
config
[
"model_dir"
])
inference_model_path
=
config
[
"model_dir"
].
rstrip
().
rstrip
(
'.onnx'
)
+
'_infer'
quant_recon_static
(
executor
=
exe
,
model_dir
=
inference_model_path
,
quantize_model_path
=
FLAGS
.
save_dir
,
data_loader
=
train_loader
,
model_filename
=
'model.pdmodel'
,
params_filename
=
'model.pdiparams'
,
batch_size
=
32
,
batch_nums
=
10
,
algo
=
FLAGS
.
algo
,
hist_percent
=
0.999
,
is_full_quantize
=
False
,
bias_correction
=
False
,
onnx_format
=
False
,
weight_quantize_type
=
'channel_wise_abs_max'
,
recon_level
=
FLAGS
.
recon_level
,
simulate_activation_quant
=
FLAGS
.
simulate_activation_quant
,
regions
=
config
[
'regions'
],
region_weights_names
=
config
[
'region_weights_names'
],
skip_tensor_list
=
config
[
'skip_tensor_list'
]
if
'skip_tensor_list'
in
config
else
None
,
epochs
=
20
,
lr
=
0.1
)
if
__name__
==
'__main__'
:
paddle
.
enable_static
()
parser
=
argsparser
()
FLAGS
=
parser
.
parse_args
()
assert
FLAGS
.
devices
in
[
'cpu'
,
'gpu'
,
'xpu'
,
'npu'
]
paddle
.
set_device
(
FLAGS
.
devices
)
main
()
paddleslim/quant/__init__.py
100644 → 100755
浏览文件 @
bc442429
...
@@ -26,6 +26,7 @@ try:
...
@@ -26,6 +26,7 @@ try:
from
.quanter
import
quant_aware
,
convert
,
quant_post_static
,
quant_post_dynamic
from
.quanter
import
quant_aware
,
convert
,
quant_post_static
,
quant_post_dynamic
from
.quanter
import
quant_post
,
quant_post_only_weight
from
.quanter
import
quant_post
,
quant_post_only_weight
from
.quant_aware_with_infermodel
import
quant_aware_with_infermodel
,
export_quant_infermodel
from
.quant_aware_with_infermodel
import
quant_aware_with_infermodel
,
export_quant_infermodel
from
.reconstruction_quantization
import
quant_recon_static
if
platform
.
system
().
lower
()
==
'linux'
:
if
platform
.
system
().
lower
()
==
'linux'
:
from
.post_quant_hpo
import
quant_post_hpo
from
.post_quant_hpo
import
quant_post_hpo
else
:
else
:
...
...
paddleslim/quant/quanter.py
浏览文件 @
bc442429
...
@@ -813,4 +813,4 @@ def pact(x, name=None):
...
@@ -813,4 +813,4 @@ def pact(x, name=None):
def
get_pact_optimizer
():
def
get_pact_optimizer
():
return
paddle
.
fluid
.
optimizer
.
MomentumOptimizer
(
0.0001
,
0.9
)
return
paddle
.
fluid
.
optimizer
.
MomentumOptimizer
(
0.0001
,
0.9
)
\ No newline at end of file
paddleslim/quant/r
ounding_optimizer
.py
→
paddleslim/quant/r
econstruction_quantization
.py
100644 → 100755
浏览文件 @
bc442429
import
numpy
as
np
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
import
time
#
import
sys
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
copy
import
logging
import
logging
import
math
import
os
import
re
import
shutil
import
sys
import
time
import
numpy
as
np
import
paddle
import
paddle
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
import
six
from
paddle.fluid.contrib.slim.quantization
import
PostTrainingQuantization
import
math
from
paddle.fluid.contrib.slim.quantization
import
utils
import
copy
from
..dist
import
merge
from
..dist
import
merge
from
..core.graph_wrapper
import
GraphWrapper
from
..core.graph_wrapper
import
GraphWrapper
from
..common
import
get_logger
from
..common
import
get_logger
from
paddle.fluid.contrib.slim.quantization
import
utils
__all__
=
[
'ReconstructionQuantization'
,
]
_logger
=
get_logger
(
__name__
,
logging
.
INFO
,
fmt
=
'%(asctime)s-%(levelname)s: %(message)s'
,
)
_logger
=
get_logger
(
__name__
,
logging
.
INFO
,
fmt
=
'%(asctime)s-%(levelname)s: %(message)s'
)
GAMMA
=
-
0.1
GAMMA
=
-
0.1
ZETA
=
1.1
ZETA
=
1.1
__all__
=
[
'RoundingOptimizer'
,
]
class
RoundingOptimizerLoss
(
object
):
def
__init__
(
self
,
program
,
weight_block_names
=
None
,
round_loss_mode
=
'relaxation'
,
rec_loss_mode
=
'mse'
,
beta_mode
=
'const'
,
weight
=
0.1
,):
"""
The loss function of Rounding Optimizer.
Args:
program(Program): The student program.
weight_block_names(list, optional): The weight names inside a block.
round_loss_mode(str): The rounding loss function mode.
rec_loss_mode(str): The reconstruction loss function mode.
beta_mode(str): The parameter beta mode.
Returns:
total_loss(Variable): The sum of rounding loss and reconstruction loss.
rec_loss(Variable): The reconstruction loss.
round_loss(Variable): The rounding loss.
"""
self
.
program
=
program
self
.
round_loss_mode
=
round_loss_mode
self
.
weight
=
weight
self
.
rec_loss_mode
=
rec_loss_mode
self
.
weight_block_names
=
weight_block_names
self
.
beta_mode
=
beta_mode
def
compute_soft_rounding
(
self
,
alpha_v
):
class
Collections
(
object
):
return
paddle
.
clip
(
paddle
.
nn
.
functional
.
sigmoid
(
alpha_v
)
*
(
ZETA
-
GAMMA
)
+
GAMMA
,
0
,
1
)
def
__init__
(
self
,
**
kwargs
):
self
.
_config
=
dict
()
for
k
,
v
in
kwargs
.
items
():
self
.
_config
[
k
]
=
v
def
get_loss
(
self
,
student_tensor
,
teacher_tensor
,
scheduler
):
def
_get_config
(
self
):
if
self
.
rec_loss_mode
==
'mse'
:
return
self
.
_config
rec_loss
=
paddle
.
nn
.
functional
.
mse_loss
(
student_tensor
,
teacher_tensor
)
else
:
raise
ValueError
(
'Not supported reconstruction loss function: {}'
.
format
(
self
.
rec_loss
))
if
self
.
beta_mode
==
'const'
:
self
.
beta
=
3
else
:
self
.
beta
=
scheduler
.
get_lr
()
if
self
.
round_loss_mode
==
'relaxation'
:
class
ReconstructionQuantization
(
PostTrainingQuantization
):
round_loss
=
0.0
"""
for
name
in
self
.
weight_block_names
:
Utilizing reconstruction quantization method to quantize the FP32 model,
alpha_v
=
self
.
program
.
global_block
().
var
(
name
+
'.alpha'
)
and it uses calibrate data to get the quantization information for all
h_v
=
self
.
compute_soft_rounding
(
alpha_v
)
quantized variables.
round_loss
+=
self
.
weight
*
paddle
.
sum
(
-
paddle
.
pow
(
paddle
.
abs
(
2
*
h_v
-
1
),
self
.
beta
)
+
1
)
"""
else
:
raise
NotImplementedError
total_loss
=
rec_loss
+
round_loss
return
total_loss
,
rec_loss
,
round_loss
def
__init__
(
self
,
PTQCollections
,
RSQCollections
):
'''
Args:
PTQCollections(Collections): The parameters set required for post training quantization.
RSQCollections(Collections): The parameters set required for reconstruction quantization.
Returns:
None
'''
super
().
__init__
(
**
PTQCollections
.
_get_config
())
self
.
_config
=
RSQCollections
.
_get_config
()
class
RoundingOptimizer
(
object
):
def
quantize
(
self
):
'''
Load the FP32 model, and use the calibrate data to calculate the forward-stage.
Based on the sample data, we can get the quantization information, and obtain
the final quantized model.
Args:
None
Returns:
the program of quantized model.
'''
self
.
_load_model_data
()
self
.
_collect_target_varnames
()
self
.
_set_activation_persistable
()
if
self
.
_algo
in
[
"KL"
,
"hist"
]:
self
.
_preparation
()
self
.
_sampling_threshold
()
self
.
_calculate_threshold
()
self
.
_reset_activation_persistable
()
self
.
_reconstruction
()
self
.
_postprocessing
()
return
self
.
_program
def
_preparation
(
self
):
batch_id
=
0
with
utils
.
tqdm
(
total
=
self
.
_batch_nums
,
bar_format
=
'Preparation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}'
,
ncols
=
80
,
)
as
t
:
for
data
in
self
.
_data_loader
():
self
.
_executor
.
run
(
program
=
self
.
_program
,
feed
=
data
,
fetch_list
=
self
.
_fetch_list
,
return_numpy
=
False
,
scope
=
self
.
_scope
,
)
self
.
_collect_activation_abs_min_max
()
batch_id
+=
1
t
.
update
()
if
self
.
_batch_nums
and
batch_id
>=
self
.
_batch_nums
:
break
self
.
_init_sampling_act_histogram
()
def
_sampling_threshold
(
self
):
batch_id
=
0
with
utils
.
tqdm
(
total
=
self
.
_batch_nums
,
bar_format
=
'Sampling stage, Run batch:|{bar}| {n_fmt}/{total_fmt}'
,
ncols
=
80
,
)
as
t
:
for
data
in
self
.
_data_loader
():
self
.
_executor
.
run
(
program
=
self
.
_program
,
feed
=
data
,
fetch_list
=
self
.
_fetch_list
,
return_numpy
=
False
,
scope
=
self
.
_scope
,
)
self
.
_sampling
()
batch_id
+=
1
t
.
update
()
if
self
.
_batch_nums
and
batch_id
>=
self
.
_batch_nums
:
break
def
_calculate_threshold
(
self
):
if
self
.
_algo
==
'avg'
:
for
var_name
in
self
.
_quantized_act_var_name
:
self
.
_quantized_threshold
[
var_name
]
=
\
np
.
array
(
self
.
_quantized_var_avg
[
var_name
]).
mean
()
self
.
_scale_dict
=
self
.
_quantized_threshold
elif
self
.
_algo
in
[
"KL"
,
"hist"
]:
self
.
_calculate_kl_hist_threshold
()
self
.
_scale_dict
=
self
.
_quantized_var_threshold
else
:
self
.
_scale_dict
=
self
.
_quantized_threshold
def
_reconstruction
(
self
):
reconstruction_quanter
=
ReconstructionQuanter
(
data_loader
=
self
.
_data_loader
,
fp32_program
=
self
.
_program
,
feed_list
=
self
.
_feed_list
,
fetch_list
=
self
.
_fetch_list
,
exe
=
self
.
_executor
,
scope
=
self
.
_scope
,
place
=
self
.
_place
,
quantized_op_pairs
=
self
.
_quantized_op_pairs
,
weight_quantize_type
=
self
.
_weight_quantize_type
,
scale_dict
=
copy
.
deepcopy
(
self
.
_scale_dict
),
regions
=
self
.
_config
[
'regions'
],
region_weights_names
=
self
.
_config
[
'region_weights_names'
],
recon_level
=
self
.
_config
[
'recon_level'
],
simulate_activation_quant
=
self
.
_config
[
'simulate_activation_quant'
],
num_iterations
=
self
.
_batch_nums
,
lr
=
self
.
_config
[
'lr'
],
bias_correction
=
self
.
_bias_correction
,
epochs
=
self
.
_config
[
'epochs'
],
scale_trainable
=
self
.
_config
[
'scale_trainable'
])
self
.
_program
=
reconstruction_quanter
.
_run
()
def
_postprocessing
(
self
):
if
self
.
_algo
is
'min_max'
:
self
.
_save_input_threhold
()
else
:
self
.
_update_program
()
# save out_threshold for quantized ops.
self
.
_save_output_threshold
()
if
any
(
op_type
in
self
.
_quantizable_op_type
for
op_type
in
self
.
_dynamic_quantize_op_type
):
self
.
_collect_dynamic_quantize_op_threshold
(
self
.
_dynamic_quantize_op_type
,
)
# Move sub blocks persistable var to global block
global_block
=
self
.
_program
.
global_block
()
for
_op
in
global_block
.
ops
:
if
_op
.
type
==
"while"
:
_block_id
=
_op
.
attr
(
"sub_block"
).
id
_block
=
self
.
_program
.
block
(
_block_id
)
persistables
=
[]
for
_name
,
_var
in
_block
.
vars
.
items
():
if
_var
.
persistable
:
global_block
.
_clone_variable
(
_var
)
persistables
.
append
(
_name
)
for
_name
in
persistables
:
_block
.
_remove_var
(
_name
)
persistables
.
extend
(
_op
.
input
(
'X'
))
_op
.
desc
.
set_input
(
"X"
,
persistables
)
class
ReconstructionQuanter
(
object
):
def
__init__
(
self
,
def
__init__
(
self
,
data_loader
,
data_loader
,
fp32_program
,
fp32_program
,
...
@@ -90,16 +212,18 @@ class RoundingOptimizer(object):
...
@@ -90,16 +212,18 @@ class RoundingOptimizer(object):
quantized_op_pairs
,
quantized_op_pairs
,
weight_quantize_type
,
weight_quantize_type
,
scale_dict
,
scale_dict
,
blocks
,
regions
,
block_weights_names
,
region_weights_names
,
round_type
,
recon_level
,
simulate_activation_quant
,
num_iterations
=
1000
,
num_iterations
=
1000
,
lr
=
0.1
,
lr
=
0.1
,
bias_correction
=
False
,
bias_correction
=
False
,
epochs
=
20
,
epochs
=
20
,
):
scale_trainable
=
False
,
drop_prob
=
0.5
):
'''
'''
R
ounding Optimiz
er, used to optimize the rounding policy
R
econstruction Quant
er, used to optimize the rounding policy
by reconstructing the intermediate output.
by reconstructing the intermediate output.
Args:
Args:
...
@@ -108,44 +232,51 @@ class RoundingOptimizer(object):
...
@@ -108,44 +232,51 @@ class RoundingOptimizer(object):
return a batch every time.
return a batch every time.
executor(fluid.Executor): The executor to load, run and save the
executor(fluid.Executor): The executor to load, run and save the
quantized model.
quantized model.
scope(fluid.Scope, optional): The scope of the program, use it to load
scope(fluid.Scope, optional): The scope of the program, use it to load
and save variables. If scope=None, get scope by global_scope().
and save variables. If scope=None, get scope by global_scope().
place(CPUPlace()|CUDAPlace(N)): This parameter represents
place(CPUPlace()|CUDAPlace(N)): This parameter represents
paddle run on which device.
paddle run on which device.
quantized_op_pairs(dict, optional): Mapping of op's weight name
quantized_op_pairs(dict, optional): Mapping of op's weight name
and output var name, where key of dict is the weight name of
and output var name, where key of dict is the weight name of
op, and value is the output var name of op.
op, and value is the output var name of op.
weight_quantize_type(str): quantization type for weights,
weight_quantize_type(str): quantization type for weights,
support 'abs_max' and 'channel_wise_abs_max'. This param only specifies
support 'abs_max' and 'channel_wise_abs_max'. This param only specifies
the fake ops in saving quantized model, and we save the scale obtained
the fake ops in saving quantized model, and we save the scale obtained
by post training quantization in fake ops. Compared to 'abs_max',
by post training quantization in fake ops. Compared to 'abs_max',
the model accuracy is usually higher when it is 'channel_wise_abs_max'.
the model accuracy is usually higher when it is 'channel_wise_abs_max'.
scale_dict(dict, optional): Mapping of var's name and var's scales, where key
scale_dict(dict, optional): Mapping of var's name and var's scales, where key
of dict is the var name, and value is the quant scales of var.
of dict is the var name, and value is the quant scales of var.
round_type(str, optional): The rounding policy of converting the quantized
recon_level(str, optional): The type of reconstruction granularity.
weights value float->int. Currently supports ['round', 'brecq', 'qdrop']
Currently support ['layer-wise', 'region-wise'] types. Default is layer-wise.
methods.
simulate_activation_quant(bool, optional): Whether we need the noise caused by activation
'adaround' is refer to https://arxiv.org/abs/2004.10568,
quantization during the reconstruction process.
'brecq' is refer to https://arxiv.org/pdf/2102.05426,
regions(list[list], optional): The list of some regions, each region is a subgraph of
'qdrop' is refer to https://arxiv.org/pdf/2203.05740.
fp32 program and it will have exact 1 input operation and 1 output operation. When
blocks(list[list], optional): The list of some blocks, each block is subgraph of
the recon-level is region, the reconstruction loss of each region is minimized.
fp32 program and it will have exact 1 input operation and 1 output operation.
Default is None.
block_weights_names(list[list], optional): The weight names inside every block.
region_weights_names(list[list], optional): The weight names inside every region.
lr(float, optional): The learning rate of Rounding Optimizer.
Default is None.
lr(float, optional): The learning rate of Reconstruction Quanter. Default is 0.1.
bias_correction(bool, optional): If set as True, use the bias correction
bias_correction(bool, optional): If set as True, use the bias correction
method of https://arxiv.org/abs/1810.05723. Default is False.
method of https://arxiv.org/abs/1810.05723. Default is False.
scale_trainable: Wether weight‘s scale is trainable. Default is False.
drop_prob: The dropout probability of activation quantization, and it is valid only if
simulate_activation_quant is True. Default is 0.5.
Returns:
Returns:
None
None
'''
'''
assert
round_type
in
[
'adaround'
,
'brecq'
,
'qdrop'
]
assert
recon_level
in
[
if
round_type
in
[
'brecq'
,
'qdrop'
]:
'layer-wise'
,
'region-wise'
assert
blocks
is
not
None
,
"The blocks cannot be None."
],
"recon_level must be one of the ['layer-wise', 'region-wise'],but received: {}"
.
format
(
assert
block_weights_names
is
not
None
,
"The block_weights_names cannot be None."
recon_level
)
if
recon_level
==
'region-wise'
:
assert
regions
is
not
None
,
"The regions cannot be None."
assert
region_weights_names
is
not
None
,
"The region_weights_names cannot be None."
self
.
_simulate_activation_quant
=
simulate_activation_quant
self
.
_program
=
fp32_program
self
.
_program
=
fp32_program
self
.
_data_loader
=
data_loader
self
.
_data_loader
=
data_loader
self
.
_r
ound_type
=
round_type
self
.
_r
econ_level
=
recon_level
self
.
_feed_list
=
feed_list
self
.
_feed_list
=
feed_list
self
.
_fetch_list
=
fetch_list
self
.
_fetch_list
=
fetch_list
self
.
_exe
=
exe
self
.
_exe
=
exe
...
@@ -158,17 +289,19 @@ class RoundingOptimizer(object):
...
@@ -158,17 +289,19 @@ class RoundingOptimizer(object):
self
.
_num_iterations
=
num_iterations
self
.
_num_iterations
=
num_iterations
self
.
_epochs
=
epochs
self
.
_epochs
=
epochs
self
.
_lr
=
lr
self
.
_lr
=
lr
self
.
_
blocks
=
block
s
self
.
_
regions
=
region
s
self
.
_
block_weights_names
=
block
_weights_names
self
.
_
region_weights_names
=
region
_weights_names
self
.
_bias_correction
=
bias_correction
self
.
_bias_correction
=
bias_correction
if
round_type
in
[
'adaround'
]:
if
self
.
_recon_level
==
'layer-wise'
:
blocks
,
block_weights_names
=
self
.
_get_layers
()
regions
,
region_weights_names
=
self
.
_get_layers
()
self
.
_blocks
=
blocks
self
.
_regions
=
regions
self
.
_block_weights_names
=
block_weights_names
self
.
_region_weights_names
=
region_weights_names
self
.
_scale_trainable
=
scale_trainable
self
.
_drop_prob
=
drop_prob
def
_get_layers
(
self
):
def
_get_layers
(
self
):
block
s
=
[]
region
s
=
[]
block
_weights_names
=
[]
region
_weights_names
=
[]
persistable_var_names
=
self
.
_all_persistable_var_names
()
persistable_var_names
=
self
.
_all_persistable_var_names
()
self
.
_input_weight_pairs
=
{}
self
.
_input_weight_pairs
=
{}
for
block_id
in
range
(
len
(
self
.
_program
.
blocks
)):
for
block_id
in
range
(
len
(
self
.
_program
.
blocks
)):
...
@@ -180,14 +313,14 @@ class RoundingOptimizer(object):
...
@@ -180,14 +313,14 @@ class RoundingOptimizer(object):
self
.
_input_weight_pairs
[
in_var_name
]
=
in_var_names
self
.
_input_weight_pairs
[
in_var_name
]
=
in_var_names
break
break
for
name
in
self
.
_weight_var_names
:
for
name
in
self
.
_weight_var_names
:
block
_weights_names
.
append
([
name
])
region
_weights_names
.
append
([
name
])
block
_
=
[]
region
_
=
[]
block
_
.
append
(
self
.
_input_weight_pairs
[
name
][
0
])
region
_
.
append
(
self
.
_input_weight_pairs
[
name
][
0
])
block
_
.
append
(
self
.
_quantized_op_pairs
[
name
])
region
_
.
append
(
self
.
_quantized_op_pairs
[
name
])
blocks
.
append
(
block
_
)
regions
.
append
(
region
_
)
return
blocks
,
block
_weights_names
return
regions
,
region
_weights_names
def
_preprocess
(
self
):
def
_preprocess
(
self
):
data_name_map
=
{}
data_name_map
=
{}
for
name
in
self
.
_feed_list
:
for
name
in
self
.
_feed_list
:
data_name_map
[
name
]
=
name
data_name_map
[
name
]
=
name
...
@@ -199,35 +332,51 @@ class RoundingOptimizer(object):
...
@@ -199,35 +332,51 @@ class RoundingOptimizer(object):
self
.
_place
,
self
.
_place
,
teacher_scope
=
None
,
teacher_scope
=
None
,
name_prefix
=
"teacher_"
,
name_prefix
=
"teacher_"
,
merge_feed
=
True
)
merge_feed
=
True
,
)
for
name
in
self
.
_weight_var_names
:
for
name
in
self
.
_weight_var_names
:
weight_np
=
utils
.
load_variable_data
(
self
.
_scope
,
name
)
weight_np
=
utils
.
load_variable_data
(
self
.
_scope
,
name
)
scale
=
self
.
_scale_dict
[
name
]
scale
=
self
.
_scale_dict
[
name
]
weight_np_floor
=
np
.
floor
(
utils
.
quant_tensor
(
weight_np
,
scale
))
weight_np_floor
=
np
.
floor
(
utils
.
quant_tensor
(
weight_np
,
scale
))
utils
.
set_variable_data
(
self
.
_scope
,
self
.
_place
,
name
,
weight_np_floor
)
utils
.
set_variable_data
(
self
.
_scope
,
self
.
_place
,
name
,
weight_np_floor
,
)
self
.
_graph
=
GraphWrapper
(
self
.
_student_program
)
self
.
_graph
=
GraphWrapper
(
self
.
_student_program
)
if
self
.
_
round_type
==
'qdrop'
:
if
self
.
_
simulate_activation_quant
:
self
.
_insert_drop_quant_dequant
()
self
.
_insert_drop_quant_dequant
()
self
.
_insert_soft_rounding
()
self
.
_insert_soft_rounding
()
self
.
_isolate_
block
s
()
self
.
_isolate_
region
s
()
def
_run
(
self
):
def
_run
(
self
):
self
.
_preprocess
()
self
.
_preprocess
()
startup_program
=
paddle
.
static
.
Program
()
startup_program
=
paddle
.
static
.
Program
()
for
k
in
range
(
len
(
self
.
_
block
s
)):
for
k
in
range
(
len
(
self
.
_
region
s
)):
block_
=
self
.
_block
s
[
k
]
region_
=
self
.
_region
s
[
k
]
names
=
self
.
_
block
_weights_names
[
k
]
names
=
self
.
_
region
_weights_names
[
k
]
tmp_program
=
self
.
_student_program
.
clone
()
tmp_program
=
self
.
_student_program
.
clone
()
quant_op_out_name
=
block
_
[
1
]
quant_op_out_name
=
region
_
[
1
]
with
paddle
.
static
.
program_guard
(
tmp_program
,
startup_program
):
with
paddle
.
static
.
program_guard
(
tmp_program
,
startup_program
):
loss_function
=
R
oundingOptimiz
erLoss
(
tmp_program
,
names
)
loss_function
=
R
econstructionQuant
erLoss
(
tmp_program
,
names
)
quant_op_out_name
=
block
_
[
1
]
quant_op_out_name
=
region
_
[
1
]
student_var
=
tmp_program
.
global_block
().
var
(
quant_op_out_name
)
student_var
=
tmp_program
.
global_block
().
var
(
quant_op_out_name
)
teacher_var
=
tmp_program
.
global_block
().
var
(
"teacher_"
+
quant_op_out_name
)
teacher_var
=
tmp_program
.
global_block
().
var
(
"teacher_"
+
scheduler
=
paddle
.
optimizer
.
lr
.
CosineAnnealingDecay
(
learning_rate
=
20
,
eta_min
=
2
,
T_max
=
2000
,
verbose
=
True
)
quant_op_out_name
)
total_loss
,
recon_loss
,
round_loss
=
loss_function
.
get_loss
(
student_var
,
teacher_var
,
scheduler
)
scheduler
=
paddle
.
optimizer
.
lr
.
CosineAnnealingDecay
(
train_fetches_loss
=
{
"total_loss"
:
total_loss
,
"recon_loss"
:
recon_loss
,
"round_loss"
:
round_loss
}
learning_rate
=
20
,
eta_min
=
2
,
T_max
=
2000
,
verbose
=
True
,
)
total_loss
,
recon_loss
,
round_loss
=
loss_function
.
get_loss
(
student_var
,
teacher_var
,
scheduler
,
)
train_fetches_loss
=
{
"total_loss"
:
total_loss
,
"recon_loss"
:
recon_loss
,
"round_loss"
:
round_loss
,
}
optimizer
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
self
.
_lr
)
optimizer
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
self
.
_lr
)
optimizer
.
minimize
(
total_loss
)
optimizer
.
minimize
(
total_loss
)
...
@@ -241,11 +390,17 @@ class RoundingOptimizer(object):
...
@@ -241,11 +390,17 @@ class RoundingOptimizer(object):
out
=
self
.
_exe
.
run
(
out
=
self
.
_exe
.
run
(
tmp_program
,
tmp_program
,
feed
=
data
,
feed
=
data
,
fetch_list
=
[
v
.
name
for
v
in
train_fetches_loss
.
values
()],
fetch_list
=
[
return_numpy
=
True
)
v
.
name
for
v
in
train_fetches_loss
.
values
()
],
return_numpy
=
True
,
)
_logger
.
info
(
_logger
.
info
(
"Iter {:d}, lr {}, total_loss {:.5f}, recon_loss {:.5f}, round_loss {:.5f}, time {:.5f}s"
"Iter {:d}, lr {}, total_loss {:.5f}, recon_loss {:.5f}, round_loss {:.5f}, time {:.5f}s"
.
format
(
epoch
,
self
.
_lr
,
np
.
mean
(
out
[
0
]),
np
.
mean
(
out
[
1
]),
np
.
mean
(
out
[
2
]),
start_time
-
prev_start_time
))
.
format
(
epoch
,
self
.
_lr
,
np
.
mean
(
out
[
0
]),
np
.
mean
(
out
[
1
]),
np
.
mean
(
out
[
2
]),
start_time
-
prev_start_time
),
)
sys
.
stdout
.
flush
()
sys
.
stdout
.
flush
()
if
i
==
self
.
_num_iterations
:
if
i
==
self
.
_num_iterations
:
break
break
...
@@ -255,7 +410,7 @@ class RoundingOptimizer(object):
...
@@ -255,7 +410,7 @@ class RoundingOptimizer(object):
return
self
.
_program
return
self
.
_program
def
_init_alpha
(
self
,
name
,
scale
):
def
_init_alpha
(
self
,
name
,
scale
):
_tensor
=
utils
.
load_variable_data
(
self
.
_scope
,
"teacher_"
+
name
)
_tensor
=
utils
.
load_variable_data
(
self
.
_scope
,
"teacher_"
+
name
)
tensor_scaled
=
utils
.
quant_tensor
(
_tensor
,
scale
)
tensor_scaled
=
utils
.
quant_tensor
(
_tensor
,
scale
)
tensor_floor
=
np
.
floor
(
tensor_scaled
)
tensor_floor
=
np
.
floor
(
tensor_scaled
)
tensor
=
tensor_scaled
-
tensor_floor
tensor
=
tensor_scaled
-
tensor_floor
...
@@ -269,31 +424,39 @@ class RoundingOptimizer(object):
...
@@ -269,31 +424,39 @@ class RoundingOptimizer(object):
weight: The quanted weight with dtype=float32
weight: The quanted weight with dtype=float32
"""
"""
bnt
=
(
1
<<
(
weight_bits
-
1
))
-
1
bnt
=
(
1
<<
(
weight_bits
-
1
))
-
1
def
_dequant
(
x
,
scale
):
def
_dequant
(
x
,
scale
):
s
=
(
scale
+
1e-8
)
/
bnt
s
=
(
scale
+
1e-8
)
/
bnt
dequant_x
=
s
*
x
dequant_x
=
s
*
x
return
dequant_x
return
dequant_x
quantized_weight
=
paddle
.
static
.
data
(
shape
=
weight
.
shape
,
dtype
=
weight
.
dtype
,
name
=
weight
.
name
+
'_quant'
)
v
=
paddle
.
static
.
create_parameter
(
shape
=
weight
.
shape
,
quantized_weight
=
paddle
.
static
.
data
(
dtype
=
weight
.
dtype
,
shape
=
weight
.
shape
,
name
=
weight
.
name
+
".alpha"
,
dtype
=
weight
.
dtype
,
default_initializer
=
fluid
.
initializer
.
NumpyArrayInitializer
(
self
.
_alpha
))
name
=
weight
.
name
+
'_quant'
,
)
v
=
paddle
.
static
.
create_parameter
(
shape
=
weight
.
shape
,
dtype
=
weight
.
dtype
,
name
=
weight
.
name
+
".alpha"
,
default_initializer
=
fluid
.
initializer
.
NumpyArrayInitializer
(
self
.
_alpha
,
),
)
h_v
=
paddle
.
clip
(
paddle
.
nn
.
functional
.
sigmoid
(
v
)
*
(
ZETA
-
GAMMA
)
+
GAMMA
,
0
,
1
)
h_v
=
paddle
.
clip
(
paddle
.
nn
.
functional
.
sigmoid
(
v
)
*
(
ZETA
-
GAMMA
)
+
GAMMA
,
0
,
1
,
)
if
self
.
_weight_quantize_type
==
'channel_wise_abs_max'
:
if
self
.
_weight_quantize_type
==
'channel_wise_abs_max'
:
scale_var
=
paddle
.
static
.
create_parameter
(
scale_var
=
paddle
.
static
.
create_parameter
(
dtype
=
weight
.
dtype
,
dtype
=
weight
.
dtype
,
shape
=
weight
.
shape
,
shape
=
weight
.
shape
,
name
=
weight
.
name
+
'.scale'
,
name
=
weight
.
name
+
'.scale'
,
default_initializer
=
fluid
.
initializer
.
NumpyArrayInitializer
(
scale
),
default_initializer
=
fluid
.
initializer
.
NumpyArrayInitializer
(
)
scale
,
),
)
else
:
else
:
scale_var
=
scale
scale_var
=
scale
w
=
_dequant
(
quantized_weight
+
h_v
,
scale_var
)
w
=
_dequant
(
quantized_weight
+
h_v
,
scale_var
)
return
w
return
w
def
_insert_soft_rounding
(
self
):
def
_insert_soft_rounding
(
self
):
...
@@ -302,26 +465,28 @@ class RoundingOptimizer(object):
...
@@ -302,26 +465,28 @@ class RoundingOptimizer(object):
scale
=
self
.
_scale_dict
[
name
]
scale
=
self
.
_scale_dict
[
name
]
shape
=
weight
.
shape
()
shape
=
weight
.
shape
()
self
.
_alpha
=
self
.
_init_alpha
(
name
,
scale
)
self
.
_alpha
=
self
.
_init_alpha
(
name
,
scale
)
if
self
.
_weight_quantize_type
==
'channel_wise_abs_max'
:
if
self
.
_weight_quantize_type
==
'channel_wise_abs_max'
:
scale
=
np
.
array
(
scale
)
scale
=
np
.
array
(
scale
)
scale
=
scale
.
reshape
(
scale
.
shape
[
0
],
1
)
scale
=
scale
.
reshape
(
scale
.
shape
[
0
],
1
)
if
len
(
shape
)
==
2
:
if
len
(
shape
)
==
2
:
scale
=
scale
.
repeat
(
shape
[
0
],
axis
=
0
)
scale
=
scale
.
repeat
(
shape
[
0
],
axis
=
0
)
else
:
else
:
scale
=
scale
.
repeat
(
shape
[
1
]
*
shape
[
2
]
*
shape
[
3
],
axis
=
1
)
scale
=
scale
.
repeat
(
shape
[
1
]
*
shape
[
2
]
*
shape
[
3
],
axis
=
1
)
scale
=
scale
.
reshape
(
shape
)
scale
=
scale
.
reshape
(
shape
)
self
.
_insert_func
(
var
=
weight
,
scale
=
scale
,
func
=
"_soft_rounding"
)
self
.
_insert_func
(
var
=
weight
,
scale
=
scale
,
func
=
"_soft_rounding"
)
def
_drop_quant_dequant
(
self
,
inputs
,
scale
,
weight_bits
=
8
):
def
_drop_quant_dequant
(
self
,
inputs
,
scale
,
weight_bits
=
8
):
x
=
paddle
.
static
.
data
(
shape
=
inputs
.
shape
,
x
=
paddle
.
static
.
data
(
dtype
=
inputs
.
dtype
,
shape
=
inputs
.
shape
,
name
=
inputs
.
name
+
'.tmp'
)
dtype
=
inputs
.
dtype
,
name
=
inputs
.
name
+
'.tmp'
,
)
bnt
=
(
1
<<
(
weight_bits
-
1
))
-
1
bnt
=
(
1
<<
(
weight_bits
-
1
))
-
1
scale
=
scale
/
bnt
scale
=
scale
/
bnt
dequantized_tensor
=
paddle
.
round
(
x
/
scale
)
*
scale
dequantized_tensor
=
paddle
.
round
(
x
/
scale
)
*
scale
quant_noise
=
x
-
dequantized_tensor
quant_noise
=
x
-
dequantized_tensor
random_noise
=
paddle
.
nn
.
functional
.
dropout
(
quant_noise
,
p
=
0.5
)
random_noise
=
paddle
.
nn
.
functional
.
dropout
(
return
x
+
random_noise
quant_noise
,
p
=
self
.
_drop_prob
)
return
x
-
random_noise
def
_insert_drop_quant_dequant
(
self
):
def
_insert_drop_quant_dequant
(
self
):
for
op
in
self
.
_graph
.
ops
():
for
op
in
self
.
_graph
.
ops
():
...
@@ -337,7 +502,10 @@ class RoundingOptimizer(object):
...
@@ -337,7 +502,10 @@ class RoundingOptimizer(object):
else
:
else
:
input
=
op
.
inputs
(
"X"
)[
0
]
input
=
op
.
inputs
(
"X"
)[
0
]
if
input
.
name
()
in
self
.
_scale_dict
.
keys
():
if
input
.
name
()
in
self
.
_scale_dict
.
keys
():
self
.
_insert_func
(
var
=
input
,
scale
=
self
.
_scale_dict
[
input
.
name
()],
func
=
"_drop_quant_dequant"
)
self
.
_insert_func
(
var
=
input
,
scale
=
self
.
_scale_dict
[
input
.
name
()],
func
=
"_drop_quant_dequant"
,
)
def
_insert_func
(
self
,
var
,
scale
,
func
):
def
_insert_func
(
self
,
var
,
scale
,
func
):
program
=
var
.
_graph
.
program
program
=
var
.
_graph
.
program
...
@@ -346,51 +514,51 @@ class RoundingOptimizer(object):
...
@@ -346,51 +514,51 @@ class RoundingOptimizer(object):
startup_program
=
paddle
.
static
.
Program
()
startup_program
=
paddle
.
static
.
Program
()
new_program
=
paddle
.
static
.
Program
()
new_program
=
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
new_program
,
startup_program
):
with
paddle
.
static
.
program_guard
(
new_program
,
startup_program
):
if
func
==
"_soft_rounding"
:
if
func
==
"_soft_rounding"
:
out
=
self
.
_soft_rounding
(
inputs
,
scale
)
out
=
self
.
_soft_rounding
(
inputs
,
scale
)
elif
func
==
"_drop_quant_dequant"
:
elif
func
==
"_drop_quant_dequant"
:
out
=
self
.
_drop_quant_dequant
(
inputs
,
scale
)
out
=
self
.
_drop_quant_dequant
(
inputs
,
scale
)
self
.
_exe
.
run
(
startup_program
)
self
.
_exe
.
run
(
startup_program
)
#create var in program
#
create var in program
for
new_var
in
new_program
.
list_vars
():
for
new_var
in
new_program
.
list_vars
():
if
new_var
.
name
==
var
.
_var
.
name
+
'_quant'
or
new_var
.
name
==
var
.
_var
.
name
+
'.tmp'
:
if
new_var
.
name
==
var
.
_var
.
name
+
'_quant'
or
new_var
.
name
==
var
.
_var
.
name
+
'.tmp'
:
continue
continue
elif
new_var
.
name
==
var
.
_var
.
name
+
'.alpha'
:
elif
new_var
.
name
==
var
.
_var
.
name
+
'.alpha'
:
program
.
global_block
().
create_parameter
(
name
=
new_var
.
name
,
shape
=
new_var
.
shape
,
dtype
=
new_var
.
dtype
,
type
=
new_var
.
type
,
stop_gradient
=
new_var
.
stop_gradient
)
elif
new_var
.
name
==
var
.
_var
.
name
+
'.scale'
:
program
.
global_block
().
create_parameter
(
program
.
global_block
().
create_parameter
(
name
=
new_var
.
name
,
name
=
new_var
.
name
,
shape
=
new_var
.
shape
,
dtype
=
new_var
.
dtype
,
type
=
new_var
.
type
,
stop_gradient
=
True
,
trainable
=
False
)
else
:
if
func
==
"_soft_rounding"
:
program
.
global_block
().
create_var
(
name
=
new_var
.
name
+
'.rounding'
,
shape
=
new_var
.
shape
,
shape
=
new_var
.
shape
,
dtype
=
new_var
.
dtype
,
dtype
=
new_var
.
dtype
,
type
=
new_var
.
type
,
type
=
new_var
.
type
,
persistable
=
new_var
.
persistable
,
stop_gradient
=
new_var
.
stop_gradient
,
)
stop_gradient
=
new_var
.
stop_gradient
)
elif
new_var
.
name
==
var
.
_var
.
name
+
'.scale'
:
else
:
program
.
global_block
().
create_parameter
(
program
.
global_block
().
create_var
(
name
=
new_var
.
name
,
name
=
new_var
.
name
,
shape
=
new_var
.
shape
,
shape
=
new_var
.
shape
,
dtype
=
new_var
.
dtype
,
dtype
=
new_var
.
dtype
,
type
=
new_var
.
type
,
type
=
new_var
.
type
,
persistable
=
new_var
.
persistable
,
stop_gradient
=
True
,
stop_gradient
=
new_var
.
stop_gradient
)
trainable
=
self
.
_scale_trainable
,
)
else
:
if
func
==
"_soft_rounding"
:
program
.
global_block
().
create_var
(
name
=
new_var
.
name
+
'.rounding'
,
shape
=
new_var
.
shape
,
dtype
=
new_var
.
dtype
,
type
=
new_var
.
type
,
persistable
=
new_var
.
persistable
,
stop_gradient
=
new_var
.
stop_gradient
,
)
else
:
program
.
global_block
().
create_var
(
name
=
new_var
.
name
,
shape
=
new_var
.
shape
,
dtype
=
new_var
.
dtype
,
type
=
new_var
.
type
,
persistable
=
new_var
.
persistable
,
stop_gradient
=
new_var
.
stop_gradient
,
)
op_list
=
new_program
.
global_block
().
ops
op_list
=
new_program
.
global_block
().
ops
op_list
=
list
(
reversed
(
op_list
))
op_list
=
list
(
reversed
(
op_list
))
block
=
var
.
_var
.
block
block
=
var
.
_var
.
block
#prepend new_program's op in program
#
prepend new_program's op in program
for
_op
in
ops
:
for
_op
in
ops
:
if
_op
.
type
()
not
in
[
'conv2d'
,
'depthwise_conv2d'
,
'mul'
]:
if
_op
.
type
()
not
in
[
'conv2d'
,
'depthwise_conv2d'
,
'mul'
]:
continue
continue
...
@@ -398,84 +566,96 @@ class RoundingOptimizer(object):
...
@@ -398,84 +566,96 @@ class RoundingOptimizer(object):
for
op
in
op_list
:
for
op
in
op_list
:
# _attrs = op.all_attrs()
# _attrs = op.all_attrs()
_type
=
op
.
type
_type
=
op
.
type
_attrs
=
{
_attrs
=
{
'use_mkldnn'
:
False
,
'use_mkldnn'
:
False
,
'with_quant_attr'
:
False
}
'with_quant_attr'
:
False
,
if
_type
==
'clip'
:
}
_attrs
=
{
if
_type
==
'clip'
:
_attrs
=
{
'use_mkldnn'
:
False
,
'use_mkldnn'
:
False
,
'with_quant_attr'
:
False
,
'with_quant_attr'
:
False
,
'max'
:
op
.
attr
(
'max'
),
'max'
:
op
.
attr
(
'max'
),
'min'
:
op
.
attr
(
'min'
)}
'min'
:
op
.
attr
(
'min'
),
elif
_type
==
'scale'
:
}
_attrs
=
{
elif
_type
==
'scale'
:
_attrs
=
{
'use_mkldnn'
:
False
,
'use_mkldnn'
:
False
,
'with_quant_attr'
:
False
,
'with_quant_attr'
:
False
,
'scale'
:
op
.
attr
(
'scale'
),
'scale'
:
op
.
attr
(
'scale'
),
'bias_after_scale'
:
op
.
attr
(
'bias_after_scale'
)}
'bias_after_scale'
:
op
.
attr
(
'bias_after_scale'
),
elif
_type
==
'elementwise_mul'
:
}
_attrs
=
{
elif
_type
==
'elementwise_mul'
:
_attrs
=
{
'use_mkldnn'
:
False
,
'use_mkldnn'
:
False
,
'with_quant_attr'
:
False
,
'with_quant_attr'
:
False
,
'Scale_out'
:
op
.
attr
(
'Scale_out'
),
'Scale_out'
:
op
.
attr
(
'Scale_out'
),
'Scale_x'
:
op
.
attr
(
'Scale_x'
),
'Scale_x'
:
op
.
attr
(
'Scale_x'
),
'Scale_y'
:
op
.
attr
(
'Scale_y'
),
'Scale_y'
:
op
.
attr
(
'Scale_y'
),
'axis'
:
op
.
attr
(
'axis'
)}
'axis'
:
op
.
attr
(
'axis'
),
}
if
func
==
"_soft_rounding"
:
_outputs
=
{
'Out'
:
op
.
output
(
'Out'
)[
0
]
+
'.rounding'
}
if
func
==
"_soft_rounding"
:
if
_type
==
"elementwise_add"
:
_outputs
=
{
'Out'
:
op
.
output
(
'Out'
)[
0
]
+
'.rounding'
}
if
_type
==
"elementwise_add"
:
_inputs
=
{
_inputs
=
{
'X'
:
var
.
_var
,
#replace tmp var conv.weight_quant with var conv.weight
'X'
:
var
.
'Y'
:
op
.
input
(
'Y'
)[
0
]
+
'.rounding'
,
_var
,
# replace tmp var conv.weight_quant with var conv.weight
}
'Y'
:
op
.
input
(
'Y'
)[
0
]
+
'.rounding'
,
elif
_type
==
"elementwise_mul"
:
}
elif
_type
==
"elementwise_mul"
:
_inputs
=
{
_inputs
=
{
'X'
:
op
.
input
(
'X'
)[
0
]
+
'.rounding'
,
'X'
:
op
.
input
(
'X'
)[
0
]
+
'.rounding'
,
'Y'
:
op
.
input
(
'Y'
)[
0
]
+
'.rounding'
,
'Y'
:
op
.
input
(
'Y'
)[
0
]
+
'.rounding'
,
}
}
elif
(
_type
==
'scale'
and
op
.
input
(
'X'
)[
0
].
endswith
(
'scale'
))
or
_type
==
'sigmoid'
:
elif
(
_type
==
'scale'
and
_inputs
=
{
'X'
:
op
.
input
(
'X'
)[
0
]}
op
.
input
(
'X'
)[
0
].
endswith
(
'scale'
)
)
or
_type
==
'sigmoid'
:
_inputs
=
{
'X'
:
op
.
input
(
'X'
)[
0
]}
else
:
else
:
_inputs
=
{
'X'
:
op
.
input
(
'X'
)[
0
]
+
'.rounding'
}
_inputs
=
{
'X'
:
op
.
input
(
'X'
)[
0
]
+
'.rounding'
}
elif
func
==
"_drop_quant_dequant"
:
elif
func
==
"_drop_quant_dequant"
:
if
_type
==
'dropout'
:
if
_type
==
'dropout'
:
_outputs
=
{
'Out'
:
op
.
output
(
'Out'
)[
0
],
_outputs
=
{
'Mask'
:
op
.
output
(
'Mask'
)[
0
]}
'Out'
:
op
.
output
(
'Out'
)[
0
],
'Mask'
:
op
.
output
(
'Mask'
)[
0
],
}
else
:
else
:
_outputs
=
{
'Out'
:
op
.
output
(
'Out'
)[
0
]}
_outputs
=
{
'Out'
:
op
.
output
(
'Out'
)[
0
]}
if
_type
==
'elementwise_add'
or
_type
==
'elementwise_sub'
:
if
_type
==
'elementwise_add'
or
_type
==
'elementwise_sub'
:
_inputs
=
{
_inputs
=
{
'X'
:
var
.
_var
,
#replace tmp var conv.weight_quant with var conv.weight
'X'
:
var
.
_var
,
# replace tmp var conv.weight_quant with var conv.weight
'Y'
:
op
.
input
(
'Y'
),
'Y'
:
op
.
input
(
'Y'
),
}
}
elif
_type
==
'scale'
and
op
.
input
(
'X'
)[
0
]
==
inputs
.
name
+
'.tmp'
:
elif
_type
==
'scale'
and
op
.
input
(
'X'
)[
0
]
==
inputs
.
name
+
'.tmp'
:
_inputs
=
{
'X'
:
var
.
_var
}
_inputs
=
{
'X'
:
var
.
_var
}
else
:
else
:
_inputs
=
{
'X'
:
op
.
input
(
'X'
)[
0
]}
_inputs
=
{
'X'
:
op
.
input
(
'X'
)[
0
]}
block
.
_insert_op
(
block
.
_insert_op
(
idx
,
idx
,
type
=
_type
,
type
=
_type
,
attrs
=
_attrs
,
attrs
=
_attrs
,
inputs
=
_inputs
,
inputs
=
_inputs
,
outputs
=
_outputs
,
outputs
=
_outputs
,
)
)
for
op
in
ops
:
for
op
in
ops
:
if
op
.
type
()
not
in
[
'conv2d'
,
'depthwise_conv2d'
,
'mul'
]:
if
op
.
type
()
not
in
[
'conv2d'
,
'depthwise_conv2d'
,
'mul'
]:
continue
continue
if
op
.
type
()
in
[
'conv2d'
,
'depthwise_conv2d'
]
and
op
.
inputs
(
'Filter'
)[
0
].
name
().
startswith
(
'teacher'
):
if
op
.
type
()
in
[
'conv2d'
,
'depthwise_conv2d'
]
and
op
.
inputs
(
'Filter'
)[
0
].
name
().
startswith
(
'teacher'
):
continue
continue
if
op
.
type
()
in
[
'mul'
]
and
op
.
inputs
(
'Y'
)[
0
].
name
().
startswith
(
'teacher'
):
if
op
.
type
()
in
[
'mul'
]
and
op
.
inputs
(
'Y'
)[
0
].
name
().
startswith
(
continue
'teacher'
):
if
func
==
'_soft_rounding'
:
continue
op
.
_op
.
_rename_input
(
inputs
.
name
,
out
.
name
+
'.rounding'
)
if
func
==
'_soft_rounding'
:
op
.
_op
.
_rename_input
(
inputs
.
name
,
out
.
name
+
'.rounding'
)
else
:
else
:
op
.
_op
.
_rename_input
(
inputs
.
name
,
out
.
name
)
op
.
_op
.
_rename_input
(
inputs
.
name
,
out
.
name
)
def
_isolate_
block
s
(
self
):
def
_isolate_
region
s
(
self
):
starts
=
[
block
[
0
]
for
block
in
self
.
_block
s
]
starts
=
[
region
[
0
]
for
region
in
self
.
_region
s
]
var2duplications
=
self
.
_duplicate_vars
(
starts
)
var2duplications
=
self
.
_duplicate_vars
(
starts
)
for
vars_
in
var2duplications
.
values
():
for
vars_
in
var2duplications
.
values
():
for
var_
in
vars_
:
for
var_
in
vars_
:
...
@@ -495,49 +675,301 @@ class RoundingOptimizer(object):
...
@@ -495,49 +675,301 @@ class RoundingOptimizer(object):
for
op
in
var
.
outputs
():
for
op
in
var
.
outputs
():
var_
=
var
.
_var
var_
=
var
.
_var
op_
=
op
.
_op
op_
=
op
.
_op
duplicated_var
=
block
.
create_var
(
name
=
var_
.
name
+
".assign"
+
str
(
index
),
duplicated_var
=
block
.
create_var
(
type
=
var_
.
type
,
name
=
var_
.
name
+
".assign"
+
str
(
index
),
shape
=
var_
.
shape
,
type
=
var_
.
type
,
dtype
=
var_
.
dtype
)
shape
=
var_
.
shape
,
dtype
=
var_
.
dtype
,
)
vars
.
append
(
duplicated_var
)
vars
.
append
(
duplicated_var
)
index
+=
1
index
+=
1
idx
=
block
.
ops
.
index
(
op_
)
idx
=
block
.
ops
.
index
(
op_
)
block
.
_insert_op
(
idx
,
block
.
_insert_op
(
type
=
"assign"
,
idx
,
inputs
=
{
"X"
:
var_
},
type
=
"assign"
,
outputs
=
{
"Out"
:
duplicated_var
})
inputs
=
{
"X"
:
var_
},
outputs
=
{
"Out"
:
duplicated_var
},
)
op_
.
_rename_input
(
var_
.
name
,
duplicated_var
.
name
)
op_
.
_rename_input
(
var_
.
name
,
duplicated_var
.
name
)
return
vars
return
vars
def
_update_weights_to_int
(
self
):
def
_update_weights_to_int
(
self
):
for
weight_var_name
in
self
.
_weight_var_names
:
for
weight_var_name
in
self
.
_weight_var_names
:
alpha_tensor
=
utils
.
load_variable_data
(
self
.
_scope
,
weight_var_name
+
'.alpha'
)
alpha_tensor
=
utils
.
load_variable_data
(
self
.
_scope
,
weight_var_name
+
'.alpha'
,
)
h_alpha_tensor
=
self
.
_compute_soft_rounding_np
(
alpha_tensor
)
h_alpha_tensor
=
self
.
_compute_soft_rounding_np
(
alpha_tensor
)
weight_quant_tensor
=
utils
.
load_variable_data
(
self
.
_scope
,
weight_var_name
)
weight_quant_tensor
=
utils
.
load_variable_data
(
utils
.
set_variable_data
(
self
.
_scope
,
self
.
_place
,
weight_var_name
,
np
.
round
(
weight_quant_tensor
+
h_alpha_tensor
))
self
.
_scope
,
weight_var_name
,
)
utils
.
set_variable_data
(
self
.
_scope
,
self
.
_place
,
weight_var_name
,
np
.
round
(
weight_quant_tensor
+
h_alpha_tensor
,
),
)
def
_bias_correction_w
(
self
):
def
_bias_correction_w
(
self
):
for
weight_var_name
in
self
.
_weight_var_names
:
for
weight_var_name
in
self
.
_weight_var_names
:
weight_var_tensor
=
utils
.
load_variable_data
(
self
.
_scope
,
"teacher_"
+
weight_var_name
)
weight_var_tensor
=
utils
.
load_variable_data
(
weight_quant_tensor
=
utils
.
load_variable_data
(
self
.
_scope
,
weight_var_name
)
self
.
_scope
,
"teacher_"
+
weight_var_name
,
)
weight_quant_tensor
=
utils
.
load_variable_data
(
self
.
_scope
,
weight_var_name
,
)
scale
=
self
.
_scale_dict
[
weight_var_name
]
scale
=
self
.
_scale_dict
[
weight_var_name
]
final_weight_tensor
=
utils
.
bias_correction_w
(
final_weight_tensor
=
utils
.
bias_correction_w
(
weight_var_tensor
,
weight_var_tensor
,
weight_quant_tensor
,
weight_quant_tensor
,
scale
,
scale
,
quant_axis
=
0
,
quant_axis
=
0
,
weight_bits
=
8
)
weight_bits
=
8
,
)
utils
.
set_variable_data
(
self
.
_scope
,
self
.
_place
,
weight_var_name
,
final_weight_tensor
)
utils
.
set_variable_data
(
self
.
_scope
,
self
.
_place
,
weight_var_name
,
final_weight_tensor
,
)
def
_compute_soft_rounding_np
(
self
,
alpha_v
):
def
_compute_soft_rounding_np
(
self
,
alpha_v
):
return
np
.
clip
(
utils
.
stable_sigmoid
(
alpha_v
)
*
(
ZETA
-
GAMMA
)
+
GAMMA
,
return
np
.
clip
(
a_min
=
0
,
utils
.
stable_sigmoid
(
alpha_v
)
*
(
ZETA
-
GAMMA
)
+
GAMMA
,
a_max
=
1
)
a_min
=
0
,
a_max
=
1
,
)
def
_all_persistable_var_names
(
self
):
def
_all_persistable_var_names
(
self
):
persistable_var_names
=
[]
persistable_var_names
=
[]
for
var
in
self
.
_program
.
list_vars
():
for
var
in
self
.
_program
.
list_vars
():
if
var
.
persistable
:
if
var
.
persistable
:
persistable_var_names
.
append
(
var
.
name
)
persistable_var_names
.
append
(
var
.
name
)
return
persistable_var_names
return
persistable_var_names
class
ReconstructionQuanterLoss
(
object
):
def
__init__
(
self
,
program
,
weight_region_names
=
None
,
round_loss_type
=
'relaxation'
,
rec_loss_type
=
'mse'
,
beta_type
=
'const'
,
weight
=
0.1
):
"""
The loss function of Rounding Optimizer.
Args:
program(Program): The student program.
weight_region_names(list, optional): The weight names inside a region.
round_loss_type(str): The type of rounding loss function.
rec_loss_type(str): The type of reconstruction loss function.
beta_type(str): The type of hyper-parameter beta.
Returns:
total_loss(Variable): The sum of rounding loss and reconstruction loss.
rec_loss(Variable): The reconstruction loss.
round_loss(Variable): The rounding loss.
"""
self
.
program
=
program
self
.
round_loss_type
=
round_loss_type
self
.
weight
=
weight
self
.
rec_loss_type
=
rec_loss_type
self
.
weight_region_names
=
weight_region_names
self
.
beta_type
=
beta_type
def
compute_soft_rounding
(
self
,
alpha_v
):
return
paddle
.
clip
(
paddle
.
nn
.
functional
.
sigmoid
(
alpha_v
)
*
(
ZETA
-
GAMMA
)
+
GAMMA
,
0
,
1
)
def
get_loss
(
self
,
student_tensor
,
teacher_tensor
,
scheduler
):
if
self
.
rec_loss_type
==
'mse'
:
rec_loss
=
paddle
.
nn
.
functional
.
mse_loss
(
student_tensor
,
teacher_tensor
,
)
else
:
raise
ValueError
(
'Not supported reconstruction loss function: {}'
.
format
(
self
.
rec_loss
,
),
)
if
self
.
beta_type
==
'const'
:
self
.
beta
=
3
else
:
self
.
beta
=
scheduler
.
get_lr
()
if
self
.
round_loss_type
==
'relaxation'
:
round_loss
=
0.0
for
name
in
self
.
weight_region_names
:
alpha_v
=
self
.
program
.
global_block
().
var
(
name
+
'.alpha'
)
h_v
=
self
.
compute_soft_rounding
(
alpha_v
)
round_loss
+=
self
.
weight
*
\
paddle
.
sum
(
-
paddle
.
pow
(
paddle
.
abs
(
2
*
h_v
-
1
),
self
.
beta
)
+
1
)
else
:
raise
NotImplementedError
total_loss
=
rec_loss
+
round_loss
return
total_loss
,
rec_loss
,
round_loss
def
quant_recon_static
(
executor
,
model_dir
,
quantize_model_path
,
batch_generator
=
None
,
sample_generator
=
None
,
data_loader
=
None
,
model_filename
=
None
,
params_filename
=
None
,
save_model_filename
=
'model.pdmodel'
,
save_params_filename
=
'model.pdiparams'
,
batch_size
=
1
,
batch_nums
=
None
,
scope
=
None
,
algo
=
'hist'
,
recon_level
=
'layer-wise'
,
simulate_activation_quant
=
False
,
hist_percent
=
0.9999
,
bias_correction
=
False
,
quantizable_op_type
=
[
"conv2d"
,
"depthwise_conv2d"
,
"mul"
,
"matmul"
,
"matmul_v2"
,
],
is_full_quantize
=
False
,
weight_bits
=
8
,
activation_bits
=
8
,
activation_quantize_type
=
'range_abs_max'
,
weight_quantize_type
=
'channel_wise_abs_max'
,
optimize_model
=
False
,
onnx_format
=
False
,
skip_tensor_list
=
None
,
is_use_cache_file
=
False
,
cache_dir
=
"./temp_recon_quantization"
,
regions
=
None
,
region_weights_names
=
None
,
epochs
=
20
,
scale_trainable
=
False
,
drop_prob
=
0.5
,
lr
=
0.1
):
"""
The function utilizes static post training quantization method to
quantize the fp32 model. It uses calibrate data to calculate the
scale factor of quantized variables, and inserts fake quantization
and dequantization operators to obtain the quantized model.
Args:
executor(paddle.static.Executor): The executor to load, run and save the
quantized model.
model_dir(str): The path of fp32 model that will be quantized, and
the model and params that saved by ``paddle.static.io.save_inference_model``
are under the path.
quantize_model_path(str): The path to save quantized model using api
``paddle.static.io.save_inference_model``.
batch_generator(Python Generator): The batch generator provides
calibrate data for DataLoader, and it returns a batch every
time. For sample_generator and batch_generator, only one
can be set. Beisdes, batch_generator supports lod tensor.
sample_generator(Python Generator): The sample generator provides
calibrate data for DataLoader, and it only returns a sample every time.
data_loader(Python Generator, Paddle.io.DataLoader, optional): The
Generator or Dataloader provides calibrate data, and it could
return a batch every time.
model_filename(str, optional): The name of model file. If parameters
are saved in separate files, set it as 'None'. Default: 'None'.
params_filename(str, optional): The name of params file.
When all parameters are saved in a single file, set it
as filename. If parameters are saved in separate files,
set it as 'None'. Default : 'None'.
save_model_filename(str): The name of model file to save the quantized inference program. Default: 'model.pdmodel'.
save_params_filename(str): The name of file to save all related parameters.
If it is set None, parameters will be saved in separate files. Default: 'model.pdiparams'.
batch_size(int, optional): The batch size of DataLoader, default is 1.
batch_nums(int, optional): If batch_nums is not None, the number of calibrate
data is 'batch_size*batch_nums'. If batch_nums is None, use all data
generated by sample_generator as calibrate data.
scope(paddle.static.Scope, optional): The scope to run program, use it to load
and save variables. If scope is None, will use paddle.static.global_scope().
algo(str, optional): If algo='KL', use KL-divergenc method to
get the scale factor. If algo='hist', use the hist_percent of histogram
to get the scale factor. If algo='mse', search for the best scale factor which
makes the mse loss minimal. Use one batch of data for mse is enough. If
algo='avg', use the average of abs_max values to get the scale factor. If
algo='abs_max', use abs_max method to get the scale factor. Default: 'hist'.
recon_level(str, optional): The type of reconstruction granularity.
Currently support ['layer-wise', 'region-wise'] types. Default is layer-wise.
simulate_activation_quant(bool, optional): Whether we need the noise caused by activation
quantization during the reconstruction process. Default is False.
hist_percent(float, optional): The percentile of histogram for algo hist.Default:0.9999.
bias_correction(bool, optional): Bias correction method of https://arxiv.org/abs/1810.05723.
Default: False.
quantizable_op_type(list[str], optional): The list of op types
that will be quantized. Default: ["conv2d", "depthwise_conv2d", "mul"].
weight_bits(int, optional): quantization bit number for weights.
activation_bits(int): quantization bit number for activation.
activation_quantize_type(str): quantization type for activation,
now support 'range_abs_max', 'moving_average_abs_max' and 'abs_max'.
This parameter only specifies the fake ops in quantized model.
If it is 'range_abs_max' or 'moving_average_abs_max', we save the scale
obtained by post training quantization in fake ops. If it
is 'abs_max', the scale will not be saved in fake ops.
weight_quantize_type(str): quantization type for weights,
support 'abs_max' and 'channel_wise_abs_max'. Compared to 'abs_max',
the model accuracy is usually higher when using 'channel_wise_abs_max'.
is_full_quantize(bool): if True, apply quantization to all supported quantizable op type.
If False, only apply quantization to the input quantizable_op_type. Default is False.
optimize_model(bool, optional): If set optimize_model as True, it applies some
passes to optimize the model before quantization. So far, the place of
executor must be cpu it supports fusing batch_norm into convs.
onnx_format(bool): Whether to export the quantized model with format of ONNX. Default is False.
skip_tensor_list(list): List of skip quant tensor name.
is_use_cache_file(bool): This param is deprecated.
cache_dir(str): This param is deprecated.
epochs: The number of steps in the reconstruction proces. Default is 20.
scale_trainable: Wether weight‘s scale is trainable. Default is False.
drop_prob: The dropout probability of activation quantization, and it is valid only if
simulate_activation_quant is True. Default is 0.5.
regions(list[list], optional): The list of some regions, each region is a subgraph of
fp32 program and it will have exact 1 input operation and 1 output operation. When
the recon-level is region, the reconstruction loss of each region is minimized.
Default is None.
region_weights_names(list[list], optional): The weight names inside every region.
Default is None.
Returns:
None
"""
PTQCollections
=
Collections
(
executor
=
executor
,
sample_generator
=
sample_generator
,
batch_generator
=
batch_generator
,
data_loader
=
data_loader
,
model_dir
=
model_dir
,
model_filename
=
model_filename
,
params_filename
=
params_filename
,
batch_size
=
batch_size
,
batch_nums
=
batch_nums
,
scope
=
scope
,
algo
=
algo
,
hist_percent
=
hist_percent
,
bias_correction
=
bias_correction
,
quantizable_op_type
=
quantizable_op_type
,
is_full_quantize
=
is_full_quantize
,
weight_bits
=
weight_bits
,
activation_bits
=
activation_bits
,
activation_quantize_type
=
activation_quantize_type
,
weight_quantize_type
=
weight_quantize_type
,
onnx_format
=
onnx_format
,
skip_tensor_list
=
skip_tensor_list
,
optimize_model
=
optimize_model
,
round_type
=
'adaround'
)
RSQCollections
=
Collections
(
recon_level
=
recon_level
,
simulate_activation_quant
=
simulate_activation_quant
,
regions
=
regions
,
region_weights_names
=
region_weights_names
,
epochs
=
epochs
,
scale_trainable
=
scale_trainable
,
lr
=
lr
)
reconstruction_quantization
=
ReconstructionQuantization
(
PTQCollections
=
PTQCollections
,
RSQCollections
=
RSQCollections
)
reconstruction_quantization
.
quantize
()
reconstruction_quantization
.
save_quantized_model
(
quantize_model_path
,
model_filename
=
save_model_filename
,
params_filename
=
save_params_filename
)
tests/test_r
ounding_opimizer
.py
→
tests/test_r
econstruct_quantization
.py
浏览文件 @
bc442429
...
@@ -22,15 +22,15 @@ from models import MobileNet
...
@@ -22,15 +22,15 @@ from models import MobileNet
from
layers
import
conv_bn_layer
from
layers
import
conv_bn_layer
import
paddle.dataset.mnist
as
reader
import
paddle.dataset.mnist
as
reader
import
numpy
as
np
import
numpy
as
np
from
paddle.fluid.contrib.slim.quantization
import
PostTrainingQuantization
from
paddleslim.quant
import
quant_recon_static
from
paddleslim.quant.rounding_optimizer
import
RoundingOptimizer
class
TestRoundingOptimizer
(
StaticCase
):
class
TestRoundingOptimizer
(
StaticCase
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
TestRoundingOptimizer
,
self
).
__init__
(
*
args
,
**
kwargs
)
super
(
TestRoundingOptimizer
,
self
).
__init__
(
*
args
,
**
kwargs
)
paddle
.
enable_static
()
paddle
.
enable_static
()
self
.
_gen_model
()
self
.
_gen_model
()
def
_gen_model
(
self
):
def
_gen_model
(
self
):
image
=
paddle
.
static
.
data
(
image
=
paddle
.
static
.
data
(
name
=
'image'
,
shape
=
[
None
,
1
,
28
,
28
],
dtype
=
'float32'
)
name
=
'image'
,
shape
=
[
None
,
1
,
28
,
28
],
dtype
=
'float32'
)
...
@@ -52,13 +52,15 @@ class TestRoundingOptimizer(StaticCase):
...
@@ -52,13 +52,15 @@ class TestRoundingOptimizer(StaticCase):
)
else
paddle
.
CPUPlace
()
)
else
paddle
.
CPUPlace
()
exe
=
paddle
.
static
.
Executor
(
place
)
exe
=
paddle
.
static
.
Executor
(
place
)
exe
.
run
(
paddle
.
static
.
default_startup_program
())
exe
.
run
(
paddle
.
static
.
default_startup_program
())
def
transform
(
x
):
def
transform
(
x
):
return
np
.
reshape
(
x
,
[
1
,
28
,
28
])
return
np
.
reshape
(
x
,
[
1
,
28
,
28
])
train_dataset
=
paddle
.
vision
.
datasets
.
MNIST
(
train_dataset
=
paddle
.
vision
.
datasets
.
MNIST
(
mode
=
'train'
,
backend
=
'cv2'
,
transform
=
transform
)
mode
=
'train'
,
backend
=
'cv2'
,
transform
=
transform
)
test_dataset
=
paddle
.
vision
.
datasets
.
MNIST
(
test_dataset
=
paddle
.
vision
.
datasets
.
MNIST
(
mode
=
'test'
,
backend
=
'cv2'
,
transform
=
transform
)
mode
=
'test'
,
backend
=
'cv2'
,
transform
=
transform
)
train_loader
=
paddle
.
io
.
DataLoader
(
self
.
train_loader
=
paddle
.
io
.
DataLoader
(
train_dataset
,
train_dataset
,
places
=
place
,
places
=
place
,
feed_list
=
[
image
,
label
],
feed_list
=
[
image
,
label
],
...
@@ -71,15 +73,18 @@ class TestRoundingOptimizer(StaticCase):
...
@@ -71,15 +73,18 @@ class TestRoundingOptimizer(StaticCase):
feed_list
=
[
image
,
label
],
feed_list
=
[
image
,
label
],
batch_size
=
64
,
batch_size
=
64
,
return_list
=
False
)
return_list
=
False
)
def
sample_generator_creator
():
def
sample_generator_creator
():
def
__reader__
():
def
__reader__
():
for
data
in
test_dataset
:
for
data
in
test_dataset
:
image
,
label
=
data
image
,
label
=
data
yield
image
,
label
yield
image
,
label
return
__reader__
return
__reader__
def
train
(
program
):
def
train
(
program
):
iter
=
0
iter
=
0
for
data
in
train_loader
():
for
data
in
self
.
train_loader
():
cost
,
top1
,
top5
=
exe
.
run
(
cost
,
top1
,
top5
=
exe
.
run
(
program
,
program
,
feed
=
data
,
feed
=
data
,
...
@@ -89,6 +94,7 @@ class TestRoundingOptimizer(StaticCase):
...
@@ -89,6 +94,7 @@ class TestRoundingOptimizer(StaticCase):
print
(
print
(
'train iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'
.
'train iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'
.
format
(
iter
,
cost
,
top1
,
top5
))
format
(
iter
,
cost
,
top1
,
top5
))
train
(
main_prog
)
train
(
main_prog
)
paddle
.
fluid
.
io
.
save_inference_model
(
paddle
.
fluid
.
io
.
save_inference_model
(
dirname
=
'./test_rounding_optimizer'
,
dirname
=
'./test_rounding_optimizer'
,
...
@@ -98,55 +104,59 @@ class TestRoundingOptimizer(StaticCase):
...
@@ -98,55 +104,59 @@ class TestRoundingOptimizer(StaticCase):
executor
=
exe
,
executor
=
exe
,
model_filename
=
'model'
,
model_filename
=
'model'
,
params_filename
=
'params'
)
params_filename
=
'params'
)
self
.
post_training_quantization
=
PostTrainingQuantization
(
exe
,
self
.
data_loader
=
sample_generator_creator
()
'./test_rounding_optimizer'
,
sample_generator
=
sample_generator_creator
(),
self
.
_regions
=
[[
'image'
,
'batch_norm_26.tmp_4'
]]
model_filename
=
'model'
,
self
.
_region_weights_names
=
[[
params_filename
=
'params'
,
'conv1_weights'
,
'conv2_1_dw_weights'
,
'conv2_1_sep_weights'
,
batch_nums
=
10
,
'conv2_2_dw_weights'
,
'conv2_2_sep_weights'
,
'conv3_1_dw_weights'
,
algo
=
'abs_max'
,
'conv3_1_sep_weights'
,
'conv3_2_dw_weights'
,
'conv3_2_sep_weights'
,
bias_correction
=
True
)
'conv4_1_dw_weights'
,
'conv4_1_sep_weights'
,
'conv4_2_dw_weights'
,
'conv4_2_sep_weights'
,
'conv5_1_dw_weights'
,
'conv5_1_sep_weights'
,
self
.
post_training_quantization
.
_load_model_data
()
'conv5_2_dw_weights'
,
'conv5_2_sep_weights'
,
'conv5_3_dw_weights'
,
self
.
post_training_quantization
.
_collect_target_varnames
()
'conv5_3_sep_weights'
,
'conv5_4_dw_weights'
,
'conv5_4_sep_weights'
,
self
.
post_training_quantization
.
_set_activation_persistable
()
'conv5_5_dw_weights'
,
'conv5_5_sep_weights'
,
'conv5_6_dw_weights'
,
for
data
in
self
.
post_training_quantization
.
_data_loader
():
'conv5_6_sep_weights'
,
'conv6_dw_weights'
,
'conv6_sep_weights'
self
.
post_training_quantization
.
_executor
.
run
(
program
=
self
.
post_training_quantization
.
_program
,
]]
feed
=
data
,
fetch_list
=
self
.
post_training_quantization
.
_fetch_list
,
return_numpy
=
False
,
scope
=
self
.
post_training_quantization
.
_scope
)
self
.
post_training_quantization
.
_sampling
()
self
.
post_training_quantization
.
_reset_activation_persistable
()
self
.
_blocks
=
[[
'image'
,
'batch_norm_26.tmp_4'
]]
self
.
_block_weights_names
=
[[
'conv1_weights'
,
'conv2_1_dw_weights'
,
'conv2_1_sep_weights'
,
'conv2_2_dw_weights'
,
'conv2_2_sep_weights'
,
'conv3_1_dw_weights'
,
'conv3_1_sep_weights'
,
'conv3_2_dw_weights'
,
'conv3_2_sep_weights'
,
'conv4_1_dw_weights'
,
'conv4_1_sep_weights'
,
'conv4_2_dw_weights'
,
'conv4_2_sep_weights'
,
'conv5_1_dw_weights'
,
'conv5_1_sep_weights'
,
'conv5_2_dw_weights'
,
'conv5_2_sep_weights'
,
'conv5_3_dw_weights'
,
'conv5_3_sep_weights'
,
'conv5_4_dw_weights'
,
'conv5_4_sep_weights'
,
'conv5_5_dw_weights'
,
'conv5_5_sep_weights'
,
'conv5_6_dw_weights'
,
'conv5_6_sep_weights'
,
'conv6_dw_weights'
,
'conv6_sep_weights'
]]
def
test_qdrop
(
self
):
def
test_qdrop
(
self
):
rounding_optimizer
=
RoundingOptimizer
(
place
=
paddle
.
CUDAPlace
(
0
)
if
paddle
.
is_compiled_with_cuda
(
data_loader
=
self
.
post_training_quantization
.
_data_loader
,
)
else
paddle
.
CPUPlace
()
fp32_program
=
self
.
post_training_quantization
.
_program
,
exe
=
paddle
.
static
.
Executor
(
place
)
feed_list
=
self
.
post_training_quantization
.
_feed_list
,
quant_recon_static
(
fetch_list
=
self
.
post_training_quantization
.
_fetch_list
,
exe
,
exe
=
self
.
post_training_quantization
.
_executor
,
'./test_rounding_optimizer'
,
scope
=
self
.
post_training_quantization
.
_scope
,
quantize_model_path
=
'rsq_out'
,
place
=
self
.
post_training_quantization
.
_place
,
sample_generator
=
self
.
data_loader
,
quantized_op_pairs
=
self
.
post_training_quantization
.
_quantized_op_pairs
,
model_filename
=
'model'
,
weight_quantize_type
=
self
.
post_training_quantization
.
_weight_quantize_type
,
params_filename
=
'params'
,
scale_dict
=
self
.
post_training_quantization
.
_quantized_threshold
,
batch_nums
=
10
,
blocks
=
self
.
_blocks
,
algo
=
'abs_max'
,
block_weights_names
=
self
.
_block_weights_names
,
regions
=
self
.
_regions
,
round_type
=
'qdrop'
,
region_weights_names
=
self
.
_region_weights_names
,
num_iterations
=
self
.
post_training_quantization
.
_batch_nums
,
recon_level
=
'region-wise'
,
lr
=
self
.
post_training_quantization
.
_learning_rate
,
simulate_activation_quant
=
True
)
bias_correction
=
self
.
post_training_quantization
.
_bias_correction
,
epochs
=
10
,
def
test_qdrop
(
self
):
)
place
=
paddle
.
CUDAPlace
(
0
)
if
paddle
.
is_compiled_with_cuda
(
rounding_optimizer
.
_run
()
)
else
paddle
.
CPUPlace
()
rounding_optimizer
.
_get_layers
()
exe
=
paddle
.
static
.
Executor
(
place
)
quant_recon_static
(
exe
,
'./test_rounding_optimizer'
,
quantize_model_path
=
'rsq_out'
,
sample_generator
=
self
.
data_loader
,
model_filename
=
'model'
,
params_filename
=
'params'
,
batch_nums
=
10
,
algo
=
'KL'
,
regions
=
self
.
_regions
,
region_weights_names
=
self
.
_region_weights_names
,
recon_level
=
'layer-wise'
,
simulate_activation_quant
=
True
,
bias_correction
=
True
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录