Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
2c4cc68f
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2c4cc68f
编写于
6月 25, 2021
作者:
王
王明冬
提交者:
GitHub
6月 25, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add compat precondition for repeated_fc_relu_fuse_pass,test=develop. (#33742)
上级
98d25314
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
42 addition
and
3 deletion
+42
-3
paddle/fluid/framework/ir/op_compat_sensible_pass.cc
paddle/fluid/framework/ir/op_compat_sensible_pass.cc
+7
-0
paddle/fluid/framework/ir/op_compat_sensible_pass.h
paddle/fluid/framework/ir/op_compat_sensible_pass.h
+2
-0
paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc
paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc
+28
-2
paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.h
paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.h
+5
-1
未找到文件。
paddle/fluid/framework/ir/op_compat_sensible_pass.cc
浏览文件 @
2c4cc68f
...
...
@@ -23,6 +23,13 @@ namespace paddle {
namespace
framework
{
namespace
ir
{
AttrCompat
&
AttrCompat
::
IsStringEQ
(
const
std
::
string
&
value
)
{
conditions_
.
emplace_back
([
value
](
const
Attribute
&
attr
)
->
bool
{
return
value
==
BOOST_GET_CONST
(
std
::
string
,
attr
);
});
return
*
this
;
}
AttrCompat
&
AttrCompat
::
IsStringIn
(
const
std
::
set
<
std
::
string
>&
candidates
)
{
conditions_
.
emplace_back
([
candidates
](
const
Attribute
&
attr
)
->
bool
{
std
::
string
value
=
BOOST_GET_CONST
(
std
::
string
,
attr
);
...
...
paddle/fluid/framework/ir/op_compat_sensible_pass.h
浏览文件 @
2c4cc68f
...
...
@@ -37,6 +37,8 @@ class AttrCompat {
// @{ String-related methods
//! Assert the attribute is an string in the `candidates` domain.
AttrCompat
&
IsStringEQ
(
const
std
::
string
&
value
);
//! Assert the attribute is an string in the `candidates` domain.
AttrCompat
&
IsStringIn
(
const
std
::
set
<
std
::
string
>&
candidates
);
//! Assert the attribute is a string and match a custom judging function.
AttrCompat
&
IsStringMatch
(
...
...
paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc
浏览文件 @
2c4cc68f
...
...
@@ -31,6 +31,27 @@ namespace paddle {
namespace
framework
{
namespace
ir
{
RepeatedFCReluFusePass
::
RepeatedFCReluFusePass
()
{
AddOpCompat
(
OpCompat
(
"fc"
))
.
AddInput
(
"Input"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"W"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Bias"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"in_num_col_dims"
)
.
IsNumEQ
(
1
)
.
End
()
.
AddAttr
(
"activation_type"
)
.
IsStringEQ
(
"relu"
)
.
End
();
}
static
bool
IsInputOfFC
(
Node
*
n
)
{
if
(
n
&&
n
->
IsVar
()
&&
VarLinksToOp
(
n
,
"fc"
))
{
return
true
;
...
...
@@ -295,8 +316,9 @@ void BuildRepeatedFCReluPattern(PDPattern* pattern,
}
}
static
int
BuildFusion
(
Graph
*
graph
,
const
std
::
string
&
name_scope
,
int
num_fc
)
{
int
RepeatedFCReluFusePass
::
BuildFusion
(
Graph
*
graph
,
const
std
::
string
&
name_scope
,
int
num_fc
)
const
{
GraphPatternDetector
gpd
;
auto
*
pattern
=
gpd
.
mutable_pattern
();
BuildRepeatedFCReluPattern
(
pattern
,
name_scope
,
num_fc
);
...
...
@@ -316,6 +338,10 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
int
fusion_count
{
0
};
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
if
(
!
IsCompat
(
subgraph
,
g
))
{
LOG
(
WARNING
)
<<
"repeated_fc_relu_fuse_pass failed in op compat."
;
return
;
}
LOG
(
INFO
)
<<
"handle Repeated FC Act fuse"
;
std
::
vector
<
Node
*>
weights_vars
(
num_fc
);
std
::
vector
<
Node
*>
bias_vars
(
num_fc
);
...
...
paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.h
浏览文件 @
2c4cc68f
...
...
@@ -31,12 +31,16 @@ class Graph;
class
RepeatedFCReluFusePass
:
public
FusePassBase
{
public:
virtual
~
RepeatedFCReluFusePass
()
{}
RepeatedFCReluFusePass
();
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
const
std
::
string
name_scope_
{
"repeated_fc_relu_fuse"
};
private:
int
BuildFusion
(
Graph
*
graph
,
const
std
::
string
&
name_scope
,
int
num_fc
)
const
;
};
}
// namespace ir
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录