Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSeg
提交
7781902e
P
PaddleSeg
项目概览
PaddlePaddle
/
PaddleSeg
通知
285
Star
8
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSeg
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7781902e
编写于
5月 19, 2020
作者:
C
chenguowei01
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update post_quantization.py
上级
13975d0f
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
126 addition
and
71 deletion
+126
-71
contrib/HumanSeg/utils/post_quantization.py
contrib/HumanSeg/utils/post_quantization.py
+126
-71
未找到文件。
contrib/HumanSeg/utils/post_quantization.py
浏览文件 @
7781902e
...
...
@@ -14,12 +14,14 @@
from
paddle.fluid.contrib.slim.quantization.quantization_pass
import
QuantizationTransformPass
from
paddle.fluid.contrib.slim.quantization.quantization_pass
import
AddQuantDequantPass
from
paddle.fluid.contrib.slim.quantization.quantization_pass
import
_o
p_real_in_out_name
from
paddle.fluid.contrib.slim.quantization.quantization_pass
import
_o
ut_scale_op_list
from
paddle.fluid.contrib.slim.quantization
import
PostTrainingQuantization
import
utils.logging
as
logging
import
paddle.fluid
as
fluid
import
os
import
utils.logging
as
logging
import
re
import
numpy
as
np
import
time
class
HumanSegPostTrainingQuantization
(
PostTrainingQuantization
):
...
...
@@ -42,7 +44,6 @@ class HumanSegPostTrainingQuantization(PostTrainingQuantization):
fp32 model. It uses calibrate data to calculate the scale factor of
quantized variables, and inserts fake quant/dequant op to obtain the
quantized model.
Args:
executor(fluid.Executor): The executor to load, run and save the
quantized model.
...
...
@@ -76,6 +77,21 @@ class HumanSegPostTrainingQuantization(PostTrainingQuantization):
Returns:
None
'''
self
.
_support_activation_quantize_type
=
[
'range_abs_max'
,
'moving_average_abs_max'
,
'abs_max'
]
self
.
_support_weight_quantize_type
=
[
'abs_max'
,
'channel_wise_abs_max'
]
self
.
_support_algo_type
=
[
'KL'
,
'abs_max'
,
'min_max'
]
self
.
_support_quantize_op_type
=
\
list
(
set
(
QuantizationTransformPass
.
_supported_quantizable_op_type
+
AddQuantDequantPass
.
_supported_quantizable_op_type
))
# Check inputs
assert
executor
is
not
None
,
"The executor cannot be None."
assert
batch_size
>
0
,
"The batch_size should be greater than 0."
assert
algo
in
self
.
_support_algo_type
,
\
"The algo should be KL, abs_max or min_max."
self
.
_executor
=
executor
self
.
_dataset
=
dataset
self
.
_batch_size
=
batch_size
...
...
@@ -84,18 +100,19 @@ class HumanSegPostTrainingQuantization(PostTrainingQuantization):
self
.
_algo
=
algo
self
.
_is_use_cache_file
=
is_use_cache_file
self
.
_cache_dir
=
cache_dir
self
.
_activation_bits
=
8
self
.
_weight_bits
=
8
self
.
_activation_quantize_type
=
'range_abs_max'
self
.
_weight_quantize_type
=
'channel_wise_abs_max'
if
self
.
_is_use_cache_file
and
not
os
.
path
.
exists
(
self
.
_cache_dir
):
os
.
mkdir
(
self
.
_cache_dir
)
supported_quantizable_op_type
=
\
QuantizationTransformPass
.
_supported_quantizable_op_type
+
\
AddQuantDequantPass
.
_supported_quantizable_op_type
if
is_full_quantize
:
self
.
_quantizable_op_type
=
s
upported_quantizabl
e_op_type
self
.
_quantizable_op_type
=
s
elf
.
_support_quantiz
e_op_type
else
:
self
.
_quantizable_op_type
=
quantizable_op_type
for
op_type
in
self
.
_quantizable_op_type
:
assert
op_type
in
s
upported_quantizabl
e_op_type
+
\
assert
op_type
in
s
elf
.
_support_quantiz
e_op_type
+
\
AddQuantDequantPass
.
_activation_type
,
\
op_type
+
" is not supported for quantization."
...
...
@@ -105,53 +122,72 @@ class HumanSegPostTrainingQuantization(PostTrainingQuantization):
self
.
_fetch_list
=
list
(
outputs
.
values
())
self
.
_data_loader
=
None
self
.
_o
p_real_in_out_name
=
_op_real_in_out_name
self
.
_o
ut_scale_op_list
=
_out_scale_op_list
self
.
_bit_length
=
8
self
.
_quantized_weight_var_name
=
set
()
self
.
_quantized_act_var_name
=
set
()
self
.
_sampling_data
=
{}
self
.
_quantized_var_scale_factor
=
{}
self
.
_quantized_var_kl_threshold
=
{}
self
.
_quantized_var_min
=
{}
self
.
_quantized_var_max
=
{}
self
.
_quantized_var_abs_max
=
{}
def
quantize
(
self
):
'''
Quantize the fp32 model. Use calibrate data to calculate the scale factor of
quantized variables, and inserts fake quant/dequant op to obtain the
quantized model.
Args:
None
Returns:
the program of quantized model.
'''
self
.
_preprocess
()
self
.
_load_model_data
()
self
.
_collect_target_varnames
()
self
.
_set_activation_persistable
()
batch_ct
=
0
for
data
in
self
.
_data_loader
():
batch_ct
+=
1
if
self
.
_batch_nums
and
batch_ct
>=
self
.
_batch_nums
:
break
batch_id
=
0
logging
.
info
(
"Start to run batch!"
)
for
data
in
self
.
_data_loader
():
start
=
time
.
time
()
self
.
_executor
.
run
(
program
=
self
.
_program
,
feed
=
data
,
fetch_list
=
self
.
_fetch_list
,
return_numpy
=
False
)
self
.
_sample_data
(
batch_id
)
if
batch_id
%
5
==
0
:
logging
.
info
(
"run batch: {}"
.
format
(
batch_id
))
if
self
.
_algo
==
"KL"
:
self
.
_sample_data
(
batch_id
)
else
:
self
.
_sample_threshold
()
end
=
time
.
time
()
logging
.
debug
(
'[Run batch data] Batch={}/{}, time_each_batch={} s.'
.
format
(
str
(
batch_id
+
1
),
str
(
batch_ct
),
str
(
end
-
start
)))
batch_id
+=
1
if
self
.
_batch_nums
and
batch_id
>=
self
.
_batch_nums
:
break
logging
.
info
(
"all run batch: "
.
format
(
batch_id
))
logging
.
info
(
"calculate scale factor ..."
)
self
.
_calculate_scale_factor
()
logging
.
info
(
"update the program ..."
)
self
.
_update_program
()
self
.
_save_output_scale
()
logging
.
info
(
"All run batch: "
.
format
(
batch_id
))
self
.
_reset_activation_persistable
()
logging
.
info
(
"Calculate scale factor ..."
)
if
self
.
_algo
==
"KL"
:
self
.
_calculate_kl_threshold
()
logging
.
info
(
"Update the program ..."
)
if
self
.
_algo
in
[
"KL"
,
"abs_max"
]:
self
.
_update_program
()
else
:
self
.
_save_input_threhold
()
logging
.
info
(
"Save ..."
)
self
.
_save_output_threshold
()
logging
.
info
(
"Finish quant!"
)
return
self
.
_program
def
save_quantized_model
(
self
,
save_model_path
):
'''
Save the quantized model to the disk.
Args:
save_model_path(str): The path to save the quantized model
Returns:
...
...
@@ -166,59 +202,78 @@ class HumanSegPostTrainingQuantization(PostTrainingQuantization):
params_filename
=
'__params__'
,
main_program
=
self
.
_program
)
def
_
preprocess
(
self
):
def
_
load_model_data
(
self
):
'''
Load model and set data loader, collect the variable names for sampling,
and set activation variables to be persistable.
Set data loader.
'''
feed_vars
=
[
fluid
.
framework
.
_get_var
(
var
.
name
,
self
.
_program
)
\
for
var
in
self
.
_feed_list
]
self
.
_data_loader
=
fluid
.
io
.
DataLoader
.
from_generator
(
feed_list
=
feed_vars
,
capacity
=
3
*
self
.
_batch_size
,
iterable
=
True
)
self
.
_data_loader
.
set_sample_list_generator
(
self
.
_dataset
.
generator
(
self
.
_batch_size
,
drop_last
=
True
),
places
=
self
.
_place
)
# collect the variable names for sampling
persistable_var_names
=
[]
for
var
in
self
.
_program
.
list_vars
():
if
var
.
persistable
:
persistable_var_names
.
append
(
var
.
name
)
for
op
in
self
.
_program
.
global_block
().
ops
:
op_type
=
op
.
type
if
op_type
in
self
.
_quantizable_op_type
:
if
op_type
in
(
"conv2d"
,
"depthwise_conv2d"
):
self
.
_quantized_act_var_name
.
add
(
op
.
input
(
"Input"
)[
0
])
self
.
_quantized_weight_var_name
.
add
(
op
.
input
(
"Filter"
)[
0
])
self
.
_quantized_act_var_name
.
add
(
op
.
output
(
"Output"
)[
0
])
elif
op_type
==
"mul"
:
if
self
.
_is_input_all_not_persistable
(
op
,
persistable_var_names
):
op
.
_set_attr
(
"skip_quant"
,
True
)
logging
.
warning
(
"Skip quant a mul op for two input variables are not persistable"
)
else
:
self
.
_quantized_act_var_name
.
add
(
op
.
input
(
"X"
)[
0
])
self
.
_quantized_weight_var_name
.
add
(
op
.
input
(
"Y"
)[
0
])
self
.
_quantized_act_var_name
.
add
(
op
.
output
(
"Out"
)[
0
])
else
:
# process other quantizable op type, the input must all not persistable
if
self
.
_is_input_all_not_persistable
(
op
,
persistable_var_names
):
input_output_name_list
=
self
.
_op_real_in_out_name
[
op_type
]
for
input_name
in
input_output_name_list
[
0
]:
for
var_name
in
op
.
input
(
input_name
):
self
.
_quantized_act_var_name
.
add
(
var_name
)
for
output_name
in
input_output_name_list
[
1
]:
for
var_name
in
op
.
output
(
output_name
):
self
.
_quantized_act_var_name
.
add
(
var_name
)
def
_calculate_kl_threshold
(
self
):
'''
Calculate the KL threshold of quantized variables.
'''
assert
self
.
_algo
==
"KL"
,
"The algo should be KL to calculate kl threshold."
ct
=
1
# Abs_max threshold for weights
for
var_name
in
self
.
_quantized_weight_var_name
:
start
=
time
.
time
()
weight_data
=
self
.
_sampling_data
[
var_name
]
weight_threshold
=
None
if
self
.
_weight_quantize_type
==
"abs_max"
:
weight_threshold
=
np
.
max
(
np
.
abs
(
weight_data
))
elif
self
.
_weight_quantize_type
==
"channel_wise_abs_max"
:
weight_threshold
=
[]
for
i
in
range
(
weight_data
.
shape
[
0
]):
abs_max_value
=
np
.
max
(
np
.
abs
(
weight_data
[
i
]))
weight_threshold
.
append
(
abs_max_value
)
self
.
_quantized_var_kl_threshold
[
var_name
]
=
weight_threshold
end
=
time
.
time
()
logging
.
debug
(
'[Calculate weight] Weight_id={}/{}, time_each_weight={} s.'
.
format
(
str
(
ct
),
str
(
len
(
self
.
_quantized_weight_var_name
)),
str
(
end
-
start
)))
ct
+=
1
# set activation variables to be persistable, so can obtain
# the tensor data in sample_data
for
var
in
self
.
_program
.
list_vars
():
if
var
.
name
in
self
.
_quantized_act_var_name
:
var
.
persistable
=
True
ct
=
1
# KL threshold for activations
if
self
.
_is_use_cache_file
:
for
var_name
in
self
.
_quantized_act_var_name
:
start
=
time
.
time
()
sampling_data
=
[]
filenames
=
[
f
for
f
in
os
.
listdir
(
self
.
_cache_dir
)
\
if
re
.
match
(
var_name
+
'_[0-9]+.npy'
,
f
)]
for
filename
in
filenames
:
file_path
=
os
.
path
.
join
(
self
.
_cache_dir
,
filename
)
sampling_data
.
append
(
np
.
load
(
file_path
))
os
.
remove
(
file_path
)
sampling_data
=
np
.
concatenate
(
sampling_data
)
self
.
_quantized_var_kl_threshold
[
var_name
]
=
\
self
.
_get_kl_scaling_factor
(
np
.
abs
(
sampling_data
))
end
=
time
.
time
()
logging
.
debug
(
'[Calculate activation] Activation_id={}/{}, time_each_activation={} s.'
.
format
(
str
(
ct
),
str
(
len
(
self
.
_quantized_act_var_name
)),
str
(
end
-
start
)))
ct
+=
1
else
:
for
var_name
in
self
.
_quantized_act_var_name
:
start
=
time
.
time
()
self
.
_sampling_data
[
var_name
]
=
np
.
concatenate
(
self
.
_sampling_data
[
var_name
])
self
.
_quantized_var_kl_threshold
[
var_name
]
=
\
self
.
_get_kl_scaling_factor
(
np
.
abs
(
self
.
_sampling_data
[
var_name
]))
end
=
time
.
time
()
logging
.
debug
(
'[Calculate activation] Activation_id={}/{}, time_each_activation={} s.'
.
format
(
str
(
ct
),
str
(
len
(
self
.
_quantized_act_var_name
)),
str
(
end
-
start
)))
ct
+=
1
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录