Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSeg
提交
cbce658d
P
PaddleSeg
项目概览
PaddlePaddle
/
PaddleSeg
通知
285
Star
8
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSeg
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
cbce658d
编写于
5月 19, 2020
作者:
W
wuyefeilin
提交者:
GitHub
5月 19, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update post_quantization.py (#255)
* update train.py * update post_quantization.py
上级
27121d0f
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
126 addition
and
71 deletion
+126
-71
contrib/HumanSeg/utils/post_quantization.py
contrib/HumanSeg/utils/post_quantization.py
+126
-71
未找到文件。
contrib/HumanSeg/utils/post_quantization.py
浏览文件 @
cbce658d
...
@@ -14,12 +14,14 @@
...
@@ -14,12 +14,14 @@
from
paddle.fluid.contrib.slim.quantization.quantization_pass
import
QuantizationTransformPass
from
paddle.fluid.contrib.slim.quantization.quantization_pass
import
QuantizationTransformPass
from
paddle.fluid.contrib.slim.quantization.quantization_pass
import
AddQuantDequantPass
from
paddle.fluid.contrib.slim.quantization.quantization_pass
import
AddQuantDequantPass
from
paddle.fluid.contrib.slim.quantization.quantization_pass
import
_o
p_real_in_out_name
from
paddle.fluid.contrib.slim.quantization.quantization_pass
import
_o
ut_scale_op_list
from
paddle.fluid.contrib.slim.quantization
import
PostTrainingQuantization
from
paddle.fluid.contrib.slim.quantization
import
PostTrainingQuantization
import
utils.logging
as
logging
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
import
os
import
os
import
re
import
utils.logging
as
logging
import
numpy
as
np
import
time
class
HumanSegPostTrainingQuantization
(
PostTrainingQuantization
):
class
HumanSegPostTrainingQuantization
(
PostTrainingQuantization
):
...
@@ -42,7 +44,6 @@ class HumanSegPostTrainingQuantization(PostTrainingQuantization):
...
@@ -42,7 +44,6 @@ class HumanSegPostTrainingQuantization(PostTrainingQuantization):
fp32 model. It uses calibrate data to calculate the scale factor of
fp32 model. It uses calibrate data to calculate the scale factor of
quantized variables, and inserts fake quant/dequant op to obtain the
quantized variables, and inserts fake quant/dequant op to obtain the
quantized model.
quantized model.
Args:
Args:
executor(fluid.Executor): The executor to load, run and save the
executor(fluid.Executor): The executor to load, run and save the
quantized model.
quantized model.
...
@@ -76,6 +77,21 @@ class HumanSegPostTrainingQuantization(PostTrainingQuantization):
...
@@ -76,6 +77,21 @@ class HumanSegPostTrainingQuantization(PostTrainingQuantization):
Returns:
Returns:
None
None
'''
'''
self
.
_support_activation_quantize_type
=
[
'range_abs_max'
,
'moving_average_abs_max'
,
'abs_max'
]
self
.
_support_weight_quantize_type
=
[
'abs_max'
,
'channel_wise_abs_max'
]
self
.
_support_algo_type
=
[
'KL'
,
'abs_max'
,
'min_max'
]
self
.
_support_quantize_op_type
=
\
list
(
set
(
QuantizationTransformPass
.
_supported_quantizable_op_type
+
AddQuantDequantPass
.
_supported_quantizable_op_type
))
# Check inputs
assert
executor
is
not
None
,
"The executor cannot be None."
assert
batch_size
>
0
,
"The batch_size should be greater than 0."
assert
algo
in
self
.
_support_algo_type
,
\
"The algo should be KL, abs_max or min_max."
self
.
_executor
=
executor
self
.
_executor
=
executor
self
.
_dataset
=
dataset
self
.
_dataset
=
dataset
self
.
_batch_size
=
batch_size
self
.
_batch_size
=
batch_size
...
@@ -84,18 +100,19 @@ class HumanSegPostTrainingQuantization(PostTrainingQuantization):
...
@@ -84,18 +100,19 @@ class HumanSegPostTrainingQuantization(PostTrainingQuantization):
self
.
_algo
=
algo
self
.
_algo
=
algo
self
.
_is_use_cache_file
=
is_use_cache_file
self
.
_is_use_cache_file
=
is_use_cache_file
self
.
_cache_dir
=
cache_dir
self
.
_cache_dir
=
cache_dir
self
.
_activation_bits
=
8
self
.
_weight_bits
=
8
self
.
_activation_quantize_type
=
'range_abs_max'
self
.
_weight_quantize_type
=
'channel_wise_abs_max'
if
self
.
_is_use_cache_file
and
not
os
.
path
.
exists
(
self
.
_cache_dir
):
if
self
.
_is_use_cache_file
and
not
os
.
path
.
exists
(
self
.
_cache_dir
):
os
.
mkdir
(
self
.
_cache_dir
)
os
.
mkdir
(
self
.
_cache_dir
)
supported_quantizable_op_type
=
\
QuantizationTransformPass
.
_supported_quantizable_op_type
+
\
AddQuantDequantPass
.
_supported_quantizable_op_type
if
is_full_quantize
:
if
is_full_quantize
:
self
.
_quantizable_op_type
=
s
upported_quantizabl
e_op_type
self
.
_quantizable_op_type
=
s
elf
.
_support_quantiz
e_op_type
else
:
else
:
self
.
_quantizable_op_type
=
quantizable_op_type
self
.
_quantizable_op_type
=
quantizable_op_type
for
op_type
in
self
.
_quantizable_op_type
:
for
op_type
in
self
.
_quantizable_op_type
:
assert
op_type
in
s
upported_quantizabl
e_op_type
+
\
assert
op_type
in
s
elf
.
_support_quantiz
e_op_type
+
\
AddQuantDequantPass
.
_activation_type
,
\
AddQuantDequantPass
.
_activation_type
,
\
op_type
+
" is not supported for quantization."
op_type
+
" is not supported for quantization."
...
@@ -105,53 +122,72 @@ class HumanSegPostTrainingQuantization(PostTrainingQuantization):
...
@@ -105,53 +122,72 @@ class HumanSegPostTrainingQuantization(PostTrainingQuantization):
self
.
_fetch_list
=
list
(
outputs
.
values
())
self
.
_fetch_list
=
list
(
outputs
.
values
())
self
.
_data_loader
=
None
self
.
_data_loader
=
None
self
.
_o
p_real_in_out_name
=
_op_real_in_out_name
self
.
_o
ut_scale_op_list
=
_out_scale_op_list
self
.
_bit_length
=
8
self
.
_bit_length
=
8
self
.
_quantized_weight_var_name
=
set
()
self
.
_quantized_weight_var_name
=
set
()
self
.
_quantized_act_var_name
=
set
()
self
.
_quantized_act_var_name
=
set
()
self
.
_sampling_data
=
{}
self
.
_sampling_data
=
{}
self
.
_quantized_var_scale_factor
=
{}
self
.
_quantized_var_kl_threshold
=
{}
self
.
_quantized_var_min
=
{}
self
.
_quantized_var_max
=
{}
self
.
_quantized_var_abs_max
=
{}
def
quantize
(
self
):
def
quantize
(
self
):
'''
'''
Quantize the fp32 model. Use calibrate data to calculate the scale factor of
Quantize the fp32 model. Use calibrate data to calculate the scale factor of
quantized variables, and inserts fake quant/dequant op to obtain the
quantized variables, and inserts fake quant/dequant op to obtain the
quantized model.
quantized model.
Args:
Args:
None
None
Returns:
Returns:
the program of quantized model.
the program of quantized model.
'''
'''
self
.
_preprocess
()
self
.
_load_model_data
()
self
.
_collect_target_varnames
()
self
.
_set_activation_persistable
()
batch_ct
=
0
for
data
in
self
.
_data_loader
():
batch_ct
+=
1
if
self
.
_batch_nums
and
batch_ct
>=
self
.
_batch_nums
:
break
batch_id
=
0
batch_id
=
0
logging
.
info
(
"Start to run batch!"
)
for
data
in
self
.
_data_loader
():
for
data
in
self
.
_data_loader
():
start
=
time
.
time
()
self
.
_executor
.
run
(
self
.
_executor
.
run
(
program
=
self
.
_program
,
program
=
self
.
_program
,
feed
=
data
,
feed
=
data
,
fetch_list
=
self
.
_fetch_list
,
fetch_list
=
self
.
_fetch_list
,
return_numpy
=
False
)
return_numpy
=
False
)
if
self
.
_algo
==
"KL"
:
self
.
_sample_data
(
batch_id
)
self
.
_sample_data
(
batch_id
)
else
:
if
batch_id
%
5
==
0
:
self
.
_sample_threshold
()
logging
.
info
(
"run batch: {}"
.
format
(
batch_id
))
end
=
time
.
time
()
logging
.
debug
(
'[Run batch data] Batch={}/{}, time_each_batch={} s.'
.
format
(
str
(
batch_id
+
1
),
str
(
batch_ct
),
str
(
end
-
start
)))
batch_id
+=
1
batch_id
+=
1
if
self
.
_batch_nums
and
batch_id
>=
self
.
_batch_nums
:
if
self
.
_batch_nums
and
batch_id
>=
self
.
_batch_nums
:
break
break
logging
.
info
(
"all run batch: "
.
format
(
batch_id
))
logging
.
info
(
"All run batch: "
.
format
(
batch_id
))
logging
.
info
(
"calculate scale factor ..."
)
self
.
_reset_activation_persistable
()
self
.
_calculate_scale_factor
()
logging
.
info
(
"Calculate scale factor ..."
)
logging
.
info
(
"update the program ..."
)
if
self
.
_algo
==
"KL"
:
self
.
_calculate_kl_threshold
()
logging
.
info
(
"Update the program ..."
)
if
self
.
_algo
in
[
"KL"
,
"abs_max"
]:
self
.
_update_program
()
self
.
_update_program
()
else
:
self
.
_save_output_scale
()
self
.
_save_input_threhold
()
logging
.
info
(
"Save ..."
)
self
.
_save_output_threshold
()
logging
.
info
(
"Finish quant!"
)
return
self
.
_program
return
self
.
_program
def
save_quantized_model
(
self
,
save_model_path
):
def
save_quantized_model
(
self
,
save_model_path
):
'''
'''
Save the quantized model to the disk.
Save the quantized model to the disk.
Args:
Args:
save_model_path(str): The path to save the quantized model
save_model_path(str): The path to save the quantized model
Returns:
Returns:
...
@@ -166,59 +202,78 @@ class HumanSegPostTrainingQuantization(PostTrainingQuantization):
...
@@ -166,59 +202,78 @@ class HumanSegPostTrainingQuantization(PostTrainingQuantization):
params_filename
=
'__params__'
,
params_filename
=
'__params__'
,
main_program
=
self
.
_program
)
main_program
=
self
.
_program
)
def
_
preprocess
(
self
):
def
_
load_model_data
(
self
):
'''
'''
Load model and set data loader, collect the variable names for sampling,
Set data loader.
and set activation variables to be persistable.
'''
'''
feed_vars
=
[
fluid
.
framework
.
_get_var
(
var
.
name
,
self
.
_program
)
\
feed_vars
=
[
fluid
.
framework
.
_get_var
(
var
.
name
,
self
.
_program
)
\
for
var
in
self
.
_feed_list
]
for
var
in
self
.
_feed_list
]
self
.
_data_loader
=
fluid
.
io
.
DataLoader
.
from_generator
(
self
.
_data_loader
=
fluid
.
io
.
DataLoader
.
from_generator
(
feed_list
=
feed_vars
,
capacity
=
3
*
self
.
_batch_size
,
iterable
=
True
)
feed_list
=
feed_vars
,
capacity
=
3
*
self
.
_batch_size
,
iterable
=
True
)
self
.
_data_loader
.
set_sample_list_generator
(
self
.
_data_loader
.
set_sample_list_generator
(
self
.
_dataset
.
generator
(
self
.
_batch_size
,
drop_last
=
True
),
self
.
_dataset
.
generator
(
self
.
_batch_size
,
drop_last
=
True
),
places
=
self
.
_place
)
places
=
self
.
_place
)
# collect the variable names for sampling
def
_calculate_kl_threshold
(
self
):
persistable_var_names
=
[]
'''
for
var
in
self
.
_program
.
list_vars
():
Calculate the KL threshold of quantized variables.
if
var
.
persistable
:
'''
persistable_var_names
.
append
(
var
.
name
)
assert
self
.
_algo
==
"KL"
,
"The algo should be KL to calculate kl threshold."
ct
=
1
# Abs_max threshold for weights
for
var_name
in
self
.
_quantized_weight_var_name
:
start
=
time
.
time
()
weight_data
=
self
.
_sampling_data
[
var_name
]
weight_threshold
=
None
if
self
.
_weight_quantize_type
==
"abs_max"
:
weight_threshold
=
np
.
max
(
np
.
abs
(
weight_data
))
elif
self
.
_weight_quantize_type
==
"channel_wise_abs_max"
:
weight_threshold
=
[]
for
i
in
range
(
weight_data
.
shape
[
0
]):
abs_max_value
=
np
.
max
(
np
.
abs
(
weight_data
[
i
]))
weight_threshold
.
append
(
abs_max_value
)
self
.
_quantized_var_kl_threshold
[
var_name
]
=
weight_threshold
end
=
time
.
time
()
logging
.
debug
(
'[Calculate weight] Weight_id={}/{}, time_each_weight={} s.'
.
format
(
str
(
ct
),
str
(
len
(
self
.
_quantized_weight_var_name
)),
str
(
end
-
start
)))
ct
+=
1
for
op
in
self
.
_program
.
global_block
().
ops
:
ct
=
1
op_type
=
op
.
type
# KL threshold for activations
if
op_type
in
self
.
_quantizable_op_type
:
if
self
.
_is_use_cache_file
:
if
op_type
in
(
"conv2d"
,
"depthwise_conv2d"
):
for
var_name
in
self
.
_quantized_act_var_name
:
self
.
_quantized_act_var_name
.
add
(
op
.
input
(
"Input"
)[
0
])
start
=
time
.
time
()
self
.
_quantized_weight_var_name
.
add
(
op
.
input
(
"Filter"
)[
0
])
sampling_data
=
[]
self
.
_quantized_act_var_name
.
add
(
op
.
output
(
"Output"
)[
0
])
filenames
=
[
f
for
f
in
os
.
listdir
(
self
.
_cache_dir
)
\
elif
op_type
==
"mul"
:
if
re
.
match
(
var_name
+
'_[0-9]+.npy'
,
f
)]
if
self
.
_is_input_all_not_persistable
(
for
filename
in
filenames
:
op
,
persistable_var_names
):
file_path
=
os
.
path
.
join
(
self
.
_cache_dir
,
filename
)
op
.
_set_attr
(
"skip_quant"
,
True
)
sampling_data
.
append
(
np
.
load
(
file_path
))
logging
.
warning
(
os
.
remove
(
file_path
)
"Skip quant a mul op for two input variables are not persistable"
sampling_data
=
np
.
concatenate
(
sampling_data
)
)
self
.
_quantized_var_kl_threshold
[
var_name
]
=
\
self
.
_get_kl_scaling_factor
(
np
.
abs
(
sampling_data
))
end
=
time
.
time
()
logging
.
debug
(
'[Calculate activation] Activation_id={}/{}, time_each_activation={} s.'
.
format
(
str
(
ct
),
str
(
len
(
self
.
_quantized_act_var_name
)),
str
(
end
-
start
)))
ct
+=
1
else
:
else
:
self
.
_quantized_act_var_name
.
add
(
op
.
input
(
"X"
)[
0
])
for
var_name
in
self
.
_quantized_act_var_name
:
self
.
_quantized_weight_var_name
.
add
(
op
.
input
(
"Y"
)[
0
])
start
=
time
.
time
()
self
.
_quantized_act_var_name
.
add
(
op
.
output
(
"Out"
)[
0
])
self
.
_sampling_data
[
var_name
]
=
np
.
concatenate
(
else
:
self
.
_sampling_data
[
var_name
])
# process other quantizable op type, the input must all not persistable
self
.
_quantized_var_kl_threshold
[
var_name
]
=
\
if
self
.
_is_input_all_not_persistable
(
self
.
_get_kl_scaling_factor
(
np
.
abs
(
self
.
_sampling_data
[
var_name
]))
op
,
persistable_var_names
):
end
=
time
.
time
()
input_output_name_list
=
self
.
_op_real_in_out_name
[
logging
.
debug
(
op_type
]
'[Calculate activation] Activation_id={}/{}, time_each_activation={} s.'
for
input_name
in
input_output_name_list
[
0
]:
.
format
(
for
var_name
in
op
.
input
(
input_name
):
str
(
ct
),
str
(
len
(
self
.
_quantized_act_var_name
)),
self
.
_quantized_act_var_name
.
add
(
var_name
)
str
(
end
-
start
)))
for
output_name
in
input_output_name_list
[
1
]:
ct
+=
1
for
var_name
in
op
.
output
(
output_name
):
self
.
_quantized_act_var_name
.
add
(
var_name
)
# set activation variables to be persistable, so can obtain
# the tensor data in sample_data
for
var
in
self
.
_program
.
list_vars
():
if
var
.
name
in
self
.
_quantized_act_var_name
:
var
.
persistable
=
True
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录