Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
33f99d61
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
33f99d61
编写于
2月 20, 2019
作者:
Z
Zhen Wang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add IrNode&IrVarNode&IrOpNode. test=develop
上级
d8128930
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
409 addition
and
106 deletion
+409
-106
python/paddle/fluid/contrib/slim/quantization/quantization_pass.py
...ddle/fluid/contrib/slim/quantization/quantization_pass.py
+41
-28
python/paddle/fluid/contrib/slim/tests/test_graph.py
python/paddle/fluid/contrib/slim/tests/test_graph.py
+3
-3
python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py
...paddle/fluid/contrib/slim/tests/test_quantization_pass.py
+26
-31
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+339
-44
未找到文件。
python/paddle/fluid/contrib/slim/quantization/quantization_pass.py
浏览文件 @
33f99d61
...
@@ -17,7 +17,9 @@ import numpy as np
...
@@ -17,7 +17,9 @@ import numpy as np
import
six
import
six
from
.....
import
compat
as
cpt
from
.....
import
compat
as
cpt
from
....
import
core
from
....
import
core
from
....
import
Executor
from
....framework
import
IrGraph
from
....framework
import
IrGraph
from
....framework
import
IrNode
from
....framework
import
Program
from
....framework
import
Program
from
....initializer
import
Constant
from
....initializer
import
Constant
from
....
import
unique_name
from
....
import
unique_name
...
@@ -31,7 +33,7 @@ __all__ = [
...
@@ -31,7 +33,7 @@ __all__ = [
class
QuantizationTransformPass
(
object
):
class
QuantizationTransformPass
(
object
):
def
__init__
(
self
,
def
__init__
(
self
,
scope
=
None
,
scope
=
None
,
p
rogram_ex
e
=
None
,
p
lac
e
=
None
,
weight_bits
=
8
,
weight_bits
=
8
,
activation_bits
=
8
,
activation_bits
=
8
,
activation_quantize_type
=
'abs_max'
,
activation_quantize_type
=
'abs_max'
,
...
@@ -45,7 +47,7 @@ class QuantizationTransformPass(object):
...
@@ -45,7 +47,7 @@ class QuantizationTransformPass(object):
scope(fluid.Scope): When activation use 'range_abs_max' as the quantize
scope(fluid.Scope): When activation use 'range_abs_max' as the quantize
type, this pass will create some new parameters. The scope is used to
type, this pass will create some new parameters. The scope is used to
initialize these new parameters.
initialize these new parameters.
p
rogram_exe(fluid.Executor): program_ex
e is used to initialize new
p
lace(fluid.CPUPlace|fluid.CUDAPlace): plac
e is used to initialize new
parameters described above.
parameters described above.
weight_bits (int): quantization bit number for weights,
weight_bits (int): quantization bit number for weights,
the bias is not quantized.
the bias is not quantized.
...
@@ -71,13 +73,13 @@ class QuantizationTransformPass(object):
...
@@ -71,13 +73,13 @@ class QuantizationTransformPass(object):
from paddle.fluid import core
from paddle.fluid import core
graph = IrGraph(core.Graph(program.desc), for_test=False)
graph = IrGraph(core.Graph(program.desc), for_test=False)
exe = fluid.Executor(fluid.CPUPlace()
)
place = fluid.CPUPlace(
)
transform_pass = QuantizationTransformPass(fluid.global_scope(),
transform_pass = QuantizationTransformPass(fluid.global_scope(),
ex
e)
plac
e)
transform_pass.apply(graph)
transform_pass.apply(graph)
"""
"""
self
.
_scope
=
scope
self
.
_scope
=
scope
self
.
_p
rogram_exe
=
program_ex
e
self
.
_p
lace
=
plac
e
self
.
_weight_bits
=
weight_bits
self
.
_weight_bits
=
weight_bits
self
.
_activation_bits
=
activation_bits
self
.
_activation_bits
=
activation_bits
...
@@ -118,7 +120,7 @@ class QuantizationTransformPass(object):
...
@@ -118,7 +120,7 @@ class QuantizationTransformPass(object):
self
.
_is_test
=
graph
.
is_test
()
self
.
_is_test
=
graph
.
is_test
()
# marked the variable which has been dequantized.
# marked the variable which has been dequantized.
dequantized_vars
=
collections
.
OrderedDict
()
dequantized_vars
=
collections
.
OrderedDict
()
persistable_vars
=
[
p
.
name
()
for
p
in
graph
.
all_persistable_
var
s
()]
persistable_vars
=
[
p
.
name
()
for
p
in
graph
.
all_persistable_
node
s
()]
def
_transform_forward
(
graph
,
op
):
def
_transform_forward
(
graph
,
op
):
for
var_node
in
op
.
inputs
:
for
var_node
in
op
.
inputs
:
...
@@ -149,7 +151,7 @@ class QuantizationTransformPass(object):
...
@@ -149,7 +151,7 @@ class QuantizationTransformPass(object):
if
not
self
.
_is_test
:
if
not
self
.
_is_test
:
self
.
_create_global_step
(
graph
)
self
.
_create_global_step
(
graph
)
ops
=
graph
.
all_ops
()
ops
=
graph
.
all_op
_node
s
()
# The process of _transform_forward and _transform_backward is needed in two for loops.
# The process of _transform_forward and _transform_backward is needed in two for loops.
# The loop for transforming the forward graph:
# The loop for transforming the forward graph:
for
op
in
ops
:
for
op
in
ops
:
...
@@ -163,8 +165,8 @@ class QuantizationTransformPass(object):
...
@@ -163,8 +165,8 @@ class QuantizationTransformPass(object):
if
len
(
self
.
_need_initialized
)
>
0
:
if
len
(
self
.
_need_initialized
)
>
0
:
assert
self
.
_scope
is
not
None
,
\
assert
self
.
_scope
is
not
None
,
\
'The scope cannot be set None when activation_quantize_type equals to range_abs_max.'
'The scope cannot be set None when activation_quantize_type equals to range_abs_max.'
assert
self
.
_p
rogram_ex
e
is
not
None
,
\
assert
self
.
_p
lac
e
is
not
None
,
\
'The p
rogram_ex
e cannot be set None when activation_quantize_type equals to range_abs_max.'
'The p
lac
e cannot be set None when activation_quantize_type equals to range_abs_max.'
init_program
=
Program
()
init_program
=
Program
()
for
var_desc
,
initializer
in
six
.
iteritems
(
self
.
_need_initialized
):
for
var_desc
,
initializer
in
six
.
iteritems
(
self
.
_need_initialized
):
var
=
init_program
.
global_block
().
create_var
(
var
=
init_program
.
global_block
().
create_var
(
...
@@ -175,7 +177,8 @@ class QuantizationTransformPass(object):
...
@@ -175,7 +177,8 @@ class QuantizationTransformPass(object):
lod_level
=
var_desc
.
lod_level
(),
lod_level
=
var_desc
.
lod_level
(),
persistable
=
var_desc
.
persistable
())
persistable
=
var_desc
.
persistable
())
initializer
(
var
,
init_program
.
global_block
())
initializer
(
var
,
init_program
.
global_block
())
self
.
_program_exe
.
run
(
program
=
init_program
,
scope
=
self
.
_scope
)
exe
=
Executor
(
self
.
_place
)
exe
.
run
(
program
=
init_program
,
scope
=
self
.
_scope
)
return
graph
return
graph
...
@@ -183,11 +186,11 @@ class QuantizationTransformPass(object):
...
@@ -183,11 +186,11 @@ class QuantizationTransformPass(object):
if
self
.
_weight_quantize_type
==
'range_abs_max'
or
\
if
self
.
_weight_quantize_type
==
'range_abs_max'
or
\
self
.
_activation_quantize_type
==
'range_abs_max'
:
self
.
_activation_quantize_type
==
'range_abs_max'
:
counter_name
=
cpt
.
to_text
(
'@STEP_COUNTER@'
)
counter_name
=
cpt
.
to_text
(
'@STEP_COUNTER@'
)
for
node
in
graph
.
all_vars
():
for
node
in
graph
.
all_var
_node
s
():
if
node
.
name
()
==
counter_name
:
if
node
.
name
()
==
counter_name
:
self
.
_global_step
=
node
self
.
_global_step
=
node
if
self
.
_global_step
is
None
:
if
self
.
_global_step
is
None
:
global_step_in
=
graph
.
create_p
aram
_node
(
global_step_in
=
graph
.
create_p
ersistable
_node
(
name
=
counter_name
,
name
=
counter_name
,
var_type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
var_type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
shape
=
[
1
],
shape
=
[
1
],
...
@@ -262,7 +265,7 @@ class QuantizationTransformPass(object):
...
@@ -262,7 +265,7 @@ class QuantizationTransformPass(object):
shape
=
var_node
.
var
().
shape
(),
shape
=
var_node
.
var
().
shape
(),
var_dtype
=
var_node
.
var
().
dtype
())
var_dtype
=
var_node
.
var
().
dtype
())
scale_in_node
=
graph
.
create_p
aram
_node
(
scale_in_node
=
graph
.
create_p
ersistable
_node
(
name
=
self
.
_quantized_scale_name
(
var_node
.
name
()),
name
=
self
.
_quantized_scale_name
(
var_node
.
name
()),
var_type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
var_type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
shape
=
[
1
],
shape
=
[
1
],
...
@@ -275,7 +278,7 @@ class QuantizationTransformPass(object):
...
@@ -275,7 +278,7 @@ class QuantizationTransformPass(object):
if
not
self
.
_is_test
:
if
not
self
.
_is_test
:
# The name of scales_var_node maybe 'scales_0', 'scales_1', etc.
# The name of scales_var_node maybe 'scales_0', 'scales_1', etc.
scales_node
=
graph
.
create_p
aram
_node
(
scales_node
=
graph
.
create_p
ersistable
_node
(
name
=
unique_name
.
generate
(
'scales'
),
name
=
unique_name
.
generate
(
'scales'
),
var_type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
var_type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
shape
=
[
self
.
_window_size
],
shape
=
[
self
.
_window_size
],
...
@@ -400,8 +403,8 @@ class QuantizationFreezePass(object):
...
@@ -400,8 +403,8 @@ class QuantizationFreezePass(object):
Args:
Args:
graph(IrGraph): the applied graph.
graph(IrGraph): the applied graph.
"""
"""
persistable_vars
=
[
p
.
name
()
for
p
in
graph
.
all_persistable_
var
s
()]
persistable_vars
=
[
p
.
name
()
for
p
in
graph
.
all_persistable_
node
s
()]
ops
=
graph
.
all_ops
()
ops
=
graph
.
all_op
_node
s
()
for
op_node
in
ops
:
for
op_node
in
ops
:
op_name
=
op_node
.
name
()
op_name
=
op_node
.
name
()
if
op_name
in
self
.
_fake_quant_op_names
:
if
op_name
in
self
.
_fake_quant_op_names
:
...
@@ -425,13 +428,13 @@ class QuantizationFreezePass(object):
...
@@ -425,13 +428,13 @@ class QuantizationFreezePass(object):
self
.
_weight_bits
)
self
.
_weight_bits
)
self
.
_restore_var
(
input_arg_name
,
quantized_param_v
)
self
.
_restore_var
(
input_arg_name
,
quantized_param_v
)
ops
=
graph
.
all_ops
()
ops
=
graph
.
all_op
_node
s
()
for
op_node
in
ops
:
for
op_node
in
ops
:
op_name
=
op_node
.
name
()
op_name
=
op_node
.
name
()
if
op_name
in
self
.
_fake_dequant_op_names
:
if
op_name
in
self
.
_fake_dequant_op_names
:
self
.
_remove_fake_quant_and_dequant_op
(
graph
,
op_node
)
self
.
_remove_fake_quant_and_dequant_op
(
graph
,
op_node
)
ops
=
graph
.
all_ops
()
ops
=
graph
.
all_op
_node
s
()
for
op_node
in
ops
:
for
op_node
in
ops
:
op_name
=
op_node
.
name
()
op_name
=
op_node
.
name
()
if
op_name
in
self
.
_quantizable_ops
:
if
op_name
in
self
.
_quantizable_ops
:
...
@@ -462,7 +465,7 @@ class QuantizationFreezePass(object):
...
@@ -462,7 +465,7 @@ class QuantizationFreezePass(object):
def
_insert_post_dequant_op
(
self
,
graph
,
op_node
):
def
_insert_post_dequant_op
(
self
,
graph
,
op_node
):
max_range
=
None
max_range
=
None
scale_var_node
=
None
scale_var_node
=
None
persistable_vars
=
[
p
.
name
()
for
p
in
graph
.
all_persistable_
var
s
()]
persistable_vars
=
[
p
.
name
()
for
p
in
graph
.
all_persistable_
node
s
()]
for
var_node
in
op_node
.
inputs
:
for
var_node
in
op_node
.
inputs
:
name
=
var_node
.
name
()
name
=
var_node
.
name
()
if
name
in
self
.
_op_input_rename_map
:
if
name
in
self
.
_op_input_rename_map
:
...
@@ -480,7 +483,7 @@ class QuantizationFreezePass(object):
...
@@ -480,7 +483,7 @@ class QuantizationFreezePass(object):
original_var_name
)
original_var_name
)
max_range
=
param_range
*
act_range
/
scale_v
max_range
=
param_range
*
act_range
/
scale_v
else
:
else
:
assert
isinstance
(
scale_v
,
core
.
Node
)
assert
isinstance
(
scale_v
,
Ir
Node
)
scale_var_node
=
self
.
_var_scale_map
[
original_var_name
]
scale_var_node
=
self
.
_var_scale_map
[
original_var_name
]
if
len
(
op_node
.
outputs
)
!=
1
:
if
len
(
op_node
.
outputs
)
!=
1
:
...
@@ -517,14 +520,19 @@ class QuantizationFreezePass(object):
...
@@ -517,14 +520,19 @@ class QuantizationFreezePass(object):
def
_remove_unused_var_nodes
(
self
,
graph
):
def
_remove_unused_var_nodes
(
self
,
graph
):
all_used_vars
=
set
()
all_used_vars
=
set
()
ops
=
graph
.
all_ops
()
ops
=
graph
.
all_op
_node
s
()
for
op_node
in
ops
:
for
op_node
in
ops
:
for
input_node
in
op_node
.
inputs
:
for
input_node
in
op_node
.
inputs
:
all_used_vars
.
add
(
input_node
)
all_used_vars
.
add
(
input_node
)
for
output_node
in
op_node
.
outputs
:
for
output_node
in
op_node
.
outputs
:
all_used_vars
.
add
(
output_node
)
all_used_vars
.
add
(
output_node
)
all_unused_vars
=
graph
.
all_vars
()
-
all_used_vars
all_used_vars
=
{
n
.
node
for
n
in
all_used_vars
}
all_unused_vars
=
{
n
for
n
in
filter
(
lambda
node
:
node
.
node
not
in
all_used_vars
,
graph
.
all_var_nodes
())
}
graph
.
safe_remove_nodes
(
all_unused_vars
)
graph
.
safe_remove_nodes
(
all_unused_vars
)
def
_original_var_name
(
self
,
var_name
):
def
_original_var_name
(
self
,
var_name
):
...
@@ -583,8 +591,8 @@ class ConvertToInt8Pass(object):
...
@@ -583,8 +591,8 @@ class ConvertToInt8Pass(object):
Args:
Args:
graph(IrGraph): the applied graph.
graph(IrGraph): the applied graph.
"""
"""
persistable_vars
=
[
p
.
name
()
for
p
in
graph
.
all_persistable_
var
s
()]
persistable_vars
=
[
p
.
name
()
for
p
in
graph
.
all_persistable_
node
s
()]
ops
=
graph
.
all_ops
()
ops
=
graph
.
all_op
_node
s
()
input_map
=
{}
input_map
=
{}
for
op_node
in
ops
:
for
op_node
in
ops
:
op_name
=
op_node
.
name
()
op_name
=
op_node
.
name
()
...
@@ -605,7 +613,7 @@ class ConvertToInt8Pass(object):
...
@@ -605,7 +613,7 @@ class ConvertToInt8Pass(object):
def
_convert_to_int8
(
self
,
graph
,
var_node
):
def
_convert_to_int8
(
self
,
graph
,
var_node
):
int8_var_node_name
=
var_node
.
name
()
+
".int8"
int8_var_node_name
=
var_node
.
name
()
+
".int8"
int8_var_node
=
graph
.
create_p
aram
_node
(
int8_var_node
=
graph
.
create_p
ersistable
_node
(
name
=
cpt
.
to_text
(
int8_var_node_name
),
name
=
cpt
.
to_text
(
int8_var_node_name
),
var_type
=
var_node
.
var
().
type
(),
var_type
=
var_node
.
var
().
type
(),
shape
=
var_node
.
var
().
shape
(),
shape
=
var_node
.
var
().
shape
(),
...
@@ -624,14 +632,19 @@ class ConvertToInt8Pass(object):
...
@@ -624,14 +632,19 @@ class ConvertToInt8Pass(object):
def
_remove_unused_var_nodes
(
self
,
graph
):
def
_remove_unused_var_nodes
(
self
,
graph
):
all_used_vars
=
set
()
all_used_vars
=
set
()
ops
=
graph
.
all_ops
()
ops
=
graph
.
all_op
_node
s
()
for
op_node
in
ops
:
for
op_node
in
ops
:
for
input_node
in
op_node
.
inputs
:
for
input_node
in
op_node
.
inputs
:
all_used_vars
.
add
(
input_node
)
all_used_vars
.
add
(
input_node
)
for
output_node
in
op_node
.
outputs
:
for
output_node
in
op_node
.
outputs
:
all_used_vars
.
add
(
output_node
)
all_used_vars
.
add
(
output_node
)
all_unused_vars
=
graph
.
all_vars
()
-
all_used_vars
all_used_vars
=
{
n
.
node
for
n
in
all_used_vars
}
all_unused_vars
=
{
n
for
n
in
filter
(
lambda
node
:
node
.
node
not
in
all_used_vars
,
graph
.
all_var_nodes
())
}
graph
.
safe_remove_nodes
(
all_unused_vars
)
graph
.
safe_remove_nodes
(
all_unused_vars
)
...
@@ -655,7 +668,7 @@ class TransformForMobilePass(object):
...
@@ -655,7 +668,7 @@ class TransformForMobilePass(object):
Args:
Args:
graph(IrGraph): the graph will be transformed.
graph(IrGraph): the graph will be transformed.
"""
"""
ops
=
graph
.
all_ops
()
ops
=
graph
.
all_op
_node
s
()
for
op_node
in
ops
:
for
op_node
in
ops
:
name
=
op_node
.
name
()
name
=
op_node
.
name
()
if
name
in
self
.
_fake_quant_op_names
:
if
name
in
self
.
_fake_quant_op_names
:
...
...
python/paddle/fluid/contrib/slim/tests/test_graph.py
浏览文件 @
33f99d61
...
@@ -61,16 +61,16 @@ class TestGraph(unittest.TestCase):
...
@@ -61,16 +61,16 @@ class TestGraph(unittest.TestCase):
opt
.
minimize
(
loss
)
opt
.
minimize
(
loss
)
graph
=
IrGraph
(
core
.
Graph
(
main
.
desc
),
for_test
=
False
)
graph
=
IrGraph
(
core
.
Graph
(
main
.
desc
),
for_test
=
False
)
marked_nodes
=
set
()
marked_nodes
=
set
()
for
op
in
graph
.
all_ops
():
for
op
in
graph
.
all_op
_node
s
():
if
op
.
name
().
find
(
'conv2d'
)
>
-
1
:
if
op
.
name
().
find
(
'conv2d'
)
>
-
1
:
marked_nodes
.
add
(
op
)
marked_nodes
.
add
(
op
)
graph
.
draw
(
'.'
,
'residual'
,
marked_nodes
)
graph
.
draw
(
'.'
,
'residual'
,
marked_nodes
)
self
.
assertFalse
(
graph
.
has_circle
())
self
.
assertFalse
(
graph
.
has_circle
())
self
.
assertEqual
(
graph
.
graph_num
(),
1
)
self
.
assertEqual
(
graph
.
graph_num
(),
1
)
nodes
=
graph
.
topology_sort
()
nodes
=
graph
.
topology_sort
()
self
.
assertEqual
(
len
(
nodes
),
len
(
graph
.
all_ops
()))
self
.
assertEqual
(
len
(
nodes
),
len
(
graph
.
all_op
_node
s
()))
nodes_map
=
graph
.
build_adjacency_list
()
nodes_map
=
graph
.
build_adjacency_list
()
self
.
assertEqual
(
len
(
nodes_map
),
len
(
graph
.
all_ops
()))
self
.
assertEqual
(
len
(
nodes_map
),
len
(
graph
.
all_op
_node
s
()))
nodes_num
=
len
(
graph
.
all_nodes
())
nodes_num
=
len
(
graph
.
all_nodes
())
graph
.
safe_remove_nodes
(
marked_nodes
)
graph
.
safe_remove_nodes
(
marked_nodes
)
self
.
assertEqual
(
len
(
graph
.
all_nodes
()),
nodes_num
-
len
(
marked_nodes
))
self
.
assertEqual
(
len
(
graph
.
all_nodes
()),
nodes_num
-
len
(
marked_nodes
))
...
...
python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py
浏览文件 @
33f99d61
...
@@ -130,15 +130,16 @@ class TestQuantizationTransformPass(unittest.TestCase):
...
@@ -130,15 +130,16 @@ class TestQuantizationTransformPass(unittest.TestCase):
loss
=
linear_fc
(
3
)
loss
=
linear_fc
(
3
)
opt
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
0.001
)
opt
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
0.001
)
opt
.
minimize
(
loss
)
opt
.
minimize
(
loss
)
exe
=
fluid
.
Executor
(
fluid
.
CPUPlace
())
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
graph
=
IrGraph
(
core
.
Graph
(
main
.
desc
),
for_test
=
False
)
graph
=
IrGraph
(
core
.
Graph
(
main
.
desc
),
for_test
=
False
)
transform_pass
=
QuantizationTransformPass
(
transform_pass
=
QuantizationTransformPass
(
scope
=
fluid
.
global_scope
(),
scope
=
fluid
.
global_scope
(),
p
rogram_exe
=
ex
e
,
p
lace
=
plac
e
,
activation_quantize_type
=
quant_type
)
activation_quantize_type
=
quant_type
)
transform_pass
.
apply
(
graph
)
transform_pass
.
apply
(
graph
)
marked_nodes
=
set
()
marked_nodes
=
set
()
for
op
in
graph
.
all_ops
():
for
op
in
graph
.
all_op
_node
s
():
if
op
.
name
().
find
(
'quantize'
)
>
-
1
:
if
op
.
name
().
find
(
'quantize'
)
>
-
1
:
marked_nodes
.
add
(
op
)
marked_nodes
.
add
(
op
)
graph
.
draw
(
'.'
,
'quantize_fc_'
+
quant_type
,
marked_nodes
)
graph
.
draw
(
'.'
,
'quantize_fc_'
+
quant_type
,
marked_nodes
)
...
@@ -146,7 +147,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
...
@@ -146,7 +147,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
self
.
check_program
(
transform_pass
,
program
)
self
.
check_program
(
transform_pass
,
program
)
val_graph
=
IrGraph
(
core
.
Graph
(
program
.
desc
),
for_test
=
False
)
val_graph
=
IrGraph
(
core
.
Graph
(
program
.
desc
),
for_test
=
False
)
val_marked_nodes
=
set
()
val_marked_nodes
=
set
()
for
op
in
val_graph
.
all_ops
():
for
op
in
val_graph
.
all_op
_node
s
():
if
op
.
name
().
find
(
'quantize'
)
>
-
1
:
if
op
.
name
().
find
(
'quantize'
)
>
-
1
:
val_marked_nodes
.
add
(
op
)
val_marked_nodes
.
add
(
op
)
val_graph
.
draw
(
'.'
,
'val_fc_'
+
quant_type
,
val_marked_nodes
)
val_graph
.
draw
(
'.'
,
'val_fc_'
+
quant_type
,
val_marked_nodes
)
...
@@ -166,15 +167,16 @@ class TestQuantizationTransformPass(unittest.TestCase):
...
@@ -166,15 +167,16 @@ class TestQuantizationTransformPass(unittest.TestCase):
loss
=
residual_block
(
2
)
loss
=
residual_block
(
2
)
opt
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
0.001
)
opt
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
0.001
)
opt
.
minimize
(
loss
)
opt
.
minimize
(
loss
)
exe
=
fluid
.
Executor
(
fluid
.
CPUPlace
())
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
graph
=
IrGraph
(
core
.
Graph
(
main
.
desc
),
for_test
=
False
)
graph
=
IrGraph
(
core
.
Graph
(
main
.
desc
),
for_test
=
False
)
transform_pass
=
QuantizationTransformPass
(
transform_pass
=
QuantizationTransformPass
(
scope
=
fluid
.
global_scope
(),
scope
=
fluid
.
global_scope
(),
p
rogram_exe
=
ex
e
,
p
lace
=
plac
e
,
activation_quantize_type
=
quant_type
)
activation_quantize_type
=
quant_type
)
transform_pass
.
apply
(
graph
)
transform_pass
.
apply
(
graph
)
marked_nodes
=
set
()
marked_nodes
=
set
()
for
op
in
graph
.
all_ops
():
for
op
in
graph
.
all_op
_node
s
():
if
op
.
name
().
find
(
'quantize'
)
>
-
1
:
if
op
.
name
().
find
(
'quantize'
)
>
-
1
:
marked_nodes
.
add
(
op
)
marked_nodes
.
add
(
op
)
graph
.
draw
(
'.'
,
'quantize_residual_'
+
quant_type
,
marked_nodes
)
graph
.
draw
(
'.'
,
'quantize_residual_'
+
quant_type
,
marked_nodes
)
...
@@ -182,7 +184,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
...
@@ -182,7 +184,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
self
.
check_program
(
transform_pass
,
program
)
self
.
check_program
(
transform_pass
,
program
)
val_graph
=
IrGraph
(
core
.
Graph
(
program
.
desc
),
for_test
=
False
)
val_graph
=
IrGraph
(
core
.
Graph
(
program
.
desc
),
for_test
=
False
)
val_marked_nodes
=
set
()
val_marked_nodes
=
set
()
for
op
in
val_graph
.
all_ops
():
for
op
in
val_graph
.
all_op
_node
s
():
if
op
.
name
().
find
(
'quantize'
)
>
-
1
:
if
op
.
name
().
find
(
'quantize'
)
>
-
1
:
val_marked_nodes
.
add
(
op
)
val_marked_nodes
.
add
(
op
)
val_graph
.
draw
(
'.'
,
'val_residual_'
+
quant_type
,
val_marked_nodes
)
val_graph
.
draw
(
'.'
,
'val_residual_'
+
quant_type
,
val_marked_nodes
)
...
@@ -231,17 +233,17 @@ class TestQuantizationFreezePass(unittest.TestCase):
...
@@ -231,17 +233,17 @@ class TestQuantizationFreezePass(unittest.TestCase):
with
fluid
.
scope_guard
(
scope
):
with
fluid
.
scope_guard
(
scope
):
exe
.
run
(
startup
)
exe
.
run
(
startup
)
transform_pass
=
QuantizationTransformPass
(
transform_pass
=
QuantizationTransformPass
(
scope
=
scope
,
p
rogram_exe
=
ex
e
,
activation_quantize_type
=
quant_type
)
scope
=
scope
,
p
lace
=
plac
e
,
activation_quantize_type
=
quant_type
)
transform_pass
.
apply
(
main_graph
)
transform_pass
.
apply
(
main_graph
)
transform_pass
.
apply
(
test_graph
)
transform_pass
.
apply
(
test_graph
)
dev_name
=
'_gpu_'
if
use_cuda
else
'_cpu_'
dev_name
=
'_gpu_'
if
use_cuda
else
'_cpu_'
marked_nodes
=
set
()
marked_nodes
=
set
()
for
op
in
main_graph
.
all_ops
():
for
op
in
main_graph
.
all_op
_node
s
():
if
op
.
name
().
find
(
'quantize'
)
>
-
1
:
if
op
.
name
().
find
(
'quantize'
)
>
-
1
:
marked_nodes
.
add
(
op
)
marked_nodes
.
add
(
op
)
main_graph
.
draw
(
'.'
,
'main'
+
dev_name
+
quant_type
,
marked_nodes
)
main_graph
.
draw
(
'.'
,
'main'
+
dev_name
+
quant_type
,
marked_nodes
)
marked_nodes
=
set
()
marked_nodes
=
set
()
for
op
in
test_graph
.
all_ops
():
for
op
in
test_graph
.
all_op
_node
s
():
if
op
.
name
().
find
(
'quantize'
)
>
-
1
:
if
op
.
name
().
find
(
'quantize'
)
>
-
1
:
marked_nodes
.
add
(
op
)
marked_nodes
.
add
(
op
)
test_graph
.
draw
(
'.'
,
'test'
+
dev_name
+
quant_type
,
marked_nodes
)
test_graph
.
draw
(
'.'
,
'test'
+
dev_name
+
quant_type
,
marked_nodes
)
...
@@ -251,11 +253,6 @@ class TestQuantizationFreezePass(unittest.TestCase):
...
@@ -251,11 +253,6 @@ class TestQuantizationFreezePass(unittest.TestCase):
iters
=
5
iters
=
5
batch_size
=
8
batch_size
=
8
#train_exe = fluid.ParallelExecutor(
# main_program=quantized_main_program,
# use_cuda=bool(use_cuda),
# loss_name=loss.name,
# scope=scope)
train_reader
=
paddle
.
batch
(
train_reader
=
paddle
.
batch
(
paddle
.
reader
.
shuffle
(
paddle
.
reader
.
shuffle
(
paddle
.
dataset
.
mnist
.
train
(),
buf_size
=
500
),
paddle
.
dataset
.
mnist
.
train
(),
buf_size
=
500
),
...
@@ -269,9 +266,7 @@ class TestQuantizationFreezePass(unittest.TestCase):
...
@@ -269,9 +266,7 @@ class TestQuantizationFreezePass(unittest.TestCase):
loss_v
=
exe
.
run
(
program
=
quantized_main_program
,
loss_v
=
exe
.
run
(
program
=
quantized_main_program
,
feed
=
feeder
.
feed
(
data
),
feed
=
feeder
.
feed
(
data
),
fetch_list
=
[
loss
])
fetch_list
=
[
loss
])
#loss_v = train_exe.run(feed=feeder.feed(data),
print
(
'{}: {}'
.
format
(
'loss'
+
dev_name
+
quant_type
,
loss_v
))
# fetch_list=[loss.name])
#print('{}: {}'.format('loss' + dev_name + quant_type, loss_v))
test_data
=
next
(
test_reader
())
test_data
=
next
(
test_reader
())
with
fluid
.
program_guard
(
quantized_test_program
):
with
fluid
.
program_guard
(
quantized_test_program
):
...
@@ -287,7 +282,7 @@ class TestQuantizationFreezePass(unittest.TestCase):
...
@@ -287,7 +282,7 @@ class TestQuantizationFreezePass(unittest.TestCase):
freeze_pass
=
QuantizationFreezePass
(
scope
=
scope
,
place
=
place
)
freeze_pass
=
QuantizationFreezePass
(
scope
=
scope
,
place
=
place
)
freeze_pass
.
apply
(
test_graph
)
freeze_pass
.
apply
(
test_graph
)
marked_nodes
=
set
()
marked_nodes
=
set
()
for
op
in
test_graph
.
all_ops
():
for
op
in
test_graph
.
all_op
_node
s
():
if
op
.
name
().
find
(
'quantize'
)
>
-
1
:
if
op
.
name
().
find
(
'quantize'
)
>
-
1
:
marked_nodes
.
add
(
op
)
marked_nodes
.
add
(
op
)
test_graph
.
draw
(
'.'
,
'test_freeze'
+
dev_name
+
quant_type
,
test_graph
.
draw
(
'.'
,
'test_freeze'
+
dev_name
+
quant_type
,
...
@@ -299,21 +294,21 @@ class TestQuantizationFreezePass(unittest.TestCase):
...
@@ -299,21 +294,21 @@ class TestQuantizationFreezePass(unittest.TestCase):
feed
=
feeder
.
feed
(
test_data
),
feed
=
feeder
.
feed
(
test_data
),
fetch_list
=
[
loss
])
fetch_list
=
[
loss
])
self
.
assertAlmostEqual
(
test_loss1
,
test_loss2
,
delta
=
5e-3
)
self
.
assertAlmostEqual
(
test_loss1
,
test_loss2
,
delta
=
5e-3
)
#
print('{}: {}'.format('test_loss1' + dev_name + quant_type, test_loss1))
print
(
'{}: {}'
.
format
(
'test_loss1'
+
dev_name
+
quant_type
,
test_loss1
))
#
print('{}: {}'.format('test_loss2' + dev_name + quant_type, test_loss2))
print
(
'{}: {}'
.
format
(
'test_loss2'
+
dev_name
+
quant_type
,
test_loss2
))
w_freeze
=
np
.
array
(
scope
.
find_var
(
'conv2d_1.w_0'
).
get_tensor
())
w_freeze
=
np
.
array
(
scope
.
find_var
(
'conv2d_1.w_0'
).
get_tensor
())
# Maybe failed, this is due to the calculation precision
# Maybe failed, this is due to the calculation precision
# self.assertAlmostEqual(np.sum(w_freeze), np.sum(w_quant))
# self.assertAlmostEqual(np.sum(w_freeze), np.sum(w_quant))
#
print('{}: {}'.format('w_freeze' + dev_name + quant_type,
print
(
'{}: {}'
.
format
(
'w_freeze'
+
dev_name
+
quant_type
,
#
np.sum(w_freeze)))
np
.
sum
(
w_freeze
)))
#
print('{}: {}'.format('w_quant' + dev_name + quant_type,
print
(
'{}: {}'
.
format
(
'w_quant'
+
dev_name
+
quant_type
,
#
np.sum(w_quant)))
np
.
sum
(
w_quant
)))
# Convert parameter to 8-bit.
# Convert parameter to 8-bit.
convert_int8_pass
=
ConvertToInt8Pass
(
scope
=
scope
,
place
=
place
)
convert_int8_pass
=
ConvertToInt8Pass
(
scope
=
scope
,
place
=
place
)
convert_int8_pass
.
apply
(
test_graph
)
convert_int8_pass
.
apply
(
test_graph
)
marked_nodes
=
set
()
marked_nodes
=
set
()
for
op
in
test_graph
.
all_ops
():
for
op
in
test_graph
.
all_op
_node
s
():
if
op
.
name
().
find
(
'quantize'
)
>
-
1
:
if
op
.
name
().
find
(
'quantize'
)
>
-
1
:
marked_nodes
.
add
(
op
)
marked_nodes
.
add
(
op
)
test_graph
.
draw
(
'.'
,
'test_int8'
+
dev_name
+
quant_type
,
marked_nodes
)
test_graph
.
draw
(
'.'
,
'test_int8'
+
dev_name
+
quant_type
,
marked_nodes
)
...
@@ -330,14 +325,14 @@ class TestQuantizationFreezePass(unittest.TestCase):
...
@@ -330,14 +325,14 @@ class TestQuantizationFreezePass(unittest.TestCase):
w_8bit
=
np
.
array
(
scope
.
find_var
(
'conv2d_1.w_0.int8'
).
get_tensor
())
w_8bit
=
np
.
array
(
scope
.
find_var
(
'conv2d_1.w_0.int8'
).
get_tensor
())
self
.
assertEqual
(
w_8bit
.
dtype
,
np
.
int8
)
self
.
assertEqual
(
w_8bit
.
dtype
,
np
.
int8
)
self
.
assertEqual
(
np
.
sum
(
w_8bit
),
np
.
sum
(
w_freeze
))
self
.
assertEqual
(
np
.
sum
(
w_8bit
),
np
.
sum
(
w_freeze
))
#
print('{}: {}'.format('w_8bit' + dev_name + quant_type, np.sum(w_8bit)))
print
(
'{}: {}'
.
format
(
'w_8bit'
+
dev_name
+
quant_type
,
np
.
sum
(
w_8bit
)))
#
print('{}: {}'.format('w_freeze' + dev_name + quant_type,
print
(
'{}: {}'
.
format
(
'w_freeze'
+
dev_name
+
quant_type
,
#
np.sum(w_freeze)))
np
.
sum
(
w_freeze
)))
mobile_pass
=
TransformForMobilePass
()
mobile_pass
=
TransformForMobilePass
()
mobile_pass
.
apply
(
test_graph
)
mobile_pass
.
apply
(
test_graph
)
marked_nodes
=
set
()
marked_nodes
=
set
()
for
op
in
test_graph
.
all_ops
():
for
op
in
test_graph
.
all_op
_node
s
():
if
op
.
name
().
find
(
'quantize'
)
>
-
1
:
if
op
.
name
().
find
(
'quantize'
)
>
-
1
:
marked_nodes
.
add
(
op
)
marked_nodes
.
add
(
op
)
test_graph
.
draw
(
'.'
,
'test_mobile'
+
dev_name
+
quant_type
,
test_graph
.
draw
(
'.'
,
'test_mobile'
+
dev_name
+
quant_type
,
...
...
python/paddle/fluid/framework.py
浏览文件 @
33f99d61
...
@@ -1538,10 +1538,297 @@ class Block(object):
...
@@ -1538,10 +1538,297 @@ class Block(object):
return
ret_var
return
ret_var
class
IrNode
(
object
):
"""
Python IrNode. Beneath it is a core.Node, which is used for Ir Pass.
"""
def
__init__
(
self
,
node
):
"""
Construct an IrNode using core.Node.
Args:
node(core.Node): C++ Node.
"""
assert
isinstance
(
node
,
core
.
Node
),
'node must be the instance of core.Node.'
self
.
node
=
node
def
name
(
self
):
"""
Return the node name.
Returns:
str: node name.
"""
return
self
.
node
.
name
()
def
node_type
(
self
):
"""
Return the node type.
Returns:
core.Node.Type: node type(core.Node.Type.Operation or core.Node.Type.Variable).
"""
return
self
.
node
.
node_type
()
def
var
(
self
):
"""
Return the node variable description.
Returns:
core.VarDesc: node variable description.
"""
return
self
.
node
.
var
()
def
op
(
self
):
"""
Return the node operator description.
Returns:
core.OpDesc: node operator description.
"""
return
self
.
node
.
op
()
def
id
(
self
):
"""
Return the node id.
Returns:
int: node id.
"""
return
self
.
node
.
id
()
def
is_op
(
self
):
"""
If the node is an operator, then return true.
Returns:
bool: indicate whether the node is an operator.
"""
return
self
.
node
.
is_op
()
def
is_var
(
self
):
"""
If the node is a variable, then return true.
Returns:
bool: indicate whether the node is a variable.
"""
return
self
.
node
.
is_var
()
def
is_ctrl_var
(
self
):
"""
If the node is a control dependence variable, then return true.
Returns:
bool: indicate whether the node is a control dependence variable.
"""
return
self
.
node
.
is_ctrl_var
()
def
clear_inputs
(
self
):
"""
Clear the node inputs. After executing the `clear_inputs` function,
the node inputs will be empty.
"""
self
.
node
.
clear_inputs
()
def
inputs_remove_by_id
(
self
,
node_id
):
"""
Remove a node from inputs by the given node id.
Args:
node_id(int): the given node id.
"""
self
.
node
.
inputs_remove
(
node_id
)
def
inputs_remove
(
self
,
ir_node
):
"""
Remove a node from inputs.
Args:
ir_node(IrNode): the node being removed.
"""
self
.
node
.
inputs_remove
(
ir_node
.
node
)
def
inputs_append
(
self
,
ir_node
):
"""
Append a node in inputs.
Args:
ir_node(IrNode): the node being appended.
"""
self
.
node
.
inputs_append
(
ir_node
.
node
)
def
clear_outputs
(
self
):
"""
Clear the node outputs. After executing the `clear_outputs` function,
the node outputs will be empty.
"""
self
.
node
.
clear_outputs
()
def
outputs_remove_by_id
(
self
,
node_id
):
"""
Remove a node from outputs by the given node id.
Args:
node_id(int): the given node id.
"""
self
.
node
.
outputs_remove
(
node_id
)
def
outputs_remove
(
self
,
ir_node
):
"""
Remove a node from outputs.
Args:
ir_node(IrNode): the node being removed.
"""
self
.
node
.
outputs_remove
(
ir_node
.
node
)
def
outputs_append
(
self
,
ir_node
):
"""
Append a node in outputs.
Args:
ir_node(IrNode): the node being appended.
"""
self
.
node
.
outputs_append
(
ir_node
.
node
)
@
property
def
inputs
(
self
):
"""
Return the node inputs.
Returns:
list(IrNode): node inputs wrapped by IrNode.
"""
return
[
IrNode
(
n
)
for
n
in
self
.
node
.
inputs
]
@
property
def
outputs
(
self
):
"""
Return the node outputs.
Returns:
list(IrNode): node outputs wrapped by IrNode.
"""
return
[
IrNode
(
n
)
for
n
in
self
.
node
.
outputs
]
class
IrVarNode
(
IrNode
):
"""
Python IrVarNode. Beneath it is a core.Node, it inherits from IrNode.
"""
def
__init__
(
self
,
node
):
"""
Construct an IrVarNode using core.Node.
Args:
node(core.Node): C++ Node.
"""
assert
isinstance
(
node
,
core
.
Node
)
and
node
.
is_var
(),
\
'node must be the instance of core.Node and it must be a variable node.'
super
(
IrVarNode
,
self
).
__init__
(
node
)
self
.
node
=
node
def
set_shape
(
self
,
shape
):
"""
Set the node variable shape.
Args:
shape(list): shape to be set.
"""
assert
self
.
node
.
var
()
is
not
None
,
\
"The node variable description cannot be None."
self
.
node
.
var
().
set_shape
(
shape
)
def
persistable
(
self
):
"""
If the variable node is a persistable variable, then return true.
Returns:
bool: indicate whether the variable is persistable.
"""
assert
self
.
node
.
var
()
is
not
None
,
\
"The node variable description cannot be None."
return
self
.
node
.
var
().
persistable
()
@
property
def
inputs
(
self
):
"""
Return the node inputs.
Returns:
list(IrOpNode): node inputs wrapped by IrOpNode.
"""
return
[
IrOpNode
(
n
)
for
n
in
self
.
node
.
inputs
]
@
property
def
outputs
(
self
):
"""
Return the node outputs.
Returns:
list(IrOpNode): node outputs wrapped by IrOpNode.
"""
return
[
IrOpNode
(
n
)
for
n
in
self
.
node
.
outputs
]
class
IrOpNode
(
IrNode
):
"""
Python IrOpNode. Beneath it is a core.Node, it inherits from IrNode.
"""
def
__init__
(
self
,
node
):
"""
Construct an IrOpNode using core.Node.
Args:
node(core.Node): C++ Node.
"""
assert
isinstance
(
node
,
core
.
Node
)
and
node
.
is_op
(),
\
'node must be the instance of core.Node and it must be a operator node.'
super
(
IrOpNode
,
self
).
__init__
(
node
)
self
.
node
=
node
def
rename_input
(
self
,
old_input_name
,
new_input_name
):
"""
Rename the input of this node.
Args:
old_input_name(str): the old input name.
new_input_name(str): the new input name.
"""
assert
self
.
node
.
op
()
is
not
None
,
\
"The node operator description cannot be None."
self
.
node
.
op
().
_rename_input
(
old_input_name
,
new_input_name
)
@
property
def
inputs
(
self
):
"""
Return the node inputs.
Returns:
list(IrVarNode): node inputs wrapped by IrVarNode.
"""
return
[
IrVarNode
(
n
)
for
n
in
self
.
node
.
inputs
]
@
property
def
outputs
(
self
):
"""
Return the node outputs.
Returns:
list(IrVarNode): node outputs wrapped by IrVarNode.
"""
return
[
IrVarNode
(
n
)
for
n
in
self
.
node
.
outputs
]
class
IrGraph
(
object
):
class
IrGraph
(
object
):
"""
"""
Python IrGraph. Beneath it is a core.Graph, which is used for
Python IrGraph. Beneath it is a core.Graph, which is used for
creat
e
a c++ Ir Pass Graph. An IrGraph is just a graph view of
creat
ing
a c++ Ir Pass Graph. An IrGraph is just a graph view of
a Program. In an IrGraph, both Variables and Operators are graph
a Program. In an IrGraph, both Variables and Operators are graph
nodes.
nodes.
"""
"""
...
@@ -1569,15 +1856,15 @@ class IrGraph(object):
...
@@ -1569,15 +1856,15 @@ class IrGraph(object):
"""
"""
Return all nodes included in the graph as a set.
Return all nodes included in the graph as a set.
"""
"""
return
{
node
for
node
in
self
.
graph
.
nodes
()}
return
{
IrNode
(
node
)
for
node
in
self
.
graph
.
nodes
()}
def
all_vars
(
self
):
def
all_var
_node
s
(
self
):
"""
"""
Return all variable nodes included in the graph as a set.
Return all variable nodes included in the graph as a set.
"""
"""
return
{
node
for
node
in
self
.
graph
.
nodes
()
if
node
.
is_var
()}
return
{
IrVarNode
(
node
)
for
node
in
self
.
graph
.
nodes
()
if
node
.
is_var
()}
def
all_persistable_
var
s
(
self
):
def
all_persistable_
node
s
(
self
):
"""
"""
Return all persistable variable nodes included in the graph as a set.
Return all persistable variable nodes included in the graph as a set.
"""
"""
...
@@ -1586,13 +1873,13 @@ class IrGraph(object):
...
@@ -1586,13 +1873,13 @@ class IrGraph(object):
if
node
.
is_var
()
and
node
.
var
()
is
not
None
and
node
.
var
(
if
node
.
is_var
()
and
node
.
var
()
is
not
None
and
node
.
var
(
).
persistable
():
).
persistable
():
persistable_nodes
.
add
(
node
)
persistable_nodes
.
add
(
node
)
return
persistable_nodes
return
{
IrVarNode
(
p
)
for
p
in
persistable_nodes
}
def
all_ops
(
self
):
def
all_op
_node
s
(
self
):
"""
"""
Return all operator nodes included in the graph as a set.
Return all operator nodes included in the graph as a set.
"""
"""
return
{
node
for
node
in
self
.
graph
.
nodes
()
if
node
.
is_op
()}
return
{
IrOpNode
(
node
)
for
node
in
self
.
graph
.
nodes
()
if
node
.
is_op
()}
def
var_node
(
self
,
name
):
def
var_node
(
self
,
name
):
"""
"""
...
@@ -1606,14 +1893,14 @@ class IrGraph(object):
...
@@ -1606,14 +1893,14 @@ class IrGraph(object):
doesn't have a variable with the giving name.
doesn't have a variable with the giving name.
Returns:
Returns:
core.
Node: the variable node with the giving name.
IrVar
Node: the variable node with the giving name.
"""
"""
if
not
isinstance
(
name
,
six
.
string_types
):
if
not
isinstance
(
name
,
six
.
string_types
):
raise
TypeError
(
raise
TypeError
(
"var require string as parameter, but get %s instead."
%
"var require string as parameter, but get %s instead."
%
(
type
(
name
)))
(
type
(
name
)))
target_var_node
=
None
target_var_node
=
None
var_nodes
=
self
.
all_vars
()
var_nodes
=
self
.
all_var
_node
s
()
for
var_node
in
var_nodes
:
for
var_node
in
var_nodes
:
if
var_node
.
name
()
==
name
:
if
var_node
.
name
()
==
name
:
target_var_node
=
var_node
target_var_node
=
var_node
...
@@ -1621,7 +1908,7 @@ class IrGraph(object):
...
@@ -1621,7 +1908,7 @@ class IrGraph(object):
raise
ValueError
(
"var_node %s not in this graph"
%
name
)
raise
ValueError
(
"var_node %s not in this graph"
%
name
)
return
target_var_node
return
target_var_node
def
create_p
aram
_node
(
self
,
name
,
var_type
,
shape
,
var_dtype
):
def
create_p
ersistable
_node
(
self
,
name
,
var_type
,
shape
,
var_dtype
):
"""
"""
Create a persistable variable node in the graph. In IrGraph,
Create a persistable variable node in the graph. In IrGraph,
it can not distinguish between persistable variables and parameters.
it can not distinguish between persistable variables and parameters.
...
@@ -1633,14 +1920,14 @@ class IrGraph(object):
...
@@ -1633,14 +1920,14 @@ class IrGraph(object):
var_dtype(core.VarDesc.VarType): the data type of the persistable variable node.
var_dtype(core.VarDesc.VarType): the data type of the persistable variable node.
Returns:
Returns:
core.
Node: the created persistable variable node.
IrVar
Node: the created persistable variable node.
"""
"""
var_desc
=
core
.
VarDesc
(
name
)
var_desc
=
core
.
VarDesc
(
name
)
var_desc
.
set_type
(
var_type
)
var_desc
.
set_type
(
var_type
)
var_desc
.
set_shape
(
shape
)
var_desc
.
set_shape
(
shape
)
var_desc
.
set_dtype
(
var_dtype
)
var_desc
.
set_dtype
(
var_dtype
)
var_desc
.
set_persistable
(
True
)
var_desc
.
set_persistable
(
True
)
return
self
.
graph
.
create_var_node
(
var_desc
)
return
IrVarNode
(
self
.
graph
.
create_var_node
(
var_desc
)
)
def
create_var_node
(
self
,
name
,
var_type
,
shape
,
var_dtype
):
def
create_var_node
(
self
,
name
,
var_type
,
shape
,
var_dtype
):
"""
"""
...
@@ -1654,14 +1941,14 @@ class IrGraph(object):
...
@@ -1654,14 +1941,14 @@ class IrGraph(object):
var_dtype(core.VarDesc.VarType): the data type of the variable node.
var_dtype(core.VarDesc.VarType): the data type of the variable node.
Returns:
Returns:
core.
Node: the created variable node.
IrVar
Node: the created variable node.
"""
"""
var_desc
=
core
.
VarDesc
(
name
)
var_desc
=
core
.
VarDesc
(
name
)
var_desc
.
set_type
(
var_type
)
var_desc
.
set_type
(
var_type
)
var_desc
.
set_shape
(
shape
)
var_desc
.
set_shape
(
shape
)
var_desc
.
set_dtype
(
var_dtype
)
var_desc
.
set_dtype
(
var_dtype
)
return
self
.
graph
.
create_var_node
(
var_desc
)
return
IrVarNode
(
self
.
graph
.
create_var_node
(
var_desc
)
)
def
create_var_node_from_desc
(
self
,
var_desc
):
def
create_var_node_from_desc
(
self
,
var_desc
):
"""
"""
...
@@ -1672,9 +1959,9 @@ class IrGraph(object):
...
@@ -1672,9 +1959,9 @@ class IrGraph(object):
var_desc(core.VarDesc): the giving variable description.
var_desc(core.VarDesc): the giving variable description.
Returns:
Returns:
core.
Node: the created variable node.
IrVar
Node: the created variable node.
"""
"""
return
self
.
graph
.
create_var_node
(
var_desc
)
return
IrVarNode
(
self
.
graph
.
create_var_node
(
var_desc
)
)
def
create_op_node
(
self
,
op_type
,
attrs
,
inputs
,
outputs
):
def
create_op_node
(
self
,
op_type
,
attrs
,
inputs
,
outputs
):
"""
"""
...
@@ -1687,7 +1974,7 @@ class IrGraph(object):
...
@@ -1687,7 +1974,7 @@ class IrGraph(object):
outputs(dict): the outpus of the operator node.
outputs(dict): the outpus of the operator node.
Returns:
Returns:
core.
Node: the created operator node.
IrOp
Node: the created operator node.
"""
"""
op_desc
=
core
.
OpDesc
()
op_desc
=
core
.
OpDesc
()
op_desc
.
set_type
(
op_type
)
op_desc
.
set_type
(
op_type
)
...
@@ -1703,7 +1990,7 @@ class IrGraph(object):
...
@@ -1703,7 +1990,7 @@ class IrGraph(object):
var_nodes
=
[
var_nodes
]
var_nodes
=
[
var_nodes
]
op_desc
.
set_output
(
output_name
,
op_desc
.
set_output
(
output_name
,
[
var_node
.
name
()
for
var_node
in
var_nodes
])
[
var_node
.
name
()
for
var_node
in
var_nodes
])
return
self
.
graph
.
create_op_node
(
op_desc
)
return
IrOpNode
(
self
.
graph
.
create_op_node
(
op_desc
)
)
def
create_op_node_from_desc
(
self
,
op_desc
):
def
create_op_node_from_desc
(
self
,
op_desc
):
"""
"""
...
@@ -1713,37 +2000,37 @@ class IrGraph(object):
...
@@ -1713,37 +2000,37 @@ class IrGraph(object):
op_desc(core.VarDesc): the giving operator description.
op_desc(core.VarDesc): the giving operator description.
Returns:
Returns:
core.
Node: the created operator node.
IrOp
Node: the created operator node.
"""
"""
return
self
.
graph
.
create_op_node
(
op_desc
)
return
IrOpNode
(
self
.
graph
.
create_op_node
(
op_desc
)
)
def
update_input_link
(
self
,
old_input_node
,
new_input_node
,
op_node
):
def
update_input_link
(
self
,
old_input_node
,
new_input_node
,
op_node
):
"""
"""
Update the input's link of a operator node.
Update the input's link of a operator node.
Args:
Args:
old_input_node(
core.
Node): the old input node of the giving op_node.
old_input_node(
Ir
Node): the old input node of the giving op_node.
new_input_node(
core.
Node): the new input node of the giving op_node.
new_input_node(
Ir
Node): the new input node of the giving op_node.
op_node(
core.
Node): the operator node that is needed to update input's link.
op_node(
IrOp
Node): the operator node that is needed to update input's link.
"""
"""
assert
old_input_node
in
self
.
graph
.
nodes
()
and
new_input_
node
in
\
assert
old_input_node
.
node
in
self
.
graph
.
nodes
()
and
new_input_node
.
node
in
\
self
.
graph
.
nodes
()
and
op_node
in
self
.
graph
.
nodes
(),
\
self
.
graph
.
nodes
()
and
op_node
.
node
in
self
.
graph
.
nodes
(),
\
'The three arguments(old_input_node&new_input_node&op_node) must be in the graph nodes.'
'The three arguments(old_input_node&new_input_node&op_node) must be in the graph nodes.'
old_input_node
.
outputs_remove
(
op_node
)
old_input_node
.
outputs_remove
(
op_node
)
op_node
.
inputs_remove
(
old_input_node
)
op_node
.
inputs_remove
(
old_input_node
)
new_input_node
.
outputs_append
(
op_node
)
new_input_node
.
outputs_append
(
op_node
)
op_node
.
inputs_append
(
new_input_node
)
op_node
.
inputs_append
(
new_input_node
)
op_node
.
op
().
_
rename_input
(
old_input_node
.
name
(),
new_input_node
.
name
())
op_node
.
rename_input
(
old_input_node
.
name
(),
new_input_node
.
name
())
def
link_to
(
self
,
node_in
,
node_out
):
def
link_to
(
self
,
node_in
,
node_out
):
"""
"""
Connect two nodes.
Connect two nodes.
Args:
Args:
node_in(
core.
Node): the input node.
node_in(
Ir
Node): the input node.
node_out(
core.
Node): the output node.
node_out(
Ir
Node): the output node.
"""
"""
assert
node_in
in
self
.
graph
.
nodes
()
and
node_out
in
self
.
graph
.
nodes
(),
\
assert
node_in
.
node
in
self
.
graph
.
nodes
()
and
node_out
.
node
in
self
.
graph
.
nodes
(),
\
'The two arguments(node_in&node_out) must be in the graph nodes.'
'The two arguments(node_in&node_out) must be in the graph nodes.'
node_in
.
outputs_append
(
node_out
)
node_in
.
outputs_append
(
node_out
)
node_out
.
inputs_append
(
node_in
)
node_out
.
inputs_append
(
node_in
)
...
@@ -1761,7 +2048,8 @@ class IrGraph(object):
...
@@ -1761,7 +2048,8 @@ class IrGraph(object):
remove_nodes
=
set
(
remove_nodes
)
remove_nodes
=
set
(
remove_nodes
)
else
:
else
:
remove_nodes
=
{
remove_nodes
}
remove_nodes
=
{
remove_nodes
}
core
.
graph_safe_remove_nodes
(
self
.
graph
,
remove_nodes
)
original_nodes
=
{
n
.
node
for
n
in
remove_nodes
}
core
.
graph_safe_remove_nodes
(
self
.
graph
,
original_nodes
)
def
has_circle
(
self
):
def
has_circle
(
self
):
"""
"""
...
@@ -1788,18 +2076,23 @@ class IrGraph(object):
...
@@ -1788,18 +2076,23 @@ class IrGraph(object):
Notes: the `graph` cannot contain a circle.
Notes: the `graph` cannot contain a circle.
Returns:
Returns:
set(
core.
Node): nodes in topology order.
set(
Ir
Node): nodes in topology order.
"""
"""
return
core
.
topology_sort
(
self
.
graph
)
ordered_nodes
=
core
.
topology_sort
(
self
.
graph
)
return
{
IrNode
(
n
)
for
n
in
ordered_nodes
}
def
build_adjacency_list
(
self
):
def
build_adjacency_list
(
self
):
"""
"""
Build an adjacency list of operations for the `graph`.
Build an adjacency list of operations for the `graph`.
Returns:
Returns:
dict{
core.Node: set(core.
Node)}: the adjacency list.
dict{
IrNode: set(Ir
Node)}: the adjacency list.
"""
"""
return
core
.
build_adjacency_list
(
self
.
graph
)
adj_list
=
core
.
build_adjacency_list
(
self
.
graph
)
wrapped_adj_list
=
dict
()
for
k
,
v
in
six
.
iteritems
(
adj_list
):
wrapped_adj_list
[
IrNode
(
k
)]
=
{
IrNode
(
n
)
for
n
in
v
}
return
wrapped_adj_list
def
draw
(
self
,
save_path
,
name
,
marked_nodes
=
None
,
remove_ctr_var
=
True
):
def
draw
(
self
,
save_path
,
name
,
marked_nodes
=
None
,
remove_ctr_var
=
True
):
"""
"""
...
@@ -1809,7 +2102,7 @@ class IrGraph(object):
...
@@ -1809,7 +2102,7 @@ class IrGraph(object):
Args:
Args:
save_path(str): the save path of drawn graph.
save_path(str): the save path of drawn graph.
name(str): the name of drawn graph.
name(str): the name of drawn graph.
marked_nodes(set(
core.
Node)): nodes that are needed to be marked.
marked_nodes(set(
Ir
Node)): nodes that are needed to be marked.
Default value is None.
Default value is None.
remove_ctr_var(bool): If it is set True, all control variable nodes
remove_ctr_var(bool): If it is set True, all control variable nodes
in the graph will be removed. Default value is True.
in the graph will be removed. Default value is True.
...
@@ -1824,20 +2117,22 @@ class IrGraph(object):
...
@@ -1824,20 +2117,22 @@ class IrGraph(object):
print
(
'The {} is saved as the dot filetype.'
.
format
(
print
(
'The {} is saved as the dot filetype.'
.
format
(
dot_file_path
))
dot_file_path
))
if
remove_ctr_var
:
remove_ctr_vars
=
set
()
remove_ctr_vars
=
set
()
for
node
in
self
.
graph
.
nodes
():
if
remove_ctr_var
:
for
node
in
self
.
all_var_nodes
():
if
node
.
is_ctrl_var
():
if
node
.
is_ctrl_var
():
remove_ctr_vars
.
add
(
node
)
remove_ctr_vars
.
add
(
node
)
self
.
safe_remove_nodes
(
remove_ctr_vars
)
self
.
safe_remove_nodes
(
remove_ctr_vars
)
ops_num
=
0
print
(
'Total ops num = {}.'
.
format
(
len
(
self
.
all_op_nodes
())))
for
node
in
self
.
graph
.
nodes
():
if
node
.
is_op
():
ops_num
+=
1
print
(
'Total ops num = {}.'
.
format
(
ops_num
))
if
marked_nodes
is
not
None
:
if
marked_nodes
is
not
None
:
if
not
isinstance
(
marked_nodes
,
set
):
if
not
isinstance
(
marked_nodes
,
set
):
if
isinstance
(
marked_nodes
,
Iterable
):
marked_nodes
=
set
(
marked_nodes
)
marked_nodes
=
set
(
marked_nodes
)
else
:
marked_nodes
=
{
marked_nodes
}
marked_nodes
=
{
n
.
node
for
n
in
marked_nodes
}
remove_ctr_vars
=
{
n
.
node
for
n
in
remove_ctr_vars
}
marked_nodes
=
marked_nodes
-
remove_ctr_vars
marked_nodes
=
marked_nodes
-
remove_ctr_vars
if
self
.
graph
.
has
(
'__graphviz__marked_node__'
):
if
self
.
graph
.
has
(
'__graphviz__marked_node__'
):
self
.
graph
.
erase
(
'__graphviz__marked_node__'
)
self
.
graph
.
erase
(
'__graphviz__marked_node__'
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录