Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
ca6fdc6e
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ca6fdc6e
编写于
1月 13, 2019
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine and fix test
test=develop
上级
a89296ac
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
10 addition
and
31 deletion
+10
-31
paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc
paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc
+7
-30
paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc
...le/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc
+3
-1
未找到文件。
paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc
浏览文件 @
ca6fdc6e
...
@@ -94,7 +94,6 @@ PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern,
...
@@ -94,7 +94,6 @@ PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern,
return
false
;
return
false
;
}
}
auto
*
relu_op
=
x
->
inputs
[
0
];
auto
*
relu_op
=
x
->
inputs
[
0
];
// std::cout << "xxxx" << std::endl;
bool
before_is_fc
=
relu_op
->
IsOp
()
&&
relu_op
->
inputs
.
size
()
==
1
&&
bool
before_is_fc
=
relu_op
->
IsOp
()
&&
relu_op
->
inputs
.
size
()
==
1
&&
relu_op
->
inputs
[
0
]
->
IsVar
()
&&
relu_op
->
inputs
[
0
]
->
IsVar
()
&&
VarLinksFromOp
(
relu_op
->
inputs
[
0
],
"fc"
)
&&
VarLinksFromOp
(
relu_op
->
inputs
[
0
],
"fc"
)
&&
...
@@ -105,31 +104,18 @@ PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern,
...
@@ -105,31 +104,18 @@ PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern,
}
}
auto
*
fc_op
=
relu_op
->
inputs
[
0
]
->
inputs
[
0
];
auto
*
fc_op
=
relu_op
->
inputs
[
0
]
->
inputs
[
0
];
bool
is_fc
=
fc_op
->
IsOp
()
&&
fc_op
->
inputs
.
size
()
==
3
;
bool
is_fc
=
fc_op
->
IsOp
()
&&
fc_op
->
inputs
.
size
()
==
3
;
// std::cout << "*****" << fc_op->inputs.size() << std::endl;
if
(
!
is_fc
)
{
if
(
!
is_fc
)
{
return
false
;
return
false
;
}
}
for
(
size_t
kkk
=
0
;
kkk
<
3
;
++
kkk
)
{
for
(
auto
*
fc_i
:
fc_op
->
inputs
)
{
// std::cout << "++++++" << kkk << std::endl;
if
(
!
fc_i
->
inputs
.
empty
())
{
if
(
!
fc_op
->
inputs
[
kkk
]
->
inputs
.
empty
())
{
if
(
at_top
)
{
if
(
at_top
)
{
return
true
;
return
true
;
}
else
{
}
else
{
bool
res
=
VarLinksFromOp
(
fc_op
->
inputs
[
kkk
],
"relu"
);
return
VarLinksFromOp
(
fc_i
,
"relu"
);
// std::cout << fc_op->inputs[kkk]->Name() << "++++++-----" << kkk <<
// ":"
// << res << std::endl;
return
res
;
}
}
}
}
}
}
// for (auto* fc_i : fc_op->inputs) {
// if (!fc_i->inputs.empty()) {
// std::cout << "++++++" << fc_op->inputs.size()<<std::endl;
// return VarLinksFromOp(fc_i, "relu");
// }
// }
return
false
;
return
false
;
};
};
...
@@ -147,7 +133,6 @@ PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern,
...
@@ -147,7 +133,6 @@ PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern,
Node
*
x
,
int
repeated_times
,
Node
*
x
,
int
repeated_times
,
const
std
::
string
&
act_type
=
"relu"
)
->
bool
{
const
std
::
string
&
act_type
=
"relu"
)
->
bool
{
for
(
int
i
=
0
;
i
<
repeated_times
;
++
i
)
{
for
(
int
i
=
0
;
i
<
repeated_times
;
++
i
)
{
// std::cout << "----" << i << std::endl;
if
(
!
var_before_is_fc_act
(
x
,
act_type
,
i
==
repeated_times
-
1
))
{
if
(
!
var_before_is_fc_act
(
x
,
act_type
,
i
==
repeated_times
-
1
))
{
return
false
;
return
false
;
}
}
...
@@ -180,17 +165,9 @@ PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern,
...
@@ -180,17 +165,9 @@ PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern,
x
,
std
::
max
(
1
,
num_fc
-
i
-
1
),
"relu"
);
x
,
std
::
max
(
1
,
num_fc
-
i
-
1
),
"relu"
);
}
}
}
else
{
}
else
{
bool
part1
=
return
var_next_is_fc_act_repeated_n_times
(
x
,
num_fc
-
i
,
"relu"
)
&&
var_next_is_fc_act_repeated_n_times
(
x
,
num_fc
-
i
,
"relu"
)
&&
x
->
inputs
.
size
()
>
0
&&
x
->
inputs
.
size
()
>
0
;
var_before_is_fc_act_repeated_n_times
(
x
,
i
,
"relu"
);
if
(
x
->
Name
()
==
"fc_0.tmp_1"
&&
x
->
IsVar
()
&&
part1
)
{
// std::cout << "testes" << std::endl;
}
bool
part2
=
var_before_is_fc_act_repeated_n_times
(
x
,
i
,
"relu"
);
if
(
x
->
Name
()
==
"fc_0.tmp_1"
)
{
// std::cout << "========" << part1 << "," << part2 << std::endl;
}
return
part1
&&
part2
;
}
}
},
},
name_scope
+
"/fc_in_"
+
std
::
to_string
(
i
));
name_scope
+
"/fc_in_"
+
std
::
to_string
(
i
));
...
@@ -394,7 +371,7 @@ std::unique_ptr<ir::Graph> RepeatedFCReluFusePass::ApplyImpl(
...
@@ -394,7 +371,7 @@ std::unique_ptr<ir::Graph> RepeatedFCReluFusePass::ApplyImpl(
int
fusion_count
=
0
;
int
fusion_count
=
0
;
for
(
int
i
=
MAX_NUM_FC
;
i
>
1
;
--
i
)
{
for
(
int
i
=
MAX_NUM_FC
;
i
>
1
;
--
i
)
{
fusion_count
+=
fusion_count
+=
BuildFusion
(
graph
.
get
(),
name_scope_
+
"/"
+
std
::
to_string
(
3
),
3
);
BuildFusion
(
graph
.
get
(),
name_scope_
+
"/"
+
std
::
to_string
(
i
),
i
);
}
}
AddStatis
(
fusion_count
);
AddStatis
(
fusion_count
);
...
...
paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc
浏览文件 @
ca6fdc6e
...
@@ -190,8 +190,10 @@ void analysis_fuse_statis(bool use_zerocopy) {
...
@@ -190,8 +190,10 @@ void analysis_fuse_statis(bool use_zerocopy) {
ASSERT_EQ
(
fuse_statis
.
at
(
"fc_fuse"
),
10
);
ASSERT_EQ
(
fuse_statis
.
at
(
"fc_fuse"
),
10
);
ASSERT_TRUE
(
fuse_statis
.
count
(
"seqpool_concat_fuse"
));
ASSERT_TRUE
(
fuse_statis
.
count
(
"seqpool_concat_fuse"
));
EXPECT_EQ
(
fuse_statis
.
at
(
"seqpool_concat_fuse"
),
2
);
EXPECT_EQ
(
fuse_statis
.
at
(
"seqpool_concat_fuse"
),
2
);
ASSERT_TRUE
(
fuse_statis
.
count
(
"repeated_fc_relu"
));
EXPECT_EQ
(
fuse_statis
.
at
(
"repeated_fc_relu"
),
2
);
LOG
(
INFO
)
<<
"num_ops: "
<<
num_ops
;
LOG
(
INFO
)
<<
"num_ops: "
<<
num_ops
;
EXPECT_EQ
(
num_ops
,
1
9
5
);
EXPECT_EQ
(
num_ops
,
1
8
5
);
}
}
// Check the fuse status
// Check the fuse status
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录