Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
f7f5044b
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f7f5044b
编写于
3月 28, 2019
作者:
Z
Zhen Wang
提交者:
GitHub
3月 28, 2019
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #16489 from wzzju/fix_slim_quant_bugs
Clean codes and fix some bugs.
上级
69cb9792
46e1bb06
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
93 addition
and
108 deletion
+93
-108
python/paddle/fluid/contrib/slim/quantization/quantization_pass.py
...ddle/fluid/contrib/slim/quantization/quantization_pass.py
+65
-55
python/paddle/fluid/contrib/slim/quantization/quantization_strategy.py
.../fluid/contrib/slim/quantization/quantization_strategy.py
+11
-5
python/paddle/fluid/contrib/slim/tests/quantization/compress.yaml
...addle/fluid/contrib/slim/tests/quantization/compress.yaml
+2
-0
python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py
...paddle/fluid/contrib/slim/tests/test_quantization_pass.py
+0
-3
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+15
-45
未找到文件。
python/paddle/fluid/contrib/slim/quantization/quantization_pass.py
浏览文件 @
f7f5044b
...
@@ -26,6 +26,17 @@ __all__ = [
...
@@ -26,6 +26,17 @@ __all__ = [
]
]
def
_init_var_node
(
var_node
,
value
,
scope
,
place
):
assert
isinstance
(
value
,
np
.
ndarray
),
'The type of value should be numpy array.'
assert
scope
is
not
None
,
\
'The scope cannot be set None.'
assert
place
is
not
None
,
\
'The place cannot be set None.'
tensor
=
scope
.
var
(
var_node
.
name
()).
get_tensor
()
tensor
.
set
(
value
,
place
)
class
QuantizationTransformPass
(
object
):
class
QuantizationTransformPass
(
object
):
def
__init__
(
self
,
def
__init__
(
self
,
scope
=
None
,
scope
=
None
,
...
@@ -88,14 +99,14 @@ class QuantizationTransformPass(object):
...
@@ -88,14 +99,14 @@ class QuantizationTransformPass(object):
assert
activation_quantize_type
!=
'channel_wise_abs_max'
,
"The activation quantization type does not support 'channel_wise_abs_max'."
assert
activation_quantize_type
!=
'channel_wise_abs_max'
,
"The activation quantization type does not support 'channel_wise_abs_max'."
if
activation_quantize_type
not
in
quant_type
:
if
activation_quantize_type
not
in
quant_type
:
raise
ValueError
(
raise
ValueError
(
"Unknown activation_quantize_type : '%s'. It can only be "
,
"Unknown activation_quantize_type : '%s'. It can only be "
"'abs_max' or 'range_abs_max' or 'moving_average_abs_max'."
,
"'abs_max' or 'range_abs_max' or 'moving_average_abs_max'."
%
str
(
activation_quantize_type
))
(
str
(
activation_quantize_type
)
))
if
weight_quantize_type
not
in
quant_type
:
if
weight_quantize_type
not
in
quant_type
:
raise
ValueError
(
raise
ValueError
(
"Unknown weight_quantize_type: '%s'. It can only be "
,
"Unknown weight_quantize_type: '%s'. It can only be "
"'abs_max' or 'channel_wise_abs_max' or 'range_abs_max' or 'moving_average_abs_max'."
,
"'abs_max' or 'channel_wise_abs_max' or 'range_abs_max' or 'moving_average_abs_max'."
str
(
weight_quantize_type
))
%
(
str
(
weight_quantize_type
)
))
self
.
_activation_quantize_type
=
activation_quantize_type
self
.
_activation_quantize_type
=
activation_quantize_type
self
.
_weight_quantize_type
=
weight_quantize_type
self
.
_weight_quantize_type
=
weight_quantize_type
...
@@ -121,8 +132,6 @@ class QuantizationTransformPass(object):
...
@@ -121,8 +132,6 @@ class QuantizationTransformPass(object):
"""
"""
assert
isinstance
(
graph
,
assert
isinstance
(
graph
,
IrGraph
),
'graph must be the instance of IrGraph.'
IrGraph
),
'graph must be the instance of IrGraph.'
#sequential_execution = core.get_pass('sequential_execution_pass')
#sequential_execution.apply(graph.graph)
self
.
_is_test
=
graph
.
is_test
()
self
.
_is_test
=
graph
.
is_test
()
# marked the variable which has been dequantized.
# marked the variable which has been dequantized.
dequantized_vars
=
collections
.
OrderedDict
()
dequantized_vars
=
collections
.
OrderedDict
()
...
@@ -203,9 +212,12 @@ class QuantizationTransformPass(object):
...
@@ -203,9 +212,12 @@ class QuantizationTransformPass(object):
var_type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
var_type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
shape
=
[
1
],
shape
=
[
1
],
var_dtype
=
core
.
VarDesc
.
VarType
.
INT64
)
var_dtype
=
core
.
VarDesc
.
VarType
.
INT64
)
self
.
_init_var_node
(
_init_var_node
(
global_step_in
,
np
.
zeros
(
global_step_in
,
[
1
],
dtype
=
'int64'
))
np
.
zeros
(
[
1
],
dtype
=
'int64'
),
self
.
_scope
,
self
.
_place
)
global_step_out
=
graph
.
create_var_node_from_desc
(
global_step_out
=
graph
.
create_var_node_from_desc
(
global_step_in
.
var
())
global_step_in
.
var
())
# The attribute of `op_role` is needed by ParallelExecutor.
# The attribute of `op_role` is needed by ParallelExecutor.
...
@@ -284,7 +296,12 @@ class QuantizationTransformPass(object):
...
@@ -284,7 +296,12 @@ class QuantizationTransformPass(object):
var_dtype
=
var_node
.
dtype
())
var_dtype
=
var_node
.
dtype
())
data_type
=
'float64'
if
var_node
.
dtype
(
data_type
=
'float64'
if
var_node
.
dtype
(
)
==
core
.
VarDesc
.
VarType
.
FP64
else
'float32'
)
==
core
.
VarDesc
.
VarType
.
FP64
else
'float32'
self
.
_init_var_node
(
scale_in_node
,
np
.
array
([
0.001
],
dtype
=
data_type
))
_init_var_node
(
scale_in_node
,
np
.
array
(
[
0.001
],
dtype
=
data_type
),
self
.
_scope
,
self
.
_place
)
scale_out_node
=
graph
.
create_var_node_from_desc
(
scale_in_node
.
var
())
scale_out_node
=
graph
.
create_var_node_from_desc
(
scale_in_node
.
var
())
inputs
=
{
'X'
:
var_node
,
'InScale'
:
scale_in_node
}
inputs
=
{
'X'
:
var_node
,
'InScale'
:
scale_in_node
}
...
@@ -299,9 +316,13 @@ class QuantizationTransformPass(object):
...
@@ -299,9 +316,13 @@ class QuantizationTransformPass(object):
var_dtype
=
var_node
.
dtype
())
var_dtype
=
var_node
.
dtype
())
data_type
=
'float64'
if
var_node
.
dtype
(
data_type
=
'float64'
if
var_node
.
dtype
(
)
==
core
.
VarDesc
.
VarType
.
FP64
else
'float32'
)
==
core
.
VarDesc
.
VarType
.
FP64
else
'float32'
self
.
_init_var_node
(
_init_var_node
(
scales_node
,
np
.
zeros
(
scales_node
,
[
self
.
_window_size
],
dtype
=
data_type
))
np
.
zeros
(
[
self
.
_window_size
],
dtype
=
data_type
),
self
.
_scope
,
self
.
_place
)
inputs
[
'Iter'
]
=
self
.
_global_step
inputs
[
'Iter'
]
=
self
.
_global_step
outputs
[
'OutScales'
]
=
scales_node
outputs
[
'OutScales'
]
=
scales_node
attrs
=
{
attrs
=
{
...
@@ -343,7 +364,12 @@ class QuantizationTransformPass(object):
...
@@ -343,7 +364,12 @@ class QuantizationTransformPass(object):
var_dtype
=
var_node
.
dtype
())
var_dtype
=
var_node
.
dtype
())
data_type
=
'float64'
if
var_node
.
dtype
(
data_type
=
'float64'
if
var_node
.
dtype
(
)
==
core
.
VarDesc
.
VarType
.
FP64
else
'float32'
)
==
core
.
VarDesc
.
VarType
.
FP64
else
'float32'
self
.
_init_var_node
(
scale_in_node
,
np
.
array
([
0.001
],
dtype
=
data_type
))
_init_var_node
(
scale_in_node
,
np
.
array
(
[
0.001
],
dtype
=
data_type
),
self
.
_scope
,
self
.
_place
)
scale_out_node
=
graph
.
create_var_node_from_desc
(
scale_in_node
.
var
())
scale_out_node
=
graph
.
create_var_node_from_desc
(
scale_in_node
.
var
())
ins
=
{
'X'
:
var_node
,
'InScale'
:
scale_in_node
}
ins
=
{
'X'
:
var_node
,
'InScale'
:
scale_in_node
}
...
@@ -356,13 +382,23 @@ class QuantizationTransformPass(object):
...
@@ -356,13 +382,23 @@ class QuantizationTransformPass(object):
shape
=
[
1
])
shape
=
[
1
])
data_type
=
'float64'
if
var_node
.
dtype
(
data_type
=
'float64'
if
var_node
.
dtype
(
)
==
core
.
VarDesc
.
VarType
.
FP64
else
'float32'
)
==
core
.
VarDesc
.
VarType
.
FP64
else
'float32'
self
.
_init_var_node
(
scale_in_node
,
np
.
ones
([
1
],
dtype
=
data_type
))
_init_var_node
(
scale_in_node
,
np
.
ones
(
[
1
],
dtype
=
data_type
),
self
.
_scope
,
self
.
_place
)
accum_in_node
=
graph
.
create_persistable_node
(
accum_in_node
=
graph
.
create_persistable_node
(
name
=
unique_name
.
generate
(
'accum'
),
name
=
unique_name
.
generate
(
'accum'
),
var_type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
var_type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
var_dtype
=
var_node
.
dtype
(),
var_dtype
=
var_node
.
dtype
(),
shape
=
[
1
])
shape
=
[
1
])
self
.
_init_var_node
(
accum_in_node
,
np
.
ones
([
1
],
dtype
=
data_type
))
_init_var_node
(
accum_in_node
,
np
.
ones
(
[
1
],
dtype
=
data_type
),
self
.
_scope
,
self
.
_place
)
state_out_node
=
graph
.
create_var_node_from_desc
(
state_in_node
.
var
(
state_out_node
=
graph
.
create_var_node_from_desc
(
state_in_node
.
var
(
))
))
accum_out_node
=
graph
.
create_var_node_from_desc
(
accum_in_node
.
var
(
accum_out_node
=
graph
.
create_var_node_from_desc
(
accum_in_node
.
var
(
...
@@ -482,16 +518,6 @@ class QuantizationTransformPass(object):
...
@@ -482,16 +518,6 @@ class QuantizationTransformPass(object):
graph
.
link_to
(
dequant_op_node
,
dequant_var_node
)
graph
.
link_to
(
dequant_op_node
,
dequant_var_node
)
return
dequant_var_node
return
dequant_var_node
def
_init_var_node
(
self
,
var_node
,
value
):
assert
isinstance
(
value
,
np
.
ndarray
),
'The type of value should be numpy array.'
assert
self
.
_scope
is
not
None
,
\
'The scope cannot be set None when activation_quantize_type equals to range_abs_max.'
assert
self
.
_place
is
not
None
,
\
'The place cannot be set None when activation_quantize_type equals to range_abs_max.'
tensor
=
self
.
_scope
.
var
(
var_node
.
name
()).
get_tensor
()
tensor
.
set
(
value
,
self
.
_place
)
def
_quantized_var_name
(
self
,
var_name
):
def
_quantized_var_name
(
self
,
var_name
):
"""
"""
Return quantized variable name for the input `var_name`.
Return quantized variable name for the input `var_name`.
...
@@ -594,8 +620,8 @@ class QuantizationFreezePass(object):
...
@@ -594,8 +620,8 @@ class QuantizationFreezePass(object):
self
.
_weight_bits
)
self
.
_weight_bits
)
self
.
_restore_var
(
input_arg_name
,
quantized_param_v
)
self
.
_restore_var
(
input_arg_name
,
quantized_param_v
)
else
:
else
:
scale_v
=
self
.
_to_node
(
op_node
.
outputs
,
scale_v
=
graph
.
_find_node_by_name
(
op_node
.
output
(
'OutScale'
)[
0
])
op_node
.
outputs
,
op_node
.
output
(
'OutScale'
)[
0
])
self
.
_var_scale_map
[
input_arg_name
]
=
scale_v
self
.
_var_scale_map
[
input_arg_name
]
=
scale_v
ops
=
graph
.
all_op_nodes
()
ops
=
graph
.
all_op_nodes
()
...
@@ -627,8 +653,8 @@ class QuantizationFreezePass(object):
...
@@ -627,8 +653,8 @@ class QuantizationFreezePass(object):
return
graph
return
graph
def
_remove_fake_quant_and_dequant_op
(
self
,
graph
,
op_node
):
def
_remove_fake_quant_and_dequant_op
(
self
,
graph
,
op_node
):
k
=
self
.
_to_nod
e
(
op_node
.
outputs
,
op_node
.
output
(
'Out'
)[
0
])
k
=
graph
.
_find_node_by_nam
e
(
op_node
.
outputs
,
op_node
.
output
(
'Out'
)[
0
])
v
=
self
.
_to_nod
e
(
op_node
.
inputs
,
op_node
.
input
(
'X'
)[
0
])
v
=
graph
.
_find_node_by_nam
e
(
op_node
.
inputs
,
op_node
.
input
(
'X'
)[
0
])
if
v
.
node
not
in
self
.
_op_input_rename_map
:
if
v
.
node
not
in
self
.
_op_input_rename_map
:
self
.
_op_input_rename_map
[
k
.
node
]
=
v
self
.
_op_input_rename_map
[
k
.
node
]
=
v
else
:
else
:
...
@@ -663,8 +689,8 @@ class QuantizationFreezePass(object):
...
@@ -663,8 +689,8 @@ class QuantizationFreezePass(object):
raise
ValueError
(
"Only support one output, but op %s has"
raise
ValueError
(
"Only support one output, but op %s has"
" more than one output."
%
(
op_node
.
name
()))
" more than one output."
%
(
op_node
.
name
()))
output_var_node
=
self
.
_to_node
(
op_node
.
outputs
,
output_var_node
=
graph
.
_find_node_by_name
(
op_node
.
output_arg_names
()[
0
])
op_node
.
outputs
,
op_node
.
output_arg_names
()[
0
])
weight_scale_node
=
graph
.
create_persistable_node
(
weight_scale_node
=
graph
.
create_persistable_node
(
name
=
unique_name
.
generate
(
'channel_scale'
),
name
=
unique_name
.
generate
(
'channel_scale'
),
var_type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
var_type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
...
@@ -672,7 +698,9 @@ class QuantizationFreezePass(object):
...
@@ -672,7 +698,9 @@ class QuantizationFreezePass(object):
var_dtype
=
output_var_node
.
dtype
())
var_dtype
=
output_var_node
.
dtype
())
data_type
=
'float64'
if
output_var_node
.
dtype
(
data_type
=
'float64'
if
output_var_node
.
dtype
(
)
==
core
.
VarDesc
.
VarType
.
FP64
else
'float32'
)
==
core
.
VarDesc
.
VarType
.
FP64
else
'float32'
self
.
_init_var_node
(
weight_scale_node
,
channel_scale
.
astype
(
data_type
))
_init_var_node
(
weight_scale_node
,
channel_scale
.
astype
(
data_type
),
self
.
_scope
,
self
.
_place
)
dequant_var_node
=
graph
.
create_var_node
(
dequant_var_node
=
graph
.
create_var_node
(
name
=
self
.
_dequantized_var_name
(
output_var_node
.
name
()),
name
=
self
.
_dequantized_var_name
(
output_var_node
.
name
()),
var_type
=
output_var_node
.
type
(),
var_type
=
output_var_node
.
type
(),
...
@@ -724,8 +752,8 @@ class QuantizationFreezePass(object):
...
@@ -724,8 +752,8 @@ class QuantizationFreezePass(object):
raise
ValueError
(
"Only support one output, but op %s has"
raise
ValueError
(
"Only support one output, but op %s has"
" more than one output."
%
(
op_node
.
name
()))
" more than one output."
%
(
op_node
.
name
()))
output_var_node
=
self
.
_to_node
(
op_node
.
outputs
,
output_var_node
=
graph
.
_find_node_by_name
(
op_node
.
output_arg_names
()[
0
])
op_node
.
outputs
,
op_node
.
output_arg_names
()[
0
])
dequant_var_node
=
graph
.
create_var_node
(
dequant_var_node
=
graph
.
create_var_node
(
name
=
self
.
_dequantized_var_name
(
output_var_node
.
name
()),
name
=
self
.
_dequantized_var_name
(
output_var_node
.
name
()),
var_type
=
output_var_node
.
type
(),
var_type
=
output_var_node
.
type
(),
...
@@ -746,24 +774,6 @@ class QuantizationFreezePass(object):
...
@@ -746,24 +774,6 @@ class QuantizationFreezePass(object):
self
.
_op_output_rename_map
[
output_var_node
.
node
]
=
dequant_var_node
self
.
_op_output_rename_map
[
output_var_node
.
node
]
=
dequant_var_node
return
dequant_var_node
return
dequant_var_node
def
_init_var_node
(
self
,
var_node
,
value
):
assert
isinstance
(
value
,
np
.
ndarray
),
'The type of value should be numpy array.'
assert
self
.
_scope
is
not
None
,
\
'The scope cannot be set None when activation_quantize_type equals to range_abs_max.'
assert
self
.
_place
is
not
None
,
\
'The place cannot be set None when activation_quantize_type equals to range_abs_max.'
tensor
=
self
.
_scope
.
var
(
var_node
.
name
()).
get_tensor
()
tensor
.
set
(
value
,
self
.
_place
)
def
_to_node
(
self
,
nodes
,
node_name
):
target_node
=
None
for
n
in
nodes
:
if
n
.
name
()
==
node_name
:
target_node
=
n
assert
target_node
is
not
None
,
"Cannot find the target node in the giving set."
return
target_node
def
_load_var
(
self
,
name
):
def
_load_var
(
self
,
name
):
return
np
.
array
(
self
.
_scope
.
find_var
(
name
).
get_tensor
())
return
np
.
array
(
self
.
_scope
.
find_var
(
name
).
get_tensor
())
...
...
python/paddle/fluid/contrib/slim/quantization/quantization_strategy.py
浏览文件 @
f7f5044b
...
@@ -45,13 +45,14 @@ class QuantizationStrategy(Strategy):
...
@@ -45,13 +45,14 @@ class QuantizationStrategy(Strategy):
activation_bits
=
8
,
activation_bits
=
8
,
weight_bits
=
8
,
weight_bits
=
8
,
activation_quantize_type
=
'abs_max'
,
activation_quantize_type
=
'abs_max'
,
weight_quantize_type
=
'abs_max'
,
save_in_nodes
=
None
,
save_in_nodes
=
None
,
save_out_nodes
=
None
):
save_out_nodes
=
None
):
"""
"""
Args:
Args:
start_epoch(int): The 'on_epoch_begin' function will be called in start_epoch. default: 0
start_epoch(int): The 'on_epoch_begin' function will be called in start_epoch. default: 0
end_epoch(int): The 'on_epoch_end' function will be called in end_epoch. default: 0
end_epoch(int): The 'on_epoch_end' function will be called in end_epoch. default: 0
float_model_save_path(str): The path to save model with float weights.
float_model_save_path(str): The path to save model with float weights.
None means it doesn't save float model. defalut: None.
None means it doesn't save float model. defalut: None.
mobile_model_save_path(str): The path to save model for paddle-mobile execution.
mobile_model_save_path(str): The path to save model for paddle-mobile execution.
None means it doesn't save mobile model. defalut: None.
None means it doesn't save mobile model. defalut: None.
...
@@ -66,9 +67,11 @@ class QuantizationStrategy(Strategy):
...
@@ -66,9 +67,11 @@ class QuantizationStrategy(Strategy):
dynamically each step in both training and testing period. If use
dynamically each step in both training and testing period. If use
'range_abs_max', a static quantization scale will be calculated
'range_abs_max', a static quantization scale will be calculated
during training and used in inference.
during training and used in inference.
save_in_nodes(list<str>): A list of variable names used to prune graph
weight_quantize_type (str): quantization type for weights, support 'abs_max' and 'channel_wise_abs_max'.
The 'range_abs_max' usually is not used for weight, since weights are fixed once the model is well trained.
save_in_nodes(list<str>): A list of variable names used to prune graph
for saving inference model.
for saving inference model.
save_out_nodes(list<str>): A list of variable names used to prune graph
save_out_nodes(list<str>): A list of variable names used to prune graph
for saving inference model.
for saving inference model.
"""
"""
...
@@ -81,6 +84,7 @@ class QuantizationStrategy(Strategy):
...
@@ -81,6 +84,7 @@ class QuantizationStrategy(Strategy):
self
.
activation_bits
=
activation_bits
self
.
activation_bits
=
activation_bits
self
.
weight_bits
=
weight_bits
self
.
weight_bits
=
weight_bits
self
.
activation_quantize_type
=
activation_quantize_type
self
.
activation_quantize_type
=
activation_quantize_type
self
.
weight_quantize_type
=
weight_quantize_type
self
.
save_out_nodes
=
save_out_nodes
self
.
save_out_nodes
=
save_out_nodes
self
.
save_in_nodes
=
save_in_nodes
self
.
save_in_nodes
=
save_in_nodes
...
@@ -100,7 +104,8 @@ class QuantizationStrategy(Strategy):
...
@@ -100,7 +104,8 @@ class QuantizationStrategy(Strategy):
place
=
context
.
place
,
place
=
context
.
place
,
weight_bits
=
self
.
weight_bits
,
weight_bits
=
self
.
weight_bits
,
activation_bits
=
self
.
activation_bits
,
activation_bits
=
self
.
activation_bits
,
activation_quantize_type
=
self
.
activation_quantize_type
)
activation_quantize_type
=
self
.
activation_quantize_type
,
weight_quantize_type
=
self
.
weight_quantize_type
)
transform_pass
.
apply
(
train_ir_graph
)
transform_pass
.
apply
(
train_ir_graph
)
transform_pass
.
apply
(
test_ir_graph
)
transform_pass
.
apply
(
test_ir_graph
)
...
@@ -134,7 +139,8 @@ class QuantizationStrategy(Strategy):
...
@@ -134,7 +139,8 @@ class QuantizationStrategy(Strategy):
scope
=
context
.
scope
,
scope
=
context
.
scope
,
place
=
context
.
place
,
place
=
context
.
place
,
weight_bits
=
self
.
weight_bits
,
weight_bits
=
self
.
weight_bits
,
activation_bits
=
self
.
activation_bits
)
activation_bits
=
self
.
activation_bits
,
weight_quantize_type
=
self
.
weight_quantize_type
)
freeze_pass
.
apply
(
test_ir_graph
)
freeze_pass
.
apply
(
test_ir_graph
)
# for other strategies
# for other strategies
...
...
python/paddle/fluid/contrib/slim/tests/quantization/compress.yaml
浏览文件 @
f7f5044b
...
@@ -35,6 +35,8 @@ strategies:
...
@@ -35,6 +35,8 @@ strategies:
start_epoch
:
0
start_epoch
:
0
end_epoch
:
0
end_epoch
:
0
float_model_save_path
:
'
./output/float'
float_model_save_path
:
'
./output/float'
mobile_model_save_path
:
'
./output/mobile'
int8_model_save_path
:
'
./output/int8'
weight_bits
:
8
weight_bits
:
8
activation_bits
:
8
activation_bits
:
8
weight_quantize_type
:
'
abs_max'
weight_quantize_type
:
'
abs_max'
...
...
python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py
浏览文件 @
f7f5044b
...
@@ -256,8 +256,6 @@ class TestQuantizationFreezePass(unittest.TestCase):
...
@@ -256,8 +256,6 @@ class TestQuantizationFreezePass(unittest.TestCase):
place
=
place
,
place
=
place
,
activation_quantize_type
=
activation_quant_type
,
activation_quantize_type
=
activation_quant_type
,
weight_quantize_type
=
weight_quant_type
)
weight_quantize_type
=
weight_quant_type
)
#transform_pass = QuantizationTransformPass(
# scope=scope, place=place, activation_quantize_type=activation_quant_type)
transform_pass
.
apply
(
main_graph
)
transform_pass
.
apply
(
main_graph
)
transform_pass
.
apply
(
test_graph
)
transform_pass
.
apply
(
test_graph
)
dev_name
=
'_gpu_'
if
use_cuda
else
'_cpu_'
dev_name
=
'_gpu_'
if
use_cuda
else
'_cpu_'
...
@@ -315,7 +313,6 @@ class TestQuantizationFreezePass(unittest.TestCase):
...
@@ -315,7 +313,6 @@ class TestQuantizationFreezePass(unittest.TestCase):
# Freeze graph for inference, but the weight of fc/conv is still float type.
# Freeze graph for inference, but the weight of fc/conv is still float type.
freeze_pass
=
QuantizationFreezePass
(
freeze_pass
=
QuantizationFreezePass
(
scope
=
scope
,
place
=
place
,
weight_quantize_type
=
weight_quant_type
)
scope
=
scope
,
place
=
place
,
weight_quantize_type
=
weight_quant_type
)
#freeze_pass = QuantizationFreezePass(scope=scope, place=place)
freeze_pass
.
apply
(
test_graph
)
freeze_pass
.
apply
(
test_graph
)
if
not
for_ci
:
if
not
for_ci
:
marked_nodes
=
set
()
marked_nodes
=
set
()
...
...
python/paddle/fluid/framework.py
浏览文件 @
f7f5044b
...
@@ -2347,40 +2347,6 @@ class IrGraph(object):
...
@@ -2347,40 +2347,6 @@ class IrGraph(object):
"""
"""
return
{
IrOpNode
(
node
)
for
node
in
self
.
graph
.
nodes
()
if
node
.
is_op
()}
return
{
IrOpNode
(
node
)
for
node
in
self
.
graph
.
nodes
()
if
node
.
is_op
()}
def
_find_var_node
(
self
,
key
):
"""
Get a variable node by the `key` from this graph. The key
can be a node name or a node id.
WARNS:
There are some nodes may have the same name. So, be
cautious about using this method when you find the
target var node by its name.
Args:
key(str|int): The str type denotes that the target variable node's name.
And the int type denotes that the target variable node's id.
Raises:
ValueError: If this graph doesn't have a variable with the giving name or id.
Returns:
IrVarNode: the variable node with the giving name or id.
"""
target_var_node
=
None
var_nodes
=
self
.
all_var_nodes
()
if
isinstance
(
key
,
six
.
string_types
):
for
var_node
in
var_nodes
:
if
var_node
.
name
()
==
key
:
target_var_node
=
var_node
elif
isinstance
(
key
,
int
):
for
var_node
in
var_nodes
:
if
var_node
.
id
()
==
key
:
target_var_node
=
var_node
if
target_var_node
is
None
:
raise
ValueError
(
"var_node %s not in this graph"
%
key
)
return
target_var_node
def
create_persistable_node
(
self
,
name
,
var_type
,
shape
,
var_dtype
):
def
create_persistable_node
(
self
,
name
,
var_type
,
shape
,
var_dtype
):
"""
"""
Create a persistable variable node in the graph. In IrGraph,
Create a persistable variable node in the graph. In IrGraph,
...
@@ -2525,14 +2491,6 @@ class IrGraph(object):
...
@@ -2525,14 +2491,6 @@ class IrGraph(object):
core
.
graph_safe_remove_nodes
(
self
.
graph
,
original_nodes
)
core
.
graph_safe_remove_nodes
(
self
.
graph
,
original_nodes
)
def
resolve_hazard
(
self
):
def
resolve_hazard
(
self
):
def
_to_node
(
nodes
,
node_name
):
target_node
=
None
for
n
in
nodes
:
if
n
.
name
()
==
node_name
:
target_node
=
n
assert
target_node
is
not
None
,
"Cannot find the target node in the giving set."
return
target_node
ordered_nodes
=
core
.
topology_sort
(
self
.
graph
)
ordered_nodes
=
core
.
topology_sort
(
self
.
graph
)
var_nodes
=
dict
()
var_nodes
=
dict
()
for
node
in
ordered_nodes
:
for
node
in
ordered_nodes
:
...
@@ -2540,16 +2498,17 @@ class IrGraph(object):
...
@@ -2540,16 +2498,17 @@ class IrGraph(object):
for
each_var_name
in
node
.
op
().
input_arg_names
():
for
each_var_name
in
node
.
op
().
input_arg_names
():
if
each_var_name
not
in
var_nodes
:
if
each_var_name
not
in
var_nodes
:
var_nodes
[
each_var_name
]
=
[
var_nodes
[
each_var_name
]
=
[
_to_nod
e
(
node
.
inputs
,
each_var_name
)
self
.
_find_node_by_nam
e
(
node
.
inputs
,
each_var_name
)
]
]
for
each_var_name
in
node
.
op
().
output_arg_names
():
for
each_var_name
in
node
.
op
().
output_arg_names
():
if
each_var_name
not
in
var_nodes
:
if
each_var_name
not
in
var_nodes
:
var_nodes
[
each_var_name
]
=
[
var_nodes
[
each_var_name
]
=
[
_to_nod
e
(
node
.
outputs
,
each_var_name
)
self
.
_find_node_by_nam
e
(
node
.
outputs
,
each_var_name
)
]
]
else
:
else
:
var_nodes
[
each_var_name
].
append
(
var_nodes
[
each_var_name
].
append
(
_to_node
(
node
.
outputs
,
each_var_name
))
self
.
_find_node_by_name
(
node
.
outputs
,
each_var_name
))
self
.
graph
.
resolve_hazard
(
var_nodes
)
self
.
graph
.
resolve_hazard
(
var_nodes
)
def
has_circle
(
self
):
def
has_circle
(
self
):
...
@@ -2662,6 +2621,17 @@ class IrGraph(object):
...
@@ -2662,6 +2621,17 @@ class IrGraph(object):
program
=
Program
.
_construct_from_desc
(
desc
)
program
=
Program
.
_construct_from_desc
(
desc
)
return
program
return
program
def
_find_node_by_name
(
self
,
nodes
,
node_name
):
"""
Find a node in the giving nodes set by the name.
"""
target_node
=
None
for
n
in
nodes
:
if
n
.
name
()
==
node_name
:
target_node
=
n
assert
target_node
is
not
None
,
"Cannot find the target node in the giving set."
return
target_node
def
_update_desc_attr
(
self
,
desc
,
name
,
val
):
def
_update_desc_attr
(
self
,
desc
,
name
,
val
):
"""
"""
Update the value of desc's attribute by attribute's name.
Update the value of desc's attribute by attribute's name.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录