Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
abb0b2d6
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
abb0b2d6
编写于
6月 16, 2022
作者:
G
Guanghua Yu
提交者:
GitHub
6月 16, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[cherry-pick]Add progress bar and speed up Quantization Pass (#43454)
* Add progress bar and speed up Quantization Pass * fix typo
上级
7e940b84
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
237 addition
and
172 deletion
+237
-172
python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py
...d/contrib/slim/quantization/post_training_quantization.py
+31
-24
python/paddle/fluid/contrib/slim/quantization/quantization_pass.py
...ddle/fluid/contrib/slim/quantization/quantization_pass.py
+182
-146
python/paddle/fluid/contrib/slim/quantization/utils.py
python/paddle/fluid/contrib/slim/quantization/utils.py
+24
-2
未找到文件。
python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py
浏览文件 @
abb0b2d6
...
...
@@ -17,6 +17,10 @@ import re
import
logging
import
numpy
as
np
import
shutil
try
:
from
tqdm
import
tqdm
except
:
from
.utils
import
tqdm
from
inspect
import
isgeneratorfunction
from
....
import
io
from
....
import
core
...
...
@@ -357,38 +361,40 @@ class PostTrainingQuantization(object):
self
.
_set_activation_persistable
()
if
self
.
_algo
in
[
"KL"
,
"hist"
]:
_logger
.
info
(
"Preparation stage ..."
)
batch_id
=
0
with
tqdm
(
total
=
self
.
_batch_nums
,
bar_format
=
'Preparation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}'
,
ncols
=
80
)
as
t
:
for
data
in
self
.
_data_loader
():
self
.
_executor
.
run
(
program
=
self
.
_program
,
feed
=
data
,
fetch_list
=
self
.
_fetch_list
,
return_numpy
=
False
,
scope
=
self
.
_scope
)
self
.
_collect_activation_abs_min_max
()
batch_id
+=
1
t
.
update
()
if
self
.
_batch_nums
and
batch_id
>=
self
.
_batch_nums
:
break
self
.
_init_sampling_act_histogram
()
batch_id
=
0
with
tqdm
(
total
=
self
.
_batch_nums
,
bar_format
=
'Sampling stage, Run batch:|{bar}| {n_fmt}/{total_fmt}'
,
ncols
=
80
)
as
t
:
for
data
in
self
.
_data_loader
():
self
.
_executor
.
run
(
program
=
self
.
_program
,
feed
=
data
,
fetch_list
=
self
.
_fetch_list
,
return_numpy
=
False
,
scope
=
self
.
_scope
)
self
.
_collect_activation_abs_min_max
()
if
batch_id
%
5
==
0
:
_logger
.
info
(
"Run batch: "
+
str
(
batch_id
))
self
.
_sampling
()
batch_id
+=
1
t
.
update
()
if
self
.
_batch_nums
and
batch_id
>=
self
.
_batch_nums
:
break
_logger
.
info
(
"Finish preparation stage, all batch:"
+
str
(
batch_id
))
self
.
_init_sampling_act_histogram
()
_logger
.
info
(
"Sampling stage ..."
)
batch_id
=
0
for
data
in
self
.
_data_loader
():
self
.
_executor
.
run
(
program
=
self
.
_program
,
feed
=
data
,
fetch_list
=
self
.
_fetch_list
,
return_numpy
=
False
,
scope
=
self
.
_scope
)
self
.
_sampling
()
if
batch_id
%
5
==
0
:
_logger
.
info
(
"Run batch: "
+
str
(
batch_id
))
batch_id
+=
1
if
self
.
_batch_nums
and
batch_id
>=
self
.
_batch_nums
:
break
_logger
.
info
(
"Finish sampling stage, all batch: "
+
str
(
batch_id
))
if
self
.
_algo
==
'avg'
:
for
var_name
in
self
.
_quantized_act_var_name
:
...
...
@@ -823,8 +829,9 @@ class PostTrainingQuantization(object):
min_value
=
float
(
np
.
min
(
var_tensor
))
max_value
=
float
(
np
.
max
(
var_tensor
))
if
var_name
not
in
self
.
_sampling_act_abs_min_max
:
self
.
_sampling_act_abs_min_max
[
var_name
]
=
[
min_value
,
max_value
]
self
.
_sampling_act_abs_min_max
[
var_name
]
=
[
min_value
,
max_value
]
else
:
if
min_value
<
self
.
_sampling_act_abs_min_max
[
var_name
][
0
]:
self
.
_sampling_act_abs_min_max
[
var_name
][
0
]
=
min_value
...
...
python/paddle/fluid/contrib/slim/quantization/quantization_pass.py
浏览文件 @
abb0b2d6
...
...
@@ -14,6 +14,10 @@
import
collections
import
numpy
as
np
try
:
from
tqdm
import
tqdm
except
:
from
.utils
import
tqdm
from
.....
import
compat
as
cpt
from
....
import
core
from
....framework
import
IrGraph
...
...
@@ -294,10 +298,10 @@ class QuantizationTransformPass(object):
else
False
# if var node is weight and weight_preprocess_func is not None,
# will insert weight preprocess func
# will insert weight preprocess func
# to preorocess weight before quantization
# if var node is activation and act_preprocess_func is not None,
# will insert activation preprocess func
# if var node is activation and act_preprocess_func is not None,
# will insert activation preprocess func
# to preorocess activation before quantization
if
is_weight
and
self
.
_weight_preprocess_func
is
not
None
:
var_node
=
self
.
_insert_func
(
...
...
@@ -372,10 +376,15 @@ class QuantizationTransformPass(object):
graph
.
out_node_mapping_table
=
dict
()
# The process of _transform_forward and _transform_backward is needed in two for loops.
# The loop for transforming the forward graph:
for
op
in
ops
:
if
op
.
name
()
in
self
.
_quantizable_ops
:
if
not
self
.
_is_skip_quant
(
graph
,
op
)
and
_has_weight
(
op
):
_transform_forward
(
graph
,
op
)
with
tqdm
(
total
=
len
(
ops
),
bar_format
=
'Adding quant op for weight:|{bar}| {n_fmt}/{total_fmt}'
,
ncols
=
80
)
as
t
:
for
op
in
ops
:
if
op
.
name
()
in
self
.
_quantizable_ops
:
if
not
self
.
_is_skip_quant
(
graph
,
op
)
and
_has_weight
(
op
):
_transform_forward
(
graph
,
op
)
t
.
update
()
# The loop for renaming the inputs of backward op.
for
op
in
ops
:
if
op
.
name
()
in
self
.
_quantizable_grad_ops
and
_has_weight
(
op
):
...
...
@@ -1427,85 +1436,92 @@ class OutScaleForTrainingPass(object):
for
op
in
graph
.
all_op_nodes
():
if
op
.
name
()
in
self
.
_teller_set
:
target_ops
.
append
(
op
)
for
op
in
target_ops
:
for
output_var_name
in
utils
.
_get_op_output_var_names
(
op
):
in_node
=
graph
.
_find_node_by_name
(
op
.
outputs
,
output_var_name
)
if
in_node
.
dtype
()
not
in
\
[
core
.
VarDesc
.
VarType
.
FP64
,
core
.
VarDesc
.
VarType
.
FP32
]:
continue
with
tqdm
(
total
=
len
(
target_ops
),
bar_format
=
'Adding OutScale op:|{bar}| {n_fmt}/{total_fmt}'
,
ncols
=
80
)
as
t
:
for
op
in
target_ops
:
for
output_var_name
in
utils
.
_get_op_output_var_names
(
op
):
in_node
=
graph
.
_find_node_by_name
(
op
.
outputs
,
output_var_name
)
if
in_node
.
dtype
()
not
in
\
[
core
.
VarDesc
.
VarType
.
FP64
,
core
.
VarDesc
.
VarType
.
FP32
]:
continue
scale_node
=
graph
.
create_persistable_node
(
name
=
self
.
_scale_name
(
in_node
.
name
()),
var_type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
shape
=
[
1
],
var_dtype
=
in_node
.
dtype
())
data_type
=
'float64'
if
in_node
.
dtype
()
\
==
core
.
VarDesc
.
VarType
.
FP64
else
'float32'
_init_var_node
(
scale_node
,
np
.
ones
(
[
1
],
dtype
=
data_type
),
self
.
_scope
,
self
.
_place
)
ins
=
{
'X'
:
in_node
}
outs
=
{
'OutScale'
:
scale_node
}
if
not
self
.
_is_test
:
state_in_node
=
graph
.
create_persistable_node
(
name
=
unique_name
.
generate
(
'scale_state@'
),
scale_node
=
graph
.
create_persistable_node
(
name
=
self
.
_scale_name
(
in_node
.
name
()),
var_type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
var_dtype
=
in_node
.
dtype
(),
shape
=
[
1
])
shape
=
[
1
],
var_dtype
=
in_node
.
dtype
())
data_type
=
'float64'
if
in_node
.
dtype
()
\
==
core
.
VarDesc
.
VarType
.
FP64
else
'float32'
_init_var_node
(
s
tate_in
_node
,
s
cale
_node
,
np
.
ones
(
[
1
],
dtype
=
data_type
),
self
.
_scope
,
self
.
_place
)
accum_in_node
=
graph
.
create_persistable_node
(
name
=
unique_name
.
generate
(
'scale_accum@'
),
var_type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
var_dtype
=
in_node
.
dtype
(),
shape
=
[
1
])
_init_var_node
(
accum_in_node
,
np
.
ones
(
[
1
],
dtype
=
data_type
),
self
.
_scope
,
self
.
_place
)
state_out_node
=
graph
.
create_var_node_from_desc
(
state_in_node
.
var
())
accum_out_node
=
graph
.
create_var_node_from_desc
(
accum_in_node
.
var
())
ins
[
'InState'
]
=
state_in_node
ins
[
'InAccum'
]
=
accum_in_node
outs
[
'OutState'
]
=
state_out_node
outs
[
'OutAccum'
]
=
accum_out_node
attrs
=
{
'moving_rate'
:
self
.
_moving_rate
,
'is_test'
:
self
.
_is_test
,
'op_role'
:
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
}
scale_op_node
=
graph
.
create_op_node
(
op_type
=
'moving_average_abs_max_scale'
,
attrs
=
attrs
,
inputs
=
ins
,
outputs
=
outs
)
graph
.
link_to
(
in_node
,
scale_op_node
)
graph
.
link_to
(
scale_op_node
,
scale_node
)
if
not
self
.
_is_test
:
graph
.
link_to
(
state_in_node
,
scale_op_node
)
graph
.
link_to
(
accum_in_node
,
scale_op_node
)
graph
.
link_to
(
scale_op_node
,
state_out_node
)
graph
.
link_to
(
scale_op_node
,
accum_out_node
)
ins
=
{
'X'
:
in_node
}
outs
=
{
'OutScale'
:
scale_node
}
if
not
self
.
_is_test
:
state_in_node
=
graph
.
create_persistable_node
(
name
=
unique_name
.
generate
(
'scale_state@'
),
var_type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
var_dtype
=
in_node
.
dtype
(),
shape
=
[
1
])
_init_var_node
(
state_in_node
,
np
.
ones
(
[
1
],
dtype
=
data_type
),
self
.
_scope
,
self
.
_place
)
accum_in_node
=
graph
.
create_persistable_node
(
name
=
unique_name
.
generate
(
'scale_accum@'
),
var_type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
var_dtype
=
in_node
.
dtype
(),
shape
=
[
1
])
_init_var_node
(
accum_in_node
,
np
.
ones
(
[
1
],
dtype
=
data_type
),
self
.
_scope
,
self
.
_place
)
state_out_node
=
graph
.
create_var_node_from_desc
(
state_in_node
.
var
())
accum_out_node
=
graph
.
create_var_node_from_desc
(
accum_in_node
.
var
())
ins
[
'InState'
]
=
state_in_node
ins
[
'InAccum'
]
=
accum_in_node
outs
[
'OutState'
]
=
state_out_node
outs
[
'OutAccum'
]
=
accum_out_node
attrs
=
{
'moving_rate'
:
self
.
_moving_rate
,
'is_test'
:
self
.
_is_test
,
'op_role'
:
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
}
scale_op_node
=
graph
.
create_op_node
(
op_type
=
'moving_average_abs_max_scale'
,
attrs
=
attrs
,
inputs
=
ins
,
outputs
=
outs
)
graph
.
link_to
(
in_node
,
scale_op_node
)
graph
.
link_to
(
scale_op_node
,
scale_node
)
if
not
self
.
_is_test
:
graph
.
link_to
(
state_in_node
,
scale_op_node
)
graph
.
link_to
(
accum_in_node
,
scale_op_node
)
graph
.
link_to
(
scale_op_node
,
state_out_node
)
graph
.
link_to
(
scale_op_node
,
accum_out_node
)
t
.
update
()
return
graph
def
_scale_name
(
self
,
var_name
):
"""
Return the scale name for the var named `var_name`.
"""
return
"%s
.
scale"
%
(
var_name
)
return
"%s
@
scale"
%
(
var_name
)
class
OutScaleForInferencePass
(
object
):
...
...
@@ -1564,7 +1580,7 @@ class OutScaleForInferencePass(object):
"""
Return the scale name for the var named `var_name`.
"""
return
"%s
.
scale"
%
(
var_name
)
return
"%s
@
scale"
%
(
var_name
)
class
AddQuantDequantPass
(
object
):
...
...
@@ -1644,36 +1660,43 @@ class AddQuantDequantPass(object):
# Forward stage, insert quant_dequant op
all_op_nodes
=
graph
.
all_op_nodes
()
for
op_node
in
all_op_nodes
:
if
op_node
.
name
()
in
self
.
_quantizable_op_type
:
is_skip
=
False
if
isinstance
(
self
.
_skip_pattern
,
list
):
is_skip
=
op_node
.
op
().
has_attr
(
"op_namescope"
)
and
\
any
(
pattern
in
op_node
.
op
().
attr
(
"op_namescope"
)
for
pattern
in
self
.
_skip_pattern
)
elif
isinstance
(
self
.
_skip_pattern
,
str
):
is_skip
=
op_node
.
op
().
has_attr
(
"op_namescope"
)
and
\
op_node
.
op
().
attr
(
"op_namescope"
).
find
(
self
.
_skip_pattern
)
!=
-
1
is_quantized
=
op_node
.
op
().
has_attr
(
"quantization_type"
)
and
\
op_node
.
op
().
attr
(
"quantization_type"
)
==
"qat_with_weight"
if
is_skip
or
is_quantized
or
\
(
not
_is_input_all_not_persistable
(
graph
,
op_node
)):
continue
with
tqdm
(
total
=
len
(
all_op_nodes
),
bar_format
=
'Adding quant activation op:|{bar}| {n_fmt}/{total_fmt}'
,
ncols
=
80
)
as
t
:
for
op_node
in
all_op_nodes
:
if
op_node
.
name
()
in
self
.
_quantizable_op_type
:
is_skip
=
False
if
isinstance
(
self
.
_skip_pattern
,
list
):
is_skip
=
op_node
.
op
().
has_attr
(
"op_namescope"
)
and
\
any
(
pattern
in
op_node
.
op
().
attr
(
"op_namescope"
)
for
pattern
in
self
.
_skip_pattern
)
elif
isinstance
(
self
.
_skip_pattern
,
str
):
is_skip
=
op_node
.
op
().
has_attr
(
"op_namescope"
)
and
\
op_node
.
op
().
attr
(
"op_namescope"
).
find
(
self
.
_skip_pattern
)
!=
-
1
is_quantized
=
op_node
.
op
().
has_attr
(
"quantization_type"
)
and
\
op_node
.
op
().
attr
(
"quantization_type"
)
==
"qat_with_weight"
if
is_skip
or
is_quantized
or
\
(
not
_is_input_all_not_persistable
(
graph
,
op_node
)):
continue
op_node
.
op
().
_set_attr
(
"quantization_type"
,
"qat_without_weight"
)
op_node
.
op
().
_set_attr
(
"activation_bits"
,
self
.
_quant_bits
)
op_node
.
op
().
_set_attr
(
"with_quant_attr"
,
True
)
arg_names
=
utils
.
_get_op_input_var_names
(
op_node
)
for
arg_name
in
arg_names
:
in_node
=
graph
.
_find_node_by_name
(
op_node
.
inputs
,
arg_name
)
if
arg_name
in
dequantized_vars_map
:
quant_var_node
=
dequantized_vars_map
[
arg_name
]
else
:
quant_var_node
,
_
=
\
self
.
_inser_quant_dequant_moving_average_abs_max_op
(
graph
,
in_node
,
self
.
_quant_bits
)
dequantized_vars_map
[
arg_name
]
=
quant_var_node
graph
.
update_input_link
(
in_node
,
quant_var_node
,
op_node
)
op_node
.
op
().
_set_attr
(
"quantization_type"
,
"qat_without_weight"
)
op_node
.
op
().
_set_attr
(
"activation_bits"
,
self
.
_quant_bits
)
op_node
.
op
().
_set_attr
(
"with_quant_attr"
,
True
)
arg_names
=
utils
.
_get_op_input_var_names
(
op_node
)
for
arg_name
in
arg_names
:
in_node
=
graph
.
_find_node_by_name
(
op_node
.
inputs
,
arg_name
)
if
arg_name
in
dequantized_vars_map
:
quant_var_node
=
dequantized_vars_map
[
arg_name
]
else
:
quant_var_node
,
_
=
\
self
.
_inser_quant_dequant_moving_average_abs_max_op
(
graph
,
in_node
,
self
.
_quant_bits
)
dequantized_vars_map
[
arg_name
]
=
quant_var_node
graph
.
update_input_link
(
in_node
,
quant_var_node
,
op_node
)
t
.
update
()
# Backward stage, update input link
for
op_node
in
all_op_nodes
:
...
...
@@ -2122,10 +2145,10 @@ class QuantizationTransformPassV2(object):
else
False
# if var node is weight and weight_preprocess_func is not None,
# will insert weight preprocess func
# will insert weight preprocess func
# to preorocess weight before quantization
# if var node is activation and act_preprocess_func is not None,
# will insert activation preprocess func
# if var node is activation and act_preprocess_func is not None,
# will insert activation preprocess func
# to preorocess activation before quantization
if
is_weight
and
self
.
_weight_preprocess_func
is
not
None
:
var_node
=
self
.
_insert_func
(
...
...
@@ -2240,10 +2263,16 @@ class QuantizationTransformPassV2(object):
graph
.
out_node_mapping_table
=
dict
()
# The process of _transform_forward and _transform_backward is needed in two for loops.
# The loop for transforming the forward graph:
for
op
in
ops
:
if
op
.
name
()
in
self
.
_quantizable_ops
:
if
not
self
.
_is_skip_quant
(
graph
,
op
)
and
self
.
_has_weight
(
op
):
self
.
_transform_forward
(
graph
,
op
)
with
tqdm
(
total
=
len
(
ops
),
bar_format
=
'Adding quant op for weight:|{bar}| {n_fmt}/{total_fmt}'
,
ncols
=
80
)
as
t
:
for
op
in
ops
:
if
op
.
name
()
in
self
.
_quantizable_ops
:
if
not
self
.
_is_skip_quant
(
graph
,
op
)
and
self
.
_has_weight
(
op
):
self
.
_transform_forward
(
graph
,
op
)
t
.
update
()
# The loop for renaming the inputs of backward op.
for
op
in
ops
:
if
op
.
name
()
in
self
.
_quantizable_grad_ops
and
self
.
_has_weight
(
op
):
...
...
@@ -2346,43 +2375,50 @@ class AddQuantDequantPassV2(object):
# Forward stage, insert quant_dequant op
all_op_nodes
=
graph
.
all_op_nodes
()
for
op_node
in
all_op_nodes
:
if
op_node
.
name
()
in
self
.
_quantizable_op_type
:
is_skip
=
False
if
isinstance
(
self
.
_skip_pattern
,
list
):
is_skip
=
op_node
.
op
().
has_attr
(
"op_namescope"
)
and
\
any
(
pattern
in
op_node
.
op
().
attr
(
"op_namescope"
)
for
pattern
in
self
.
_skip_pattern
)
elif
isinstance
(
self
.
_skip_pattern
,
str
):
is_skip
=
op_node
.
op
().
has_attr
(
"op_namescope"
)
and
\
op_node
.
op
().
attr
(
"op_namescope"
).
find
(
self
.
_skip_pattern
)
!=
-
1
is_quantized
=
op_node
.
op
().
has_attr
(
"quantization_type"
)
and
\
op_node
.
op
().
attr
(
"quantization_type"
)
==
"qat_with_weight"
if
is_skip
or
is_quantized
:
continue
op_node
.
op
().
_set_attr
(
"quantization_type"
,
"qat_without_weight"
)
arg_names
=
utils
.
_get_op_input_var_names
(
op_node
)
for
arg_name
in
arg_names
:
in_node
=
graph
.
_find_node_by_name
(
op_node
.
inputs
,
arg_name
)
if
in_node
.
persistable
():
with
tqdm
(
total
=
len
(
all_op_nodes
),
bar_format
=
'Adding quant activation op:|{bar}| {n_fmt}/{total_fmt}'
,
ncols
=
80
)
as
t
:
for
op_node
in
all_op_nodes
:
if
op_node
.
name
()
in
self
.
_quantizable_op_type
:
is_skip
=
False
if
isinstance
(
self
.
_skip_pattern
,
list
):
is_skip
=
op_node
.
op
().
has_attr
(
"op_namescope"
)
and
\
any
(
pattern
in
op_node
.
op
().
attr
(
"op_namescope"
)
for
pattern
in
self
.
_skip_pattern
)
elif
isinstance
(
self
.
_skip_pattern
,
str
):
is_skip
=
op_node
.
op
().
has_attr
(
"op_namescope"
)
and
\
op_node
.
op
().
attr
(
"op_namescope"
).
find
(
self
.
_skip_pattern
)
!=
-
1
is_quantized
=
op_node
.
op
().
has_attr
(
"quantization_type"
)
and
\
op_node
.
op
().
attr
(
"quantization_type"
)
==
"qat_with_weight"
if
is_skip
or
is_quantized
:
continue
if
arg_name
in
dequantized_vars_map
:
dequant_var_node
=
dequantized_vars_map
[
arg_name
]
else
:
insert_quant_pass
=
InsertQuantizeLinear
(
self
.
_place
,
self
.
_scope
,
quant_bits
=
self
.
_quant_bits
,
quant_axis
=-
1
,
channel_wise
=
False
,
is_test
=
self
.
_is_test
)
quant_var_node
,
scale_var_node
=
insert_quant_pass
.
insert_quant_op
(
graph
,
in_node
)
dequant_var_node
=
insert_quant_pass
.
insert_dequant_op
(
graph
,
quant_var_node
,
scale_var_node
)
dequantized_vars_map
[
arg_name
]
=
dequant_var_node
graph
.
update_input_link
(
in_node
,
dequant_var_node
,
op_node
)
op_node
.
op
().
_set_attr
(
"quantization_type"
,
"qat_without_weight"
)
arg_names
=
utils
.
_get_op_input_var_names
(
op_node
)
for
arg_name
in
arg_names
:
in_node
=
graph
.
_find_node_by_name
(
op_node
.
inputs
,
arg_name
)
if
in_node
.
persistable
():
continue
if
arg_name
in
dequantized_vars_map
:
dequant_var_node
=
dequantized_vars_map
[
arg_name
]
else
:
insert_quant_pass
=
InsertQuantizeLinear
(
self
.
_place
,
self
.
_scope
,
quant_bits
=
self
.
_quant_bits
,
quant_axis
=-
1
,
channel_wise
=
False
,
is_test
=
self
.
_is_test
)
quant_var_node
,
scale_var_node
=
insert_quant_pass
.
insert_quant_op
(
graph
,
in_node
)
dequant_var_node
=
insert_quant_pass
.
insert_dequant_op
(
graph
,
quant_var_node
,
scale_var_node
)
dequantized_vars_map
[
arg_name
]
=
dequant_var_node
graph
.
update_input_link
(
in_node
,
dequant_var_node
,
op_node
)
t
.
update
()
# Backward stage, update input link
for
op_node
in
all_op_nodes
:
...
...
python/paddle/fluid/contrib/slim/quantization/utils.py
浏览文件 @
abb0b2d6
...
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
sys
import
numpy
as
np
from
....framework
import
IrNode
from
....framework
import
Operator
...
...
@@ -52,7 +53,6 @@ _act_supported_quantizable_op_type = [
"leaky_relu"
,
"tanh"
,
"swish"
,
"scale"
,
"transpose"
,
"transpose2"
,
"sigmoid"
,
...
...
@@ -162,7 +162,6 @@ _op_real_in_out_name = {
"sigmoid"
:
[[
"X"
],
[
"Out"
]],
"elementwise_mul"
:
[[
"X"
,
"Y"
],
[
"Out"
]],
"elementwise_pow"
:
[[
"X"
,
"Y"
],
[
"Out"
]],
"scale"
:
[[
"X"
],
[
"Out"
]],
"hard_swish"
:
[[
"X"
],
[
"Out"
]],
"hard_sigmoid"
:
[[
"X"
],
[
"Out"
]],
"gru"
:
[[
"Input"
,
"Weight"
],
[
"Hidden"
]],
...
...
@@ -414,3 +413,26 @@ def calculate_quant_cos_error(orig_tensor, qdq_tensor):
cos_sim
=
np
.
inner
(
orig_tensor
.
flatten
(),
qdq_tensor
.
flatten
())
\
/
(
np
.
linalg
.
norm
(
orig_tensor
.
flatten
())
*
np
.
linalg
.
norm
(
qdq_tensor
.
flatten
()))
return
cos_sim
class
tqdm
(
object
):
def
__init__
(
self
,
total
,
bar_format
=
'Loading|{bar}'
,
ncols
=
80
):
self
.
total
=
total
self
.
bar_format
=
bar_format
self
.
ncols
=
ncols
self
.
n
=
0
def
update
(
self
,
n
=
1
):
self
.
n
+=
n
a
=
"="
*
round
((
self
.
n
/
self
.
total
)
*
self
.
ncols
)
b
=
" "
*
(
self
.
ncols
-
len
(
a
))
prefix
=
self
.
bar_format
.
split
(
'|'
)[
0
]
sys
.
stderr
.
write
(
"
\r
{}|{}=>{}| {}/{}"
.
format
(
prefix
,
a
,
b
,
self
.
n
,
self
.
total
))
sys
.
stderr
.
flush
()
def
__enter__
(
self
):
return
self
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
sys
.
stderr
.
write
(
'
\n
'
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录