Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
7c8f7df2
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7c8f7df2
编写于
2月 21, 2019
作者:
Z
Zhen Wang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add some op_des funs to IrOpNode and add some var_des funs to IrVarNode. test=develop
上级
33f99d61
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
99 addition
and
27 deletion
+99
-27
python/paddle/fluid/contrib/slim/quantization/quantization_pass.py
...ddle/fluid/contrib/slim/quantization/quantization_pass.py
+27
-27
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+72
-0
未找到文件。
python/paddle/fluid/contrib/slim/quantization/quantization_pass.py
浏览文件 @
7c8f7df2
...
@@ -231,14 +231,14 @@ class QuantizationTransformPass(object):
...
@@ -231,14 +231,14 @@ class QuantizationTransformPass(object):
quant_var_node
=
graph
.
create_var_node
(
quant_var_node
=
graph
.
create_var_node
(
name
=
self
.
_quantized_var_name
(
var_node
.
name
()),
name
=
self
.
_quantized_var_name
(
var_node
.
name
()),
var_type
=
var_node
.
var
().
type
(),
var_type
=
var_node
.
type
(),
shape
=
var_node
.
var
().
shape
(),
shape
=
var_node
.
shape
(),
var_dtype
=
var_node
.
var
().
dtype
())
var_dtype
=
var_node
.
dtype
())
scale_var_node
=
graph
.
create_var_node
(
scale_var_node
=
graph
.
create_var_node
(
name
=
self
.
_quantized_scale_name
(
var_node
.
name
()),
name
=
self
.
_quantized_scale_name
(
var_node
.
name
()),
var_type
=
var_node
.
var
().
type
(),
var_type
=
var_node
.
type
(),
shape
=
var_node
.
var
().
shape
(),
shape
=
var_node
.
shape
(),
var_dtype
=
var_node
.
var
().
dtype
())
var_dtype
=
var_node
.
dtype
())
quant_op_node
=
graph
.
create_op_node
(
quant_op_node
=
graph
.
create_op_node
(
op_type
=
'fake_quantize_abs_max'
,
op_type
=
'fake_quantize_abs_max'
,
attrs
=
{
attrs
=
{
...
@@ -261,15 +261,15 @@ class QuantizationTransformPass(object):
...
@@ -261,15 +261,15 @@ class QuantizationTransformPass(object):
quant_var_node
=
graph
.
create_var_node
(
quant_var_node
=
graph
.
create_var_node
(
name
=
self
.
_quantized_var_name
(
var_node
.
name
()),
name
=
self
.
_quantized_var_name
(
var_node
.
name
()),
var_type
=
var_node
.
var
().
type
(),
var_type
=
var_node
.
type
(),
shape
=
var_node
.
var
().
shape
(),
shape
=
var_node
.
shape
(),
var_dtype
=
var_node
.
var
().
dtype
())
var_dtype
=
var_node
.
dtype
())
scale_in_node
=
graph
.
create_persistable_node
(
scale_in_node
=
graph
.
create_persistable_node
(
name
=
self
.
_quantized_scale_name
(
var_node
.
name
()),
name
=
self
.
_quantized_scale_name
(
var_node
.
name
()),
var_type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
var_type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
shape
=
[
1
],
shape
=
[
1
],
var_dtype
=
var_node
.
var
().
dtype
())
var_dtype
=
var_node
.
dtype
())
self
.
_need_initialized
[
scale_in_node
.
var
()]
=
Constant
(
value
=
0.001
)
self
.
_need_initialized
[
scale_in_node
.
var
()]
=
Constant
(
value
=
0.001
)
scale_out_node
=
graph
.
create_var_node_from_desc
(
scale_in_node
.
var
())
scale_out_node
=
graph
.
create_var_node_from_desc
(
scale_in_node
.
var
())
...
@@ -282,7 +282,7 @@ class QuantizationTransformPass(object):
...
@@ -282,7 +282,7 @@ class QuantizationTransformPass(object):
name
=
unique_name
.
generate
(
'scales'
),
name
=
unique_name
.
generate
(
'scales'
),
var_type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
var_type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
shape
=
[
self
.
_window_size
],
shape
=
[
self
.
_window_size
],
var_dtype
=
var_node
.
var
().
dtype
())
var_dtype
=
var_node
.
dtype
())
self
.
_need_initialized
[
scales_node
.
var
()]
=
Constant
(
value
=
0
)
self
.
_need_initialized
[
scales_node
.
var
()]
=
Constant
(
value
=
0
)
inputs
[
'Iter'
]
=
self
.
_global_step
inputs
[
'Iter'
]
=
self
.
_global_step
outputs
[
'OutScales'
]
=
scales_node
outputs
[
'OutScales'
]
=
scales_node
...
@@ -317,9 +317,9 @@ class QuantizationTransformPass(object):
...
@@ -317,9 +317,9 @@ class QuantizationTransformPass(object):
dequant_var_node
=
graph
.
create_var_node
(
dequant_var_node
=
graph
.
create_var_node
(
name
=
self
.
_dequantized_var_name
(
var_node
.
name
()),
name
=
self
.
_dequantized_var_name
(
var_node
.
name
()),
var_type
=
var_node
.
var
().
type
(),
var_type
=
var_node
.
type
(),
shape
=
var_node
.
var
().
shape
(),
shape
=
var_node
.
shape
(),
var_dtype
=
var_node
.
var
().
dtype
())
var_dtype
=
var_node
.
dtype
())
max_range
=
(
1
<<
(
quant_bits
-
1
))
-
1
max_range
=
(
1
<<
(
quant_bits
-
1
))
-
1
dequant_op_node
=
graph
.
create_op_node
(
dequant_op_node
=
graph
.
create_op_node
(
op_type
=
'fake_dequantize_max_abs'
,
op_type
=
'fake_dequantize_max_abs'
,
...
@@ -408,17 +408,17 @@ class QuantizationFreezePass(object):
...
@@ -408,17 +408,17 @@ class QuantizationFreezePass(object):
for
op_node
in
ops
:
for
op_node
in
ops
:
op_name
=
op_node
.
name
()
op_name
=
op_node
.
name
()
if
op_name
in
self
.
_fake_quant_op_names
:
if
op_name
in
self
.
_fake_quant_op_names
:
input_arg_name
=
op_node
.
op
().
input
(
'X'
)[
0
]
input_arg_name
=
op_node
.
input
(
'X'
)[
0
]
if
input_arg_name
in
persistable_vars
:
if
input_arg_name
in
persistable_vars
:
if
self
.
_weight_quantize_type
==
'abs_max'
:
if
self
.
_weight_quantize_type
==
'abs_max'
:
param
=
self
.
_load_var
(
input_arg_name
)
param
=
self
.
_load_var
(
input_arg_name
)
scale_v
=
np
.
max
(
np
.
abs
(
param
))
scale_v
=
np
.
max
(
np
.
abs
(
param
))
else
:
else
:
scale_v
=
self
.
_load_var
(
op_node
.
op
().
output
(
'OutScale'
)
scale_v
=
self
.
_load_var
(
[
0
])[
0
]
op_node
.
output
(
'OutScale'
)
[
0
])[
0
]
self
.
_var_scale_map
[
input_arg_name
]
=
scale_v
self
.
_var_scale_map
[
input_arg_name
]
=
scale_v
else
:
else
:
scale_v
=
graph
.
var_node
(
op_node
.
o
p
().
o
utput
(
'OutScale'
)[
0
])
scale_v
=
graph
.
var_node
(
op_node
.
output
(
'OutScale'
)[
0
])
self
.
_var_scale_map
[
input_arg_name
]
=
scale_v
self
.
_var_scale_map
[
input_arg_name
]
=
scale_v
if
input_arg_name
in
persistable_vars
:
if
input_arg_name
in
persistable_vars
:
self
.
_remove_fake_quant_and_dequant_op
(
graph
,
op_node
)
self
.
_remove_fake_quant_and_dequant_op
(
graph
,
op_node
)
...
@@ -454,8 +454,8 @@ class QuantizationFreezePass(object):
...
@@ -454,8 +454,8 @@ class QuantizationFreezePass(object):
return
graph
return
graph
def
_remove_fake_quant_and_dequant_op
(
self
,
graph
,
op_node
):
def
_remove_fake_quant_and_dequant_op
(
self
,
graph
,
op_node
):
k
=
op_node
.
o
p
().
o
utput
(
'Out'
)[
0
]
k
=
op_node
.
output
(
'Out'
)[
0
]
v
=
op_node
.
op
().
input
(
'X'
)[
0
]
v
=
op_node
.
input
(
'X'
)[
0
]
if
v
not
in
self
.
_op_input_rename_map
:
if
v
not
in
self
.
_op_input_rename_map
:
self
.
_op_input_rename_map
[
k
]
=
v
self
.
_op_input_rename_map
[
k
]
=
v
else
:
else
:
...
@@ -493,9 +493,9 @@ class QuantizationFreezePass(object):
...
@@ -493,9 +493,9 @@ class QuantizationFreezePass(object):
output_var_node
=
op_node
.
outputs
[
0
]
output_var_node
=
op_node
.
outputs
[
0
]
dequant_var_node
=
graph
.
create_var_node
(
dequant_var_node
=
graph
.
create_var_node
(
name
=
self
.
_dequantized_var_name
(
output_var_node
.
name
()),
name
=
self
.
_dequantized_var_name
(
output_var_node
.
name
()),
var_type
=
output_var_node
.
var
().
type
(),
var_type
=
output_var_node
.
type
(),
shape
=
output_var_node
.
var
().
shape
(),
shape
=
output_var_node
.
shape
(),
var_dtype
=
output_var_node
.
var
().
dtype
())
var_dtype
=
output_var_node
.
dtype
())
dequant_op_node
=
graph
.
create_op_node
(
dequant_op_node
=
graph
.
create_op_node
(
op_type
=
'fake_dequantize_max_abs'
,
op_type
=
'fake_dequantize_max_abs'
,
attrs
=
{
attrs
=
{
...
@@ -615,8 +615,8 @@ class ConvertToInt8Pass(object):
...
@@ -615,8 +615,8 @@ class ConvertToInt8Pass(object):
int8_var_node_name
=
var_node
.
name
()
+
".int8"
int8_var_node_name
=
var_node
.
name
()
+
".int8"
int8_var_node
=
graph
.
create_persistable_node
(
int8_var_node
=
graph
.
create_persistable_node
(
name
=
cpt
.
to_text
(
int8_var_node_name
),
name
=
cpt
.
to_text
(
int8_var_node_name
),
var_type
=
var_node
.
var
().
type
(),
var_type
=
var_node
.
type
(),
shape
=
var_node
.
var
().
shape
(),
shape
=
var_node
.
shape
(),
var_dtype
=
core
.
VarDesc
.
VarType
.
INT8
)
var_dtype
=
core
.
VarDesc
.
VarType
.
INT8
)
array
=
self
.
_load_var
(
var_node
.
name
())
array
=
self
.
_load_var
(
var_node
.
name
())
self
.
_scope
.
var
(
int8_var_node_name
)
self
.
_scope
.
var
(
int8_var_node_name
)
...
@@ -672,7 +672,7 @@ class TransformForMobilePass(object):
...
@@ -672,7 +672,7 @@ class TransformForMobilePass(object):
for
op_node
in
ops
:
for
op_node
in
ops
:
name
=
op_node
.
name
()
name
=
op_node
.
name
()
if
name
in
self
.
_fake_quant_op_names
:
if
name
in
self
.
_fake_quant_op_names
:
op_node
.
op
().
set_type
(
'quantize'
)
op_node
.
set_type
(
'quantize'
)
quant_node
=
graph
.
create_op_node_from_desc
(
op_node
.
op
())
quant_node
=
graph
.
create_op_node_from_desc
(
op_node
.
op
())
for
input_node
in
op_node
.
inputs
:
for
input_node
in
op_node
.
inputs
:
graph
.
link_to
(
input_node
,
quant_node
)
graph
.
link_to
(
input_node
,
quant_node
)
...
@@ -680,7 +680,7 @@ class TransformForMobilePass(object):
...
@@ -680,7 +680,7 @@ class TransformForMobilePass(object):
graph
.
link_to
(
quant_node
,
output_node
)
graph
.
link_to
(
quant_node
,
output_node
)
graph
.
safe_remove_nodes
(
op_node
)
graph
.
safe_remove_nodes
(
op_node
)
if
name
in
self
.
_fake_dequant_op_names
:
if
name
in
self
.
_fake_dequant_op_names
:
op_node
.
op
().
set_type
(
'dequantize'
)
op_node
.
set_type
(
'dequantize'
)
dequant_node
=
graph
.
create_op_node_from_desc
(
op_node
.
op
())
dequant_node
=
graph
.
create_op_node_from_desc
(
op_node
.
op
())
for
input_node
in
op_node
.
inputs
:
for
input_node
in
op_node
.
inputs
:
graph
.
link_to
(
input_node
,
dequant_node
)
graph
.
link_to
(
input_node
,
dequant_node
)
...
...
python/paddle/fluid/framework.py
浏览文件 @
7c8f7df2
...
@@ -1754,6 +1754,39 @@ class IrVarNode(IrNode):
...
@@ -1754,6 +1754,39 @@ class IrVarNode(IrNode):
"The node variable description cannot be None."
"The node variable description cannot be None."
return
self
.
node
.
var
().
persistable
()
return
self
.
node
.
var
().
persistable
()
def
type
(
self
):
"""
Return the variable type.
Returns:
core.VarDesc.VarType: the variable type.
"""
assert
self
.
node
.
var
()
is
not
None
,
\
"The node variable description cannot be None."
return
self
.
node
.
var
().
type
()
def
dtype
(
self
):
"""
Return the variable data type.
Returns:
core.VarDesc.VarType: the variable data type.
"""
assert
self
.
node
.
var
()
is
not
None
,
\
"The node variable description cannot be None."
return
self
.
node
.
var
().
dtype
()
def
shape
(
self
):
"""
Return the variable shape.
Returns:
list: the variable shape.
"""
assert
self
.
node
.
var
()
is
not
None
,
\
"The node variable description cannot be None."
return
self
.
node
.
var
().
shape
()
@
property
@
property
def
inputs
(
self
):
def
inputs
(
self
):
"""
"""
...
@@ -1804,6 +1837,45 @@ class IrOpNode(IrNode):
...
@@ -1804,6 +1837,45 @@ class IrOpNode(IrNode):
"The node operator description cannot be None."
"The node operator description cannot be None."
self
.
node
.
op
().
_rename_input
(
old_input_name
,
new_input_name
)
self
.
node
.
op
().
_rename_input
(
old_input_name
,
new_input_name
)
def
input
(
self
,
name
):
"""
Get the argument name list by the parameter name for input.
Args:
name(str): the parameter name.
Returns:
list(str): the argument name list.
"""
assert
self
.
node
.
op
()
is
not
None
,
\
"The node operator description cannot be None."
return
self
.
node
.
op
().
input
(
name
)
def
output
(
self
,
name
):
"""
Get the argument name list by the parameter name for output.
Args:
name(str): the parameter name.
Returns:
list(str): the argument name list.
"""
assert
self
.
node
.
op
()
is
not
None
,
\
"The node operator description cannot be None."
return
self
.
node
.
op
().
output
(
name
)
def
set_type
(
self
,
new_type
):
"""
Change the operator type into new type.
Args:
new_type(str): new operator type to be set.
"""
assert
self
.
node
.
op
()
is
not
None
,
\
"The node operator description cannot be None."
return
self
.
node
.
op
().
set_type
(
new_type
)
@
property
@
property
def
inputs
(
self
):
def
inputs
(
self
):
"""
"""
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录