Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
a0a27bd2
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a0a27bd2
编写于
1月 09, 2019
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add seqpool concat fuse pass tester
test=develop
上级
8e086a85
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
118 addition
and
4 deletion
+118
-4
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+1
-0
paddle/fluid/framework/ir/seqpool_concat_fuse_pass.cc
paddle/fluid/framework/ir/seqpool_concat_fuse_pass.cc
+3
-4
paddle/fluid/framework/ir/seqpool_concat_fuse_pass_tester.cc
paddle/fluid/framework/ir/seqpool_concat_fuse_pass_tester.cc
+114
-0
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
a0a27bd2
...
...
@@ -69,6 +69,7 @@ cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_r
cc_test
(
graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph_to_program_pass
)
cc_test
(
test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector
)
cc_test
(
test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto
)
cc_test
(
test_seqpool_concat_fuse_pass SRCS seqpool_concat_fuse_pass_tester.cc DEPS seqpool_concat_fuse_pass framework_proto
)
cc_test
(
test_is_test_pass SRCS is_test_pass_tester.cc DEPS is_test_pass
)
if
(
WITH_MKLDNN
)
cc_test
(
test_depthwise_conv_mkldnn_pass SRCS depthwise_conv_mkldnn_pass_tester.cc DEPS depthwise_conv_mkldnn_pass
)
...
...
paddle/fluid/framework/ir/seqpool_concat_fuse_pass.cc
浏览文件 @
a0a27bd2
...
...
@@ -112,8 +112,7 @@ PDNode* BuildSeqPoolConcatPattern(PDPattern* pattern,
return
concat_out_var
;
}
int
BuildFusion
(
Graph
*
graph
,
const
std
::
string
&
name_scope
,
Scope
*
scope
,
int
num_inputs
)
{
int
BuildFusion
(
Graph
*
graph
,
const
std
::
string
&
name_scope
,
int
num_inputs
)
{
GraphPatternDetector
gpd
;
auto
*
pattern
=
gpd
.
mutable_pattern
();
BuildSeqPoolConcatPattern
(
pattern
,
name_scope
,
num_inputs
);
...
...
@@ -182,8 +181,8 @@ std::unique_ptr<ir::Graph> SeqPoolConcatFusePass::ApplyImpl(
FusePassBase
::
Init
(
name_scope_
,
graph
.
get
());
int
fusion_count
=
0
;
for
(
int
i
=
MAX_CONCAT_INPUTS
;
i
>
0
;
--
i
)
{
fusion_count
+=
BuildFusion
(
graph
.
get
(),
name_scope_
+
"/"
+
std
::
to_string
(
i
),
param_scope
(
),
i
);
fusion_count
+=
BuildFusion
(
graph
.
get
(),
name_scope_
+
"/"
+
std
::
to_string
(
i
),
i
);
}
AddStatis
(
fusion_count
);
...
...
paddle/fluid/framework/ir/seqpool_concat_fuse_pass_tester.cc
0 → 100644
浏览文件 @
a0a27bd2
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/seqpool_concat_fuse_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/op_proto_maker.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
void
SetOp
(
ProgramDesc
*
prog
,
const
std
::
string
&
type
,
const
std
::
vector
<
std
::
string
>&
inputs
,
const
std
::
vector
<
std
::
string
>&
outputs
)
{
auto
*
op
=
prog
->
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
type
);
if
(
type
==
"sequence_pool"
)
{
op
->
SetInput
(
"X"
,
{
inputs
[
0
]});
std
::
string
pooltype
=
"SUM"
;
op
->
SetAttr
(
"pooltype"
,
pooltype
);
op
->
SetOutput
(
"MaxIndex"
,
{
outputs
[
0
]});
op
->
SetOutput
(
"Out"
,
{
outputs
[
1
]});
}
else
if
(
type
==
"concat"
)
{
op
->
SetInput
(
"X"
,
inputs
);
op
->
SetAttr
(
"axis"
,
1
);
op
->
SetOutput
(
"Out"
,
{
outputs
[
0
]});
}
op
->
SetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
(),
static_cast
<
int
>
(
OpRole
::
kForward
));
}
/*
* Before fuse:
* a b c
* | | |
* op1 op2 op3
* / \ / \ / \
* d e f g h i
* \ | /
* concat
* |
* j
* After fuse:
* a b c
* \ | /
* fusion_seqpool_concat
* |
* j
* unused nodes: d, f, h
*/
ProgramDesc
BuildProgramDesc
()
{
ProgramDesc
prog
;
for
(
auto
&
v
:
std
::
vector
<
std
::
string
>
(
{
"a"
,
"b"
,
"c"
,
"d"
,
"e"
,
"f"
,
"g"
,
"h"
,
"i"
,
"j"
}))
{
auto
*
var
=
prog
.
MutableBlock
(
0
)
->
Var
(
v
);
var
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
}
SetOp
(
&
prog
,
"sequence_pool"
,
std
::
vector
<
std
::
string
>
({
"a"
}),
std
::
vector
<
std
::
string
>
({
"d"
,
"e"
}));
SetOp
(
&
prog
,
"sequence_pool"
,
std
::
vector
<
std
::
string
>
({
"b"
}),
std
::
vector
<
std
::
string
>
({
"f"
,
"g"
}));
SetOp
(
&
prog
,
"sequence_pool"
,
std
::
vector
<
std
::
string
>
({
"c"
}),
std
::
vector
<
std
::
string
>
({
"h"
,
"i"
}));
SetOp
(
&
prog
,
"concat"
,
std
::
vector
<
std
::
string
>
({
"e"
,
"g"
,
"i"
}),
std
::
vector
<
std
::
string
>
({
"j"
}));
return
prog
;
}
TEST
(
SeqPoolConcatFusePass
,
basic
)
{
auto
prog
=
BuildProgramDesc
();
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
prog
));
auto
pass
=
PassRegistry
::
Instance
().
Get
(
"seqpool_concat_fuse_pass"
);
int
pre_nodes
=
graph
->
Nodes
().
size
();
graph
=
pass
->
Apply
(
std
::
move
(
graph
));
int
after_nodes
=
graph
->
Nodes
().
size
();
// Remove 7 Nodes: op1, op2, op3, e, g, i, concat_op
// Add 1 Node: fusion_seqpool_concat
EXPECT_EQ
(
pre_nodes
-
6
,
after_nodes
);
// Assert new op in newly generated graph
int
count
=
0
;
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
IsOp
()
&&
node
->
Op
()
->
Type
()
==
"fusion_seqpool_concat"
)
{
++
count
;
}
}
EXPECT_EQ
(
count
,
1
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
USE_PASS
(
seqpool_concat_fuse_pass
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录