Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
ab9c4b2a
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ab9c4b2a
编写于
1月 10, 2019
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine seqpool concat pass and remove unused nodes
test=develop
上级
ce909664
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
128 addition
and
27 deletion
+128
-27
paddle/fluid/framework/ir/seqpool_concat_fuse_pass.cc
paddle/fluid/framework/ir/seqpool_concat_fuse_pass.cc
+21
-4
paddle/fluid/framework/ir/seqpool_concat_fuse_pass_tester.cc
paddle/fluid/framework/ir/seqpool_concat_fuse_pass_tester.cc
+107
-23
未找到文件。
paddle/fluid/framework/ir/seqpool_concat_fuse_pass.cc
浏览文件 @
ab9c4b2a
...
...
@@ -76,6 +76,7 @@ PDNode* BuildSeqPoolConcatPattern(PDPattern* pattern,
std
::
vector
<
PDNode
*>
seqpool_ops_input_var
(
num_inputs
);
std
::
vector
<
PDNode
*>
seqpool_ops_output_var
(
num_inputs
);
std
::
vector
<
PDNode
*>
seqpool_ops_output_unused_var
(
num_inputs
);
std
::
vector
<
PDNode
*>
seqpool_ops
(
num_inputs
);
for
(
int
i
=
0
;
i
<
num_inputs
;
++
i
)
{
...
...
@@ -88,6 +89,15 @@ PDNode* BuildSeqPoolConcatPattern(PDPattern* pattern,
},
name_scope
+
"/sequence_pool_out_"
+
std
::
to_string
(
i
));
seqpool_ops_output_unused_var
[
i
]
=
pattern
->
NewNode
(
[
=
](
Node
*
x
)
{
return
x
&&
x
->
IsVar
()
&&
x
->
inputs
.
size
()
==
1
&&
x
->
outputs
.
size
()
==
0
&&
is_seqpool_op_with_pootype_of_nth_input_of_concat
(
x
->
inputs
[
0
],
"SUM"
,
i
);
},
name_scope
+
"/sequence_pool_unused_out_"
+
std
::
to_string
(
i
));
seqpool_ops
[
i
]
=
pattern
->
NewNode
(
[
=
](
Node
*
x
)
{
return
x
&&
x
->
IsOp
()
&&
...
...
@@ -97,16 +107,23 @@ PDNode* BuildSeqPoolConcatPattern(PDPattern* pattern,
seqpool_ops_input_var
[
i
]
=
pattern
->
NewNode
(
[
=
](
Node
*
x
)
{
return
x
&&
x
->
IsVar
()
&&
x
->
outputs
.
size
()
>=
1
&&
is_seqpool_op_with_pootype_of_nth_input_of_concat
(
x
->
outputs
[
0
],
"SUM"
,
i
);
bool
basic
=
x
&&
x
->
IsVar
()
&&
x
->
outputs
.
size
()
>=
1
;
bool
next_is_fine
=
false
;
for
(
auto
*
o
:
x
->
outputs
)
{
if
(
is_seqpool_op_with_pootype_of_nth_input_of_concat
(
o
,
"SUM"
,
i
))
{
next_is_fine
=
true
;
break
;
}
}
return
basic
&&
next_is_fine
;
},
name_scope
+
"/sequence_pool_in_"
+
std
::
to_string
(
i
));
// Links
seqpool_ops
[
i
]
->
LinksFrom
({
seqpool_ops_input_var
[
i
]})
.
LinksTo
({
seqpool_ops_output_var
[
i
]});
.
LinksTo
({
seqpool_ops_output_var
[
i
]
,
seqpool_ops_output_unused_var
[
i
]
});
}
concat_op
->
LinksFrom
(
seqpool_ops_output_var
).
LinksTo
({
concat_out_var
});
return
concat_out_var
;
...
...
paddle/fluid/framework/ir/seqpool_concat_fuse_pass_tester.cc
浏览文件 @
ab9c4b2a
...
...
@@ -35,11 +35,35 @@ void SetOp(ProgramDesc* prog, const std::string& type,
op
->
SetInput
(
"X"
,
inputs
);
op
->
SetAttr
(
"axis"
,
1
);
op
->
SetOutput
(
"Out"
,
{
outputs
[
0
]});
}
else
{
op
->
SetInput
(
"X"
,
inputs
);
op
->
SetOutput
(
"Out"
,
outputs
);
}
op
->
SetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
(),
static_cast
<
int
>
(
OpRole
::
kForward
));
}
int
CountOpType
(
const
ir
::
Graph
*
graph
,
const
std
::
string
&
op_type
=
"fusion_seqpool_concat"
)
{
int
count
=
0
;
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
IsOp
()
&&
node
->
Op
()
->
Type
()
==
op_type
)
{
++
count
;
}
}
return
count
;
}
std
::
unique_ptr
<
ir
::
Graph
>
GetNumNodesOfBeforeAfter
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
,
int
*
before
,
int
*
after
,
const
std
::
string
&
pass_type
=
"seqpool_concat_fuse_pass"
)
{
auto
pass
=
PassRegistry
::
Instance
().
Get
(
pass_type
);
*
before
=
graph
->
Nodes
().
size
();
graph
=
pass
->
Apply
(
std
::
move
(
graph
));
*
after
=
graph
->
Nodes
().
size
();
return
graph
;
}
/*
* Before fuse:
* a b c
...
...
@@ -51,15 +75,16 @@ void SetOp(ProgramDesc* prog, const std::string& type,
* concat
* |
* j
* Type of op1, op2 and op3 are sequence_pool, with "SUM" pooltype attr
*
* After fuse:
* a b c
* \ | /
* fusion_seqpool_concat
* |
* j
* unused nodes: d, f, h
*/
ProgramDesc
BuildProgramDesc
(
)
{
TEST
(
SeqPoolConcatFusePass
,
basic
)
{
ProgramDesc
prog
;
for
(
auto
&
v
:
std
::
vector
<
std
::
string
>
(
{
"a"
,
"b"
,
"c"
,
"d"
,
"e"
,
"f"
,
"g"
,
"h"
,
"i"
,
"j"
}))
{
...
...
@@ -76,35 +101,94 @@ ProgramDesc BuildProgramDesc() {
SetOp
(
&
prog
,
"concat"
,
std
::
vector
<
std
::
string
>
({
"e"
,
"g"
,
"i"
}),
std
::
vector
<
std
::
string
>
({
"j"
}));
return
prog
;
}
TEST
(
SeqPoolConcatFusePass
,
basic
)
{
auto
prog
=
BuildProgramDesc
();
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
prog
));
int
before
,
after
;
graph
=
GetNumNodesOfBeforeAfter
(
std
::
move
(
graph
),
&
before
,
&
after
);
// Remove 10 Nodes: op1, op2, op3, d, e, f, g, h, i, concat_op
// Add 1 Node: fusion_seqpool_concat
EXPECT_EQ
(
after
,
before
-
9
);
EXPECT_EQ
(
CountOpType
(
graph
.
get
()),
1
);
}
auto
pass
=
PassRegistry
::
Instance
().
Get
(
"seqpool_concat_fuse_pass"
);
int
pre_nodes
=
graph
->
Nodes
().
size
();
graph
=
pass
->
Apply
(
std
::
move
(
graph
));
/*
* Before fuse:
* a b
* | / \
* op1 op2 op3
* / \ / \ \
* c d e f g
* \ /
* concat
* |
* h
* Type of op1 and op2 are sequence_pool, with "SUM" pooltype attr
*
* After fuse:
* a b
* \ / \
* fusion_seqpool_concat op3
* | |
* h g
*/
TEST
(
SeqPoolConcatFusePass
,
advanced
)
{
ProgramDesc
prog
;
for
(
auto
&
v
:
std
::
vector
<
std
::
string
>
({
"a"
,
"b"
,
"c"
,
"d"
,
"e"
,
"f"
,
"g"
,
"h"
}))
{
auto
*
var
=
prog
.
MutableBlock
(
0
)
->
Var
(
v
);
var
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
}
int
after_nodes
=
graph
->
Nodes
().
size
();
SetOp
(
&
prog
,
"sequence_pool"
,
std
::
vector
<
std
::
string
>
({
"a"
}),
std
::
vector
<
std
::
string
>
({
"c"
,
"d"
}));
SetOp
(
&
prog
,
"sequence_pool"
,
std
::
vector
<
std
::
string
>
({
"b"
}),
std
::
vector
<
std
::
string
>
({
"e"
,
"f"
}));
SetOp
(
&
prog
,
"op3"
,
std
::
vector
<
std
::
string
>
({
"b"
}),
std
::
vector
<
std
::
string
>
({
"g"
}));
SetOp
(
&
prog
,
"concat"
,
std
::
vector
<
std
::
string
>
({
"d"
,
"f"
}),
std
::
vector
<
std
::
string
>
({
"h"
}));
// Remove 7 Nodes: op1, op2, op3, e, g, i, concat_op
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
prog
));
int
before
,
after
;
graph
=
GetNumNodesOfBeforeAfter
(
std
::
move
(
graph
),
&
before
,
&
after
);
// Remove 7 Nodes: op1, op2, c, d, e, f concat_op
// Add 1 Node: fusion_seqpool_concat
EXPECT_EQ
(
pre_nodes
-
6
,
after_nodes
);
EXPECT_EQ
(
after
,
before
-
6
);
EXPECT_EQ
(
CountOpType
(
graph
.
get
()),
1
);
}
// Assert new op in newly generated graph
int
count
=
0
;
ProgramDesc
BuildProgramDesc
(
int
num_inputs_of_concat
)
{
ProgramDesc
prog
;
auto
new_var
=
[
&
](
const
std
::
string
&
name
)
{
auto
*
var
=
prog
.
MutableBlock
(
0
)
->
Var
(
name
);
var
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
};
std
::
vector
<
std
::
string
>
concat_inputs
;
for
(
int
i
=
0
;
i
<
num_inputs_of_concat
;
++
i
)
{
std
::
string
prefix
=
"seqpool_op_"
+
i
;
new_var
(
prefix
+
"in"
);
new_var
(
prefix
+
"out"
);
new_var
(
prefix
+
"out_unused"
);
SetOp
(
&
prog
,
"sequence_pool"
,
std
::
vector
<
std
::
string
>
({
prefix
+
"in"
}),
std
::
vector
<
std
::
string
>
({
prefix
+
"out"
,
prefix
+
"out_unused"
}));
concat_inputs
.
push_back
(
prefix
+
"out"
);
}
SetOp
(
&
prog
,
"concat"
,
concat_inputs
,
std
::
vector
<
std
::
string
>
({
"concat_out"
}));
return
prog
;
}
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
IsOp
()
&&
node
->
Op
()
->
Type
()
==
"fusion_seqpool_concat"
)
{
++
count
;
}
// test more inputs of concat
TEST
(
SeqPoolConcatFusePass
,
more_inputs
)
{
for
(
int
num
:
{
1
,
2
,
10
})
{
ProgramDesc
prog
=
BuildProgramDesc
(
num
);
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
prog
));
int
before
,
after
;
graph
=
GetNumNodesOfBeforeAfter
(
std
::
move
(
graph
),
&
before
,
&
after
);
// Remove Nodes: n * (seqpool_op, out, out_unused), and concat_op
// Add Node: fusion_seqpool_concat op
EXPECT_EQ
(
after
,
before
-
num
*
3
);
EXPECT_EQ
(
CountOpType
(
graph
.
get
()),
1
);
}
EXPECT_EQ
(
count
,
1
);
}
}
// namespace ir
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录