Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
478a4e85
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
478a4e85
编写于
9月 10, 2018
作者:
Y
Yan Chunwei
提交者:
GitHub
9月 10, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor ir pattern (#13304)
上级
14242eae
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
316 addition
and
240 deletion
+316
-240
paddle/fluid/framework/ir/fc_fuse_pass.cc
paddle/fluid/framework/ir/fc_fuse_pass.cc
+11
-22
paddle/fluid/framework/ir/fc_gru_fuse_pass.cc
paddle/fluid/framework/ir/fc_gru_fuse_pass.cc
+44
-62
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
+61
-91
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+62
-61
paddle/fluid/framework/ir/graph_pattern_detector.h
paddle/fluid/framework/ir/graph_pattern_detector.h
+130
-4
paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc
paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc
+6
-0
paddle/fluid/inference/analysis/analyzer_tester.cc
paddle/fluid/inference/analysis/analyzer_tester.cc
+2
-0
未找到文件。
paddle/fluid/framework/ir/fc_fuse_pass.cc
浏览文件 @
478a4e85
...
@@ -29,39 +29,27 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
...
@@ -29,39 +29,27 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
std
::
unordered_set
<
Node
*>
nodes2delete
;
std
::
unordered_set
<
Node
*>
nodes2delete
;
GraphPatternDetector
gpd
;
GraphPatternDetector
gpd
;
// BuildFCPattern(gpd.mutable_pattern());
auto
*
x
=
gpd
.
mutable_pattern
()
auto
*
x
=
gpd
.
mutable_pattern
()
->
NewNode
(
"fc_fuse/x"
)
->
NewNode
(
"fc_fuse/x"
)
->
AsInput
()
->
AsInput
()
->
assert_is_op_input
(
"mul"
,
"X"
);
->
assert_is_op_input
(
"mul"
,
"X"
);
patterns
::
FC
(
gpd
.
mutable_pattern
(),
"fc_fuse"
,
x
,
true
/*with bias*/
);
patterns
::
FC
fc_pattern
(
gpd
.
mutable_pattern
(),
"fc_fuse"
);
fc_pattern
(
x
,
true
/*with bias*/
);
#define GET_NODE(id) \
PADDLE_ENFORCE(subgraph.count(gpd.pattern().RetrieveNode("fc_fuse/" #id)), \
"pattern has no Node called %s", #id); \
auto* id = subgraph.at(gpd.pattern().RetrieveNode("fc_fuse/" #id)); \
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", "fc_fuse/" #id);
int
found_fc_count
=
0
;
int
found_fc_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
Graph
*
g
)
{
VLOG
(
4
)
<<
"handle FC fuse"
;
VLOG
(
4
)
<<
"handle FC fuse"
;
// Currently, there is no FC op available, so I will just simulate the
GET_IR_NODE_FROM_SUBGRAPH
(
w
,
w
,
fc_pattern
);
// scenerio.
GET_IR_NODE_FROM_SUBGRAPH
(
fc_bias
,
bias
,
fc_pattern
);
// FC's fusion is simple, just op fuse, no need to process the
GET_IR_NODE_FROM_SUBGRAPH
(
fc_out
,
Out
,
fc_pattern
);
// parameters.
GET_IR_NODE_FROM_SUBGRAPH
(
mul
,
mul
,
fc_pattern
);
GET_NODE
(
x
);
// x
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_add
,
elementwise_add
,
fc_pattern
);
GET_NODE
(
w
);
// Y
GET_IR_NODE_FROM_SUBGRAPH
(
mul_out
,
mul_out
,
fc_pattern
);
GET_NODE
(
fc_bias
);
// bias
GET_NODE
(
fc_out
);
// Out
GET_NODE
(
mul
);
// MUL op
GET_NODE
(
elementwise_add
);
// ELEMENT_ADD op
GET_NODE
(
mul_out
);
// tmp
#undef GET_NODE
// Create an FC Node.
// Create an FC Node.
OpDesc
desc
;
OpDesc
desc
;
std
::
string
fc_x_in
=
x
->
Name
();
std
::
string
fc_x_in
=
subgraph
.
at
(
x
)
->
Name
();
std
::
string
fc_Y_in
=
w
->
Name
();
std
::
string
fc_Y_in
=
w
->
Name
();
std
::
string
fc_bias_in
=
fc_bias
->
Name
();
std
::
string
fc_bias_in
=
fc_bias
->
Name
();
std
::
string
fc_out_out
=
fc_out
->
Name
();
std
::
string
fc_out_out
=
fc_out
->
Name
();
...
@@ -73,7 +61,8 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
...
@@ -73,7 +61,8 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
auto
fc_node
=
g
->
CreateOpNode
(
&
desc
);
// OpDesc will be copied.
auto
fc_node
=
g
->
CreateOpNode
(
&
desc
);
// OpDesc will be copied.
GraphSafeRemoveNodes
(
graph
.
get
(),
{
mul
,
elementwise_add
,
mul_out
});
GraphSafeRemoveNodes
(
graph
.
get
(),
{
mul
,
elementwise_add
,
mul_out
});
IR_NODE_LINK_TO
(
x
,
fc_node
);
PADDLE_ENFORCE
(
subgraph
.
count
(
x
));
IR_NODE_LINK_TO
(
subgraph
.
at
(
x
),
fc_node
);
IR_NODE_LINK_TO
(
w
,
fc_node
);
IR_NODE_LINK_TO
(
w
,
fc_node
);
IR_NODE_LINK_TO
(
fc_bias
,
fc_node
);
IR_NODE_LINK_TO
(
fc_bias
,
fc_node
);
IR_NODE_LINK_TO
(
fc_node
,
fc_out
);
IR_NODE_LINK_TO
(
fc_node
,
fc_out
);
...
...
paddle/fluid/framework/ir/fc_gru_fuse_pass.cc
浏览文件 @
478a4e85
...
@@ -20,52 +20,43 @@ namespace paddle {
...
@@ -20,52 +20,43 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
ir
{
namespace
ir
{
static
void
BuildPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
,
bool
with_fc_bias
)
{
PDNode
*
x
=
pattern
->
NewNode
(
name_scope
,
"x"
)
->
assert_is_op_input
(
"mul"
)
->
assert_var_not_persistable
();
auto
*
fc_out
=
patterns
::
FC
(
pattern
,
name_scope
,
x
,
with_fc_bias
);
fc_out
->
AsIntermediate
();
// fc_out is a tmp var, will be removed after fuse.
patterns
::
GRU
(
pattern
,
name_scope
,
fc_out
);
VLOG
(
3
)
<<
"fc_gru pattern
\n
"
<<
pattern
->
DotString
();
}
static
int
BuildFusion
(
Graph
*
graph
,
const
std
::
string
&
name_scope
,
static
int
BuildFusion
(
Graph
*
graph
,
const
std
::
string
&
name_scope
,
Scope
*
scope
,
bool
with_fc_bias
)
{
Scope
*
scope
,
bool
with_fc_bias
)
{
GraphPatternDetector
gpd
;
GraphPatternDetector
gpd
;
auto
*
pattern
=
gpd
.
mutable_pattern
();
auto
*
pattern
=
gpd
.
mutable_pattern
();
BuildPattern
(
pattern
,
name_scope
,
with_fc_bias
);
// Create pattern.
patterns
::
FC
fc_pattern
(
pattern
,
name_scope
);
patterns
::
GRU
gru_pattern
(
pattern
,
name_scope
);
PDNode
*
x
=
pattern
->
NewNode
(
patterns
::
UniqueKey
(
"x"
))
->
assert_var_not_persistable
();
auto
*
fc_out
=
fc_pattern
(
x
,
with_fc_bias
);
fc_out
->
AsIntermediate
();
// fc_out is a tmp var, will be removed after fuse.
gru_pattern
(
fc_out
);
// Create New OpDesc
// Create New OpDesc
auto
gru_creater
=
[
&
](
int
gru
,
int
x
,
int
weight_x
,
int
weight_h
,
int
bias
,
auto
gru_creater
=
[
&
](
Node
*
gru
,
Node
*
x
,
Node
*
weight_x
,
Node
*
weight_h
,
int
hidden
,
int
fc_bias
)
{
Node
*
bias
,
Node
*
hidden
,
Node
*
fc_bias
)
{
#define GET_NODE(x) auto* x##_n = graph->RetriveNode(x);
GET_NODE
(
x
);
GET_NODE
(
weight_x
);
GET_NODE
(
weight_h
);
GET_NODE
(
bias
);
GET_NODE
(
hidden
);
GET_NODE
(
gru
);
OpDesc
op_desc
;
OpDesc
op_desc
;
op_desc
.
SetType
(
"fusion_gru"
);
op_desc
.
SetType
(
"fusion_gru"
);
#define NEW_NAME(x) name_scope + "/at." #x ".new"
#define NEW_NAME(x) name_scope + "/at." #x ".new"
#define SET_IN(Key, node__) op_desc.SetInput(#Key, {node__
##_n
->Name()});
#define SET_IN(Key, node__) op_desc.SetInput(#Key, {node__->Name()});
SET_IN
(
X
,
x
);
SET_IN
(
X
,
x
);
SET_IN
(
WeightX
,
weight_x
);
SET_IN
(
WeightX
,
weight_x
);
SET_IN
(
WeightH
,
weight_h
);
SET_IN
(
WeightH
,
weight_h
);
if
(
with_fc_bias
)
{
if
(
with_fc_bias
)
{
op_desc
.
SetInput
(
"Bias"
,
{
NEW_NAME
(
bias
)
+
bias
_n
->
Name
()});
op_desc
.
SetInput
(
"Bias"
,
{
NEW_NAME
(
bias
)
+
bias
->
Name
()});
}
else
{
}
else
{
SET_IN
(
Bias
,
bias
);
SET_IN
(
Bias
,
bias
);
}
}
#undef SET_IN
#undef SET_IN
op_desc
.
SetInput
(
"H0"
,
{});
op_desc
.
SetInput
(
"H0"
,
{});
op_desc
.
SetOutput
(
"Hidden"
,
{
hidden
_n
->
Name
()});
op_desc
.
SetOutput
(
"Hidden"
,
{
hidden
->
Name
()});
op_desc
.
SetAttr
(
"is_reverse"
,
gru
_n
->
Op
()
->
GetAttr
(
"is_reverse"
));
op_desc
.
SetAttr
(
"is_reverse"
,
gru
->
Op
()
->
GetAttr
(
"is_reverse"
));
// TODO(TJ): This should be a option for infer
// TODO(TJ): This should be a option for infer
op_desc
.
SetAttr
(
"use_seq"
,
true
);
op_desc
.
SetAttr
(
"use_seq"
,
true
);
...
@@ -82,14 +73,12 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
...
@@ -82,14 +73,12 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
PADDLE_ENFORCE
(
scope
);
PADDLE_ENFORCE
(
scope
);
if
(
with_fc_bias
)
{
if
(
with_fc_bias
)
{
// Fusion GRU bias = fcbias + grubias
// Fusion GRU bias = fcbias + grubias
auto
*
fusion_bias_var
=
scope
->
Var
(
NEW_NAME
(
bias
)
+
bias
_n
->
Name
());
auto
*
fusion_bias_var
=
scope
->
Var
(
NEW_NAME
(
bias
)
+
bias
->
Name
());
auto
*
out_bias_tensor
=
auto
*
out_bias_tensor
=
fusion_bias_var
->
GetMutable
<
framework
::
LoDTensor
>
();
fusion_bias_var
->
GetMutable
<
framework
::
LoDTensor
>
();
PADDLE_ENFORCE
(
fusion_bias_var
);
PADDLE_ENFORCE
(
fusion_bias_var
);
GET_NODE
(
fc_bias
);
auto
*
gru_bias_var
=
scope
->
FindVar
(
bias
->
Name
());
PADDLE_ENFORCE
(
fc_bias_n
);
auto
*
fc_bias_var
=
scope
->
FindVar
(
fc_bias
->
Name
());
auto
*
gru_bias_var
=
scope
->
FindVar
(
bias_n
->
Name
());
auto
*
fc_bias_var
=
scope
->
FindVar
(
fc_bias_n
->
Name
());
PADDLE_ENFORCE
(
gru_bias_var
);
PADDLE_ENFORCE
(
gru_bias_var
);
PADDLE_ENFORCE
(
fc_bias_var
);
PADDLE_ENFORCE
(
fc_bias_var
);
const
auto
&
gru_bias_tenosr
=
gru_bias_var
->
Get
<
framework
::
LoDTensor
>
();
const
auto
&
gru_bias_tenosr
=
gru_bias_var
->
Get
<
framework
::
LoDTensor
>
();
...
@@ -113,11 +102,11 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
...
@@ -113,11 +102,11 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
#undef NEW_NAME
#undef NEW_NAME
#undef NEW_IMTERMEDIATE_OUT
#undef NEW_IMTERMEDIATE_OUT
IR_NODE_LINK_TO
(
x
_n
,
op
);
IR_NODE_LINK_TO
(
x
,
op
);
IR_NODE_LINK_TO
(
weight_x
_n
,
op
);
IR_NODE_LINK_TO
(
weight_x
,
op
);
IR_NODE_LINK_TO
(
weight_h
_n
,
op
);
IR_NODE_LINK_TO
(
weight_h
,
op
);
IR_NODE_LINK_TO
(
bias
_n
,
op
);
// actually should link to new bias if have
IR_NODE_LINK_TO
(
bias
,
op
);
// actually should link to new bias if have
IR_NODE_LINK_TO
(
op
,
hidden
_n
);
IR_NODE_LINK_TO
(
op
,
hidden
);
// h0?
// h0?
return
op
;
return
op
;
};
};
...
@@ -125,42 +114,35 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
...
@@ -125,42 +114,35 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
int
fusion_count
{
0
};
int
fusion_count
{
0
};
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
Graph
*
g
)
{
#define GET_NODE(name__) \
auto
*
x_n
=
subgraph
.
at
(
x
);
std::string name__##key = name_scope + "/" + #name__; \
GET_IR_NODE_FROM_SUBGRAPH
(
w
,
w
,
fc_pattern
);
auto* name__##n = pattern->RetrieveNode(name__##key); \
GET_IR_NODE_FROM_SUBGRAPH
(
mul
,
mul
,
fc_pattern
);
PADDLE_ENFORCE(name__##n); \
GET_IR_NODE_FROM_SUBGRAPH
(
fc_out
,
Out
,
fc_pattern
);
PADDLE_ENFORCE(subgraph.count(name__##n)); \
GET_IR_NODE_FROM_SUBGRAPH
(
Weight
,
Weight
,
gru_pattern
);
Node* name__##_n = subgraph.at(name__##n); \
GET_IR_NODE_FROM_SUBGRAPH
(
gru
,
gru
,
gru_pattern
);
int name__ __attribute__((unused)) = name__##_n->id();
GET_IR_NODE_FROM_SUBGRAPH
(
Bias
,
Bias
,
gru_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
Hidden
,
Hidden
,
gru_pattern
);
GET_NODE
(
x
);
GET_NODE
(
w
);
// fc weight
GET_NODE
(
mul
);
GET_NODE
(
fc_out
);
GET_NODE
(
Weight
);
GET_NODE
(
gru
);
GET_NODE
(
Bias
);
GET_NODE
(
Hidden
);
// nodes need be removed
// nodes need be removed
GET_
NODE
(
BatchGate
);
GET_
IR_NODE_FROM_SUBGRAPH
(
BatchGate
,
BatchGate
,
gru_pattern
);
GET_
NODE
(
BatchResetHiddenPrev
);
GET_
IR_NODE_FROM_SUBGRAPH
(
BatchResetHiddenPrev
,
BatchGate
,
gru_pattern
);
GET_
NODE
(
BatchHidde
n
);
GET_
IR_NODE_FROM_SUBGRAPH
(
BatchHidden
,
BatchGate
,
gru_patter
n
);
if
(
with_fc_bias
)
{
if
(
with_fc_bias
)
{
GET_NODE
(
mul_out
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul_out
,
mul_out
,
fc_pattern
);
GET_NODE
(
fc_bias
);
GET_IR_NODE_FROM_SUBGRAPH
(
fc_bias
,
bias
,
fc_pattern
);
GET_NODE
(
elementwise_add
);
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_add
,
elementwise_add
,
fc_pattern
);
gru_creater
(
gru
,
x
,
w
,
Weight
,
Bias
,
Hidden
,
fc_bias
);
gru_creater
(
gru
,
x_n
,
w
,
Weight
,
Bias
,
Hidden
,
fc_bias
);
// Remove unneeded nodes.
// Remove unneeded nodes.
std
::
unordered_set
<
const
Node
*>
marked_nodes
(
std
::
unordered_set
<
const
Node
*>
marked_nodes
(
{
mul
_n
,
gru_n
,
elementwise_add_n
,
fc_bias_n
,
fc_out_n
,
mul_out_n
,
{
mul
,
gru
,
elementwise_add
,
fc_bias
,
fc_out
,
mul_out
,
BatchGate
,
Batch
Gate_n
,
BatchResetHiddenPrev_n
,
BatchHidden_
n
});
Batch
ResetHiddenPrev
,
BatchHidde
n
});
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
}
else
{
}
else
{
gru_creater
(
gru
,
x
,
w
,
Weight
,
Bias
,
Hidden
,
-
1
);
gru_creater
(
gru
,
x
_n
,
w
,
Weight
,
Bias
,
Hidden
,
nullptr
);
// Remove unneeded nodes.
// Remove unneeded nodes.
std
::
unordered_set
<
const
Node
*>
marked_nodes
(
std
::
unordered_set
<
const
Node
*>
marked_nodes
(
{
mul
_n
,
gru_n
,
BatchGate_n
,
BatchResetHiddenPrev_n
,
BatchHidden_
n
});
{
mul
,
gru
,
BatchGate
,
BatchResetHiddenPrev
,
BatchHidde
n
});
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
}
}
#undef GET_NODE
#undef GET_NODE
...
...
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
浏览文件 @
478a4e85
...
@@ -20,45 +20,29 @@ namespace paddle {
...
@@ -20,45 +20,29 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
ir
{
namespace
ir
{
static
std
::
string
GenNodeName
(
const
std
::
string
&
prefix
,
int
BuildFusion
(
Graph
*
graph
,
const
std
::
string
&
name_scope
,
Scope
*
scope
,
const
std
::
string
&
name
)
{
bool
with_fc_bias
)
{
return
prefix
+
"/"
+
name
;
GraphPatternDetector
gpd
;
}
auto
*
pattern
=
gpd
.
mutable_pattern
();
static
void
BuildPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
,
// Build pattern
bool
with_fc_bias
)
{
PDNode
*
x
=
pattern
->
NewNode
(
patterns
::
PDNodeName
(
name_scope
,
"x"
))
PDNode
*
x
=
pattern
->
NewNode
(
name_scope
,
"x"
)
->
assert_is_op_input
(
"mul"
)
->
assert_is_op_input
(
"mul"
)
->
assert_var_not_persistable
();
->
assert_var_not_persistable
();
auto
*
fc_out
=
patterns
::
FC
(
pattern
,
name_scope
,
x
,
with_fc_bias
);
patterns
::
FC
fc_pattern
(
pattern
,
name_scope
);
fc_out
->
AsIntermediate
();
// fc_out is a tmp var, will be removed after fuse.
patterns
::
LSTM
(
pattern
,
name_scope
,
fc_out
);
// LOG(INFO) << "\n" << pattern->DotString();
}
static
int
BuildFusion
(
Graph
*
graph
,
const
std
::
string
&
name_scope
,
Scope
*
scope
,
bool
with_fc_bias
)
{
GraphPatternDetector
gpd
;
auto
*
pattern
=
gpd
.
mutable_pattern
();
BuildPattern
(
pattern
,
name_scope
,
with_fc_bias
);
// fc_out is a tmp var, will be removed after fuse, so marked as intermediate.
auto
*
fc_out
=
fc_pattern
(
x
,
with_fc_bias
)
->
AsIntermediate
();
patterns
::
LSTM
lstm_pattern
(
pattern
,
name_scope
);
lstm_pattern
(
fc_out
);
// Create New OpDesc
// Create New OpDesc
auto
lstm_creator
=
[
&
](
int
lstm
,
int
input
,
int
weight_x
,
int
weight_h
,
auto
lstm_creator
=
[
&
](
Node
*
lstm
,
Node
*
input
,
Node
*
weight_x
,
int
bias
,
int
hidden
,
int
cell
,
int
xx
,
int
fc_bias
)
{
Node
*
weight_h
,
Node
*
bias
,
Node
*
hidden
,
Node
*
cell
,
#define GET_NODE(x) auto* x##_n = graph->RetriveNode(x);
Node
*
xx
,
Node
*
fc_bias
)
{
GET_NODE
(
input
);
GET_NODE
(
weight_x
);
GET_NODE
(
weight_h
);
GET_NODE
(
bias
);
GET_NODE
(
hidden
);
GET_NODE
(
cell
);
GET_NODE
(
xx
);
GET_NODE
(
lstm
);
OpDesc
op_desc
;
OpDesc
op_desc
;
op_desc
.
SetType
(
"fusion_lstm"
);
op_desc
.
SetType
(
"fusion_lstm"
);
#define SET_IN(Key, node__) op_desc.SetInput(#Key, {node__
##_n
->Name()});
#define SET_IN(Key, node__) op_desc.SetInput(#Key, {node__->Name()});
SET_IN
(
X
,
input
);
SET_IN
(
X
,
input
);
SET_IN
(
WeightX
,
weight_x
);
SET_IN
(
WeightX
,
weight_x
);
SET_IN
(
WeightH
,
weight_h
);
SET_IN
(
WeightH
,
weight_h
);
...
@@ -71,13 +55,12 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
...
@@ -71,13 +55,12 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
auto
*
bias_var
=
scope
->
Var
(
new_bias_var
);
auto
*
bias_var
=
scope
->
Var
(
new_bias_var
);
PADDLE_ENFORCE
(
bias_var
);
PADDLE_ENFORCE
(
bias_var
);
auto
*
bias_tensor
=
bias_var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
bias_tensor
=
bias_var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
lstm_bias_var
=
scope
->
FindVar
(
bias
_n
->
Name
());
auto
*
lstm_bias_var
=
scope
->
FindVar
(
bias
->
Name
());
PADDLE_ENFORCE
(
lstm_bias_var
);
PADDLE_ENFORCE
(
lstm_bias_var
);
const
auto
&
lstm_bias_tensor
=
lstm_bias_var
->
Get
<
framework
::
LoDTensor
>
();
const
auto
&
lstm_bias_tensor
=
lstm_bias_var
->
Get
<
framework
::
LoDTensor
>
();
bias_tensor
->
Resize
(
lstm_bias_tensor
.
dims
());
bias_tensor
->
Resize
(
lstm_bias_tensor
.
dims
());
GET_NODE
(
fc_bias
);
auto
*
fc_bias_var
=
scope
->
FindVar
(
fc_bias
->
Name
());
auto
*
fc_bias_var
=
scope
->
FindVar
(
fc_bias_n
->
Name
());
const
auto
&
fc_bias_tensor
=
fc_bias_var
->
Get
<
framework
::
LoDTensor
>
();
const
auto
&
fc_bias_tensor
=
fc_bias_var
->
Get
<
framework
::
LoDTensor
>
();
auto
*
data
=
bias_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
auto
*
data
=
bias_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
...
@@ -88,31 +71,36 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
...
@@ -88,31 +71,36 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
}
}
op_desc
.
SetInput
(
"Bias"
,
{
new_bias_var
});
op_desc
.
SetInput
(
"Bias"
,
{
new_bias_var
});
}
}
#undef GET_NODE
// Create temp variables.
// Create temp variables.
scope
->
Var
(
name_scope
+
"/BatchedInput.new"
)
const
std
::
string
BatchedInput
=
patterns
::
UniqueKey
(
"BatchedInput"
);
->
GetMutable
<
framework
::
LoDTensor
>
();
const
std
::
string
BatchedCellPreAct
=
scope
->
Var
(
name_scope
+
"/BatchCellPreAct.new"
)
patterns
::
UniqueKey
(
"BatchedCellPreAct"
);
->
GetMutable
<
framework
::
LoDTensor
>
();
const
std
::
string
BatchedGate
=
patterns
::
UniqueKey
(
"BatchedGate"
);
scope
->
Var
(
name_scope
+
"/BatchedGate.new"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
scope
->
Var
(
BatchedInput
)
->
GetMutable
<
framework
::
LoDTensor
>
();
scope
->
Var
(
BatchedCellPreAct
)
->
GetMutable
<
framework
::
LoDTensor
>
();
scope
->
Var
(
BatchedGate
)
->
GetMutable
<
framework
::
LoDTensor
>
();
op_desc
.
SetInput
(
"H0"
,
{});
op_desc
.
SetInput
(
"H0"
,
{});
op_desc
.
SetInput
(
"C0"
,
{});
op_desc
.
SetInput
(
"C0"
,
{});
op_desc
.
SetOutput
(
"Hidden"
,
{
hidden
_n
->
Name
()});
op_desc
.
SetOutput
(
"Hidden"
,
{
hidden
->
Name
()});
op_desc
.
SetOutput
(
"Cell"
,
{
cell
_n
->
Name
()});
op_desc
.
SetOutput
(
"Cell"
,
{
cell
->
Name
()});
op_desc
.
SetOutput
(
"XX"
,
{
xx
_n
->
Name
()});
op_desc
.
SetOutput
(
"XX"
,
{
xx
->
Name
()});
op_desc
.
SetOutput
(
"BatchedGate"
,
{
name_scope
+
"/BatchedGate.new"
});
op_desc
.
SetOutput
(
"BatchedGate"
,
{
BatchedGate
});
op_desc
.
SetOutput
(
"BatchCellPreAct"
,
{
name_scope
+
"/BatchCellPreAct.new"
});
op_desc
.
SetOutput
(
"BatchCellPreAct"
,
{
BatchedCellPreAct
});
op_desc
.
SetOutput
(
"BatchedInput"
,
{
name_scope
+
"/BatchedInput.new"
});
op_desc
.
SetOutput
(
"BatchedInput"
,
{
BatchedInput
});
op_desc
.
SetAttr
(
"is_reverse"
,
lstm
_n
->
Op
()
->
GetAttr
(
"is_reverse"
));
op_desc
.
SetAttr
(
"is_reverse"
,
lstm
->
Op
()
->
GetAttr
(
"is_reverse"
));
op_desc
.
SetAttr
(
"use_peepholes"
,
lstm
_n
->
Op
()
->
GetAttr
(
"use_peepholes"
));
op_desc
.
SetAttr
(
"use_peepholes"
,
lstm
->
Op
()
->
GetAttr
(
"use_peepholes"
));
// TODO(TJ): get from attr
// TODO(TJ): get from attr
op_desc
.
SetAttr
(
"use_seq"
,
true
);
op_desc
.
SetAttr
(
"use_seq"
,
true
);
#define TMP_NAME(x) "at.new.tmp." #x
PADDLE_ENFORCE
(
graph
->
Has
(
kParamScopeAttr
));
#define OP_SET_OUT(x) op_desc.SetOutput(#x, {TMP_NAME(x)})
auto
*
scope
=
graph
->
Get
<
Scope
*>
(
kParamScopeAttr
);
#define OP_SET_OUT(x) \
const std::string x = patterns::UniqueKey(#x); \
op_desc.SetOutput(#x, {x}); \
scope->Var(x)->GetMutable<LoDTensor>()
OP_SET_OUT
(
BatchedCell
);
OP_SET_OUT
(
BatchedCell
);
OP_SET_OUT
(
BatchedHidden
);
OP_SET_OUT
(
BatchedHidden
);
OP_SET_OUT
(
ReorderedH0
);
OP_SET_OUT
(
ReorderedH0
);
...
@@ -120,22 +108,11 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
...
@@ -120,22 +108,11 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
#undef OP_SET_OUT
#undef OP_SET_OUT
auto
*
op
=
graph
->
CreateOpNode
(
&
op_desc
);
auto
*
op
=
graph
->
CreateOpNode
(
&
op_desc
);
PADDLE_ENFORCE
(
graph
->
Has
(
kParamScopeAttr
));
IR_NODE_LINK_TO
(
input
,
op
);
auto
*
scope
=
graph
->
Get
<
Scope
*>
(
kParamScopeAttr
);
IR_NODE_LINK_TO
(
weight_x
,
op
);
IR_NODE_LINK_TO
(
weight_h
,
op
);
#define TMP_NEW(x) scope->Var(TMP_NAME(x))->GetMutable<LoDTensor>()
IR_NODE_LINK_TO
(
bias
,
op
);
TMP_NEW
(
BatchedCell
);
IR_NODE_LINK_TO
(
op
,
hidden
);
TMP_NEW
(
BatchedHidden
);
TMP_NEW
(
ReorderedH0
);
TMP_NEW
(
ReorderedC0
);
#undef TMP_NEW
#undef TMP_NAME
IR_NODE_LINK_TO
(
input_n
,
op
);
IR_NODE_LINK_TO
(
weight_x_n
,
op
);
IR_NODE_LINK_TO
(
weight_h_n
,
op
);
IR_NODE_LINK_TO
(
bias_n
,
op
);
IR_NODE_LINK_TO
(
op
,
hidden_n
);
return
op
;
return
op
;
};
};
...
@@ -143,39 +120,32 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
...
@@ -143,39 +120,32 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
Graph
*
g
)
{
#define GET_NODE(name__) \
std::string name__##key = name_scope + "/" + #name__; \
auto* name__##n = pattern->RetrieveNode(name__##key); \
PADDLE_ENFORCE(name__##n); \
PADDLE_ENFORCE(subgraph.count(name__##n)); \
Node* name__##_n = subgraph.at(name__##n); \
int name__ __attribute__((unused)) = name__##_n->id();
GET_NODE
(
x
);
GET_NODE
(
w
);
GET_NODE
(
mul
);
GET_NODE
(
fc_out
);
GET_NODE
(
Weight
);
GET_NODE
(
lstm
);
GET_NODE
(
Bias
);
GET_NODE
(
Hidden
);
GET_NODE
(
Cell
);
GET_IR_NODE_FROM_SUBGRAPH
(
lstm
,
lstm
,
lstm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
Weight
,
Weight
,
lstm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
Bias
,
Bias
,
lstm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
Cell
,
Cell
,
lstm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
Hidden
,
Hidden
,
lstm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
w
,
w
,
fc_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul
,
mul
,
fc_pattern
);
if
(
with_fc_bias
)
{
if
(
with_fc_bias
)
{
GET_NODE
(
fc_bias
);
GET_IR_NODE_FROM_SUBGRAPH
(
fc_out
,
Out
,
fc_pattern
);
GET_NODE
(
elementwise_add
);
GET_IR_NODE_FROM_SUBGRAPH
(
fc_bias
,
bias
,
fc_pattern
);
lstm_creator
(
lstm
,
x
,
w
,
Weight
,
Bias
,
Hidden
,
Cell
,
fc_out
,
fc_bias
);
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_add
,
elementwise_add
,
fc_pattern
);
lstm_creator
(
lstm
,
subgraph
.
at
(
x
),
w
,
Weight
,
Bias
,
Hidden
,
Cell
,
fc_out
,
fc_bias
);
// Remove unneeded nodes.
// Remove unneeded nodes.
std
::
unordered_set
<
const
Node
*>
marked_nodes
(
std
::
unordered_set
<
const
Node
*>
marked_nodes
(
{
mul
_n
,
lstm_n
,
elementwise_add_n
});
{
mul
,
lstm
,
elementwise_add
});
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
}
else
{
}
else
{
lstm_creator
(
lstm
,
x
,
w
,
Weight
,
Bias
,
Hidden
,
Cell
,
fc_out
,
-
1
);
GET_IR_NODE_FROM_SUBGRAPH
(
fc_out
,
mul_out
,
fc_pattern
);
lstm_creator
(
lstm
,
subgraph
.
at
(
x
),
w
,
Weight
,
Bias
,
Hidden
,
Cell
,
fc_out
,
nullptr
);
// Remove unneeded nodes.
// Remove unneeded nodes.
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
mul
_n
,
lstm_n
});
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
mul
,
lstm
});
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
}
}
#undef GET_NODE
++
fusion_count
;
++
fusion_count
;
};
};
...
...
paddle/fluid/framework/ir/graph_pattern_detector.cc
浏览文件 @
478a4e85
...
@@ -21,6 +21,7 @@
...
@@ -21,6 +21,7 @@
#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/printf.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -106,8 +107,7 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) {
...
@@ -106,8 +107,7 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) {
for
(
auto
&
pdnode
:
pattern_
.
nodes
())
{
for
(
auto
&
pdnode
:
pattern_
.
nodes
())
{
if
(
!
pdnodes2nodes_
.
count
(
pdnode
.
get
()))
{
if
(
!
pdnodes2nodes_
.
count
(
pdnode
.
get
()))
{
VLOG
(
4
)
<<
pdnode
->
name
()
<<
" can't find matched Node, early stop"
;
VLOG
(
4
)
<<
pdnode
->
name
()
<<
" can't find matched Node, early stop"
;
// return false;
return
false
;
}
}
}
}
for
(
auto
&
item
:
pdnodes2nodes_
)
{
for
(
auto
&
item
:
pdnodes2nodes_
)
{
...
@@ -517,87 +517,89 @@ bool VarLinksFromOp(Node* node, const std::string& op_type) {
...
@@ -517,87 +517,89 @@ bool VarLinksFromOp(Node* node, const std::string& op_type) {
return
false
;
return
false
;
}
}
PDNode
*
patterns
::
FC
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
,
PDNode
*
patterns
::
FC
::
operator
()(
paddle
::
framework
::
ir
::
PDNode
*
x
,
PDNode
*
x
,
bool
with_bias
)
{
bool
with_bias
)
{
// mul op
// Create shared nodes.
auto
*
mul_op
=
pattern
->
NewNode
(
name_scope
,
"mul"
)
->
assert_is_op
(
"mul"
);
x
->
assert_is_op_input
(
"mul"
,
"X"
);
auto
*
mul_weight_var
=
pattern
->
NewNode
(
name_scope
,
"w"
)
auto
*
mul
=
pattern
->
NewNode
(
mul_repr
())
->
assert_is_op
(
"mul"
);
->
AsInput
()
->
assert_is_persistable_var
()
auto
*
mul_w_var
=
pattern
->
NewNode
(
w_repr
())
->
assert_is_op_input
(
"mul"
,
"Y"
);
->
AsInput
()
->
assert_is_persistable_var
()
PDNode
*
fc_out
{
nullptr
};
->
assert_is_op_input
(
"mul"
,
"Y"
);
if
(
with_bias
)
{
PDNode
*
elementwise_add_op
{
nullptr
};
auto
*
mul_out_var
=
PDNode
*
mul_out_var
{
nullptr
},
*
bias
{
nullptr
};
pattern
->
NewNode
(
mul_out_repr
())
->
assert_is_op_output
(
"mul"
);
elementwise_add_op
=
pattern
->
NewNode
(
name_scope
,
"elementwise_add"
)
->
assert_is_op
(
"elementwise_add"
);
if
(
!
with_bias
)
{
// not with bias
// intermediate variable, will be removed in the IR after fuse.
// Add links.
mul_out_var
=
pattern
->
NewNode
(
name_scope
,
"mul_out"
)
mul
->
LinksFrom
({
x
,
mul_w_var
}).
LinksTo
({
mul_out_var
});
->
AsIntermediate
()
return
mul_out_var
;
->
assert_is_only_output_of_op
(
"mul"
)
->
assert_is_op_input
(
"elementwise_add"
);
}
else
{
// with bias
// bias
mul_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
bias
=
pattern
->
NewNode
(
name_scope
,
"fc_bias"
)
// Create operators.
->
AsInput
()
auto
*
elementwise_add
=
pattern
->
NewNode
(
elementwise_add_repr
())
->
assert_is_op_input
(
"elementwise_add"
);
->
assert_is_op
(
"elementwise_add"
);
// output
// Create variables.
fc_out
=
pattern
->
NewNode
(
name_scope
,
"fc_out"
)
auto
*
bias
=
pattern
->
NewNode
(
bias_repr
())
->
AsOutput
()
->
assert_is_op_input
(
"elementwise_add"
)
->
assert_is_op_output
(
"elementwise_add"
);
->
AsInput
();
mul_op
->
LinksFrom
({
x
,
mul_weight_var
}).
LinksTo
({
mul_out_var
});
elementwise_add_op
->
LinksFrom
({
mul_out_var
,
bias
}).
LinksTo
({
fc_out
});
auto
*
fc_out
=
pattern
->
NewNode
(
Out_repr
())
}
else
{
->
AsOutput
()
fc_out
=
pattern
->
NewNode
(
name_scope
,
"fc_out"
)
->
assert_is_op_output
(
"elementwise_add"
);
->
AsOutput
()
->
assert_is_op_output
(
"mul"
);
mul
->
LinksFrom
({
mul_w_var
,
x
}).
LinksTo
({
mul_out_var
});
mul_op
->
LinksFrom
({
mul_weight_var
,
x
}).
LinksTo
({
fc_out
});
elementwise_add
->
LinksFrom
({
mul_out_var
,
bias
}).
LinksTo
({
fc_out
});
return
fc_out
;
}
}
return
fc_out
;
}
}
#define NEW_NODE(op__, arg__, io__) \
PDNode
*
patterns
::
LSTM
::
operator
()(
PDNode
*
x
)
{
auto* arg__ = pattern->NewNode(name_scope, #arg__) \
->assert_is_op_##io__(#op__, #arg__);
PDNode
*
patterns
::
LSTM
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
,
PDNode
*
x
)
{
x
->
assert_is_op_input
(
"lstm"
,
"Input"
);
x
->
assert_is_op_input
(
"lstm"
,
"Input"
);
auto
*
lstm_op
=
pattern
->
NewNode
(
name_scope
,
"lstm"
)
->
assert_is_op
(
"lstm"
);
auto
*
lstm_op
=
pattern
->
NewNode
(
lstm_repr
())
->
assert_is_op
(
"lstm"
);
#define NEW_NODE(arg__, io__) \
auto* arg__ = \
pattern->NewNode(arg__##_repr())->assert_is_op_##io__("lstm", #arg__);
// Currently, the H0 and C0 are optional
// Currently, the H0 and C0 are optional
// TODO(Superjomn) upgrade the fuse framework to support optional.
// TODO(Superjomn) upgrade the fuse framework to support optional.
// NEW_NODE(H0, input);
// NEW_NODE(H0, input);
// NEW_NODE(C0, input);
// NEW_NODE(C0, input);
NEW_NODE
(
lstm
,
Weight
,
input
);
NEW_NODE
(
Weight
,
input
);
NEW_NODE
(
lstm
,
Bias
,
input
);
NEW_NODE
(
Bias
,
input
);
NEW_NODE
(
lstm
,
Hidden
,
output
);
NEW_NODE
(
Hidden
,
output
);
NEW_NODE
(
lstm
,
Cell
,
output
);
NEW_NODE
(
Cell
,
output
);
NEW_NODE
(
lstm
,
BatchGate
,
output
);
NEW_NODE
(
BatchGate
,
output
);
NEW_NODE
(
lstm
,
BatchCellPreAct
,
output
);
NEW_NODE
(
BatchCellPreAct
,
output
);
#undef NEW_NODE
lstm_op
->
LinksFrom
({
x
,
Weight
,
Bias
});
lstm_op
->
LinksFrom
({
x
,
Weight
,
Bias
});
lstm_op
->
LinksTo
({
Hidden
,
Cell
,
BatchGate
,
BatchCellPreAct
});
lstm_op
->
LinksTo
({
Hidden
,
Cell
,
BatchGate
,
BatchCellPreAct
});
return
Hidden
;
return
Hidden
;
}
}
PDNode
*
patterns
::
GRU
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
,
PDNode
*
patterns
::
GRU
::
operator
()(
PDNode
*
x
)
{
PDNode
*
x
)
{
x
->
assert_is_op_input
(
"gru"
,
"Input"
);
x
->
assert_is_op_input
(
"gru"
,
"Input"
);
auto
*
gru_op
=
pattern
->
NewNode
(
name_scope
,
"gru"
)
->
assert_is_op
(
"gru"
);
auto
*
gru_op
=
pattern
->
NewNode
(
gru_repr
())
->
assert_is_op
(
"gru"
);
#define NEW_NODE(arg__, io__) \
auto* arg__ = \
pattern->NewNode(arg__##_repr())->assert_is_op_##io__("gru", #arg__);
NEW_NODE
(
gru
,
Weight
,
input
);
NEW_NODE
(
Weight
,
input
);
// TODO(Superjomn): upgrade the fuse framework to support optional.
// TODO(Superjomn): upgrade the fuse framework to support optional.
// H0 and bias are optional
// H0 and bias are optional
NEW_NODE
(
gru
,
Bias
,
input
);
// also optional
NEW_NODE
(
Bias
,
input
);
// also optional
// NEW_NODE(H0, input);
// NEW_NODE(H0, input);
NEW_NODE
(
gru
,
Hidden
,
output
);
NEW_NODE
(
Hidden
,
output
);
// below are intermediate
// below are intermediate
NEW_NODE
(
gru
,
BatchGate
,
output
);
NEW_NODE
(
BatchGate
,
output
);
NEW_NODE
(
gru
,
BatchResetHiddenPrev
,
output
);
NEW_NODE
(
BatchResetHiddenPrev
,
output
);
NEW_NODE
(
gru
,
BatchHidden
,
output
);
NEW_NODE
(
BatchHidden
,
output
);
#undef NEW_NODE
BatchGate
->
AsIntermediate
();
BatchGate
->
AsIntermediate
();
BatchResetHiddenPrev
->
AsIntermediate
();
BatchResetHiddenPrev
->
AsIntermediate
();
...
@@ -607,7 +609,6 @@ PDNode* patterns::GRU(PDPattern* pattern, const std::string& name_scope,
...
@@ -607,7 +609,6 @@ PDNode* patterns::GRU(PDPattern* pattern, const std::string& name_scope,
gru_op
->
LinksTo
({
Hidden
,
BatchGate
,
BatchResetHiddenPrev
,
BatchHidden
});
gru_op
->
LinksTo
({
Hidden
,
BatchGate
,
BatchResetHiddenPrev
,
BatchHidden
});
return
Hidden
;
return
Hidden
;
}
}
#undef NEW_NODE
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
...
...
paddle/fluid/framework/ir/graph_pattern_detector.h
浏览文件 @
478a4e85
...
@@ -286,22 +286,148 @@ void GraphSafeRemoveNodes(Graph* graph,
...
@@ -286,22 +286,148 @@ void GraphSafeRemoveNodes(Graph* graph,
const
std
::
unordered_set
<
const
Node
*>&
nodes
);
const
std
::
unordered_set
<
const
Node
*>&
nodes
);
// Some pre-defined patterns those can be reused in multiple passes.
// Some pre-defined patterns those can be reused in multiple passes.
// The related Fluid Layer or Op should be one pattern here for better reusage
// accross different fusion.
namespace
patterns
{
namespace
patterns
{
struct
KeyCounter
{
static
KeyCounter
&
Instance
()
{
static
KeyCounter
x
;
return
x
;
}
int
IncCounter
(
const
std
::
string
&
key
)
{
return
dic_
[
key
]
++
;
}
private:
std
::
unordered_map
<
std
::
string
,
size_t
>
dic_
;
};
// Generate a unique PDNode's name with name_scope and id.
// The format is {name_scope}/{repr}/{id}/{name}
static
std
::
string
PDNodeName
(
const
std
::
string
&
name_scope
,
const
std
::
string
&
repr
,
size_t
id
,
const
std
::
string
&
name
)
{
return
string
::
Sprintf
(
"%s/%s/%d/%s"
,
name_scope
,
repr
,
id
,
name
);
}
// Generate a unique PDNode's name.
// The format is {name_scope}/{repr}/{id}
static
std
::
string
PDNodeName
(
const
std
::
string
&
name_scope
,
const
std
::
string
&
repr
)
{
return
string
::
Sprintf
(
"%s/%s/%d"
,
name_scope
,
repr
,
KeyCounter
::
Instance
().
IncCounter
(
repr
));
}
// Generate a unique key. It can be used for a universally unique temporary
// name.
// The format is {repr}/{id}
static
std
::
string
UniqueKey
(
const
std
::
string
&
repr
)
{
return
string
::
Sprintf
(
"%s/%d"
,
repr
,
KeyCounter
::
Instance
().
IncCounter
(
repr
));
}
// Declare a PDNode in a pattern, will create two methods:
// std::string xxx_repr(); return this PDNode's string id.
// PDNode* xxx_n(); return the corresponding PDNode.
#define PATTERN_DECL_NODE(name__) \
std::string name__##_repr() const { \
return PDNodeName(name_scope_, repr_, id_, #name__); \
} \
PDNode* name__##_n() const { return pattern->RetrieveNode(name__##_repr()); }
// Get an ir::Node* from the matched subgraph.
// var: variable.
// arg: the argument declared by PATTERN_DECL_NODE in a pattern definition.
// pat: the pattern object.
#define GET_IR_NODE_FROM_SUBGRAPH(var, arg, pat) \
PADDLE_ENFORCE(subgraph.count(pat.arg##_n()), \
"Node not found for PDNode %s", pat.arg##_repr()); \
Node* var = subgraph.at(pat.arg##_n()); \
PADDLE_ENFORCE(var, "node %s not exists in the sub-graph", #arg)
// The base class of all the patterns.
struct
PatternBase
{
PatternBase
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
,
const
std
::
string
&
repr
)
:
pattern
(
pattern
),
name_scope_
(
name_scope
),
repr_
(
repr
),
id_
(
KeyCounter
::
Instance
().
IncCounter
(
repr
))
{}
PDPattern
*
pattern
;
protected:
std
::
string
name_scope_
;
std
::
string
repr_
;
size_t
id_
;
};
// FC with bias
// FC with bias
// op: mul + elementwise_add
// op: mul + elementwise_add
// named nodes:
// named nodes:
// mul, elementwise_add
// mul, elementwise_add
// w, mul_out, bias, fc_out
// w, mul_out, bias, fc_out
PDNode
*
FC
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
,
PDNode
*
x
,
struct
FC
:
public
PatternBase
{
bool
with_bias
);
FC
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"fc"
)
{}
PDNode
*
operator
()(
PDNode
*
x
,
bool
with_bias
);
// declare operator node's name
PATTERN_DECL_NODE
(
fc
);
PATTERN_DECL_NODE
(
mul
);
PATTERN_DECL_NODE
(
elementwise_add
);
// declare variable node's name
PATTERN_DECL_NODE
(
w
);
PATTERN_DECL_NODE
(
mul_out
);
// (x,w) -> mul_out
PATTERN_DECL_NODE
(
bias
);
PATTERN_DECL_NODE
(
Out
);
};
struct
LSTM
:
public
PatternBase
{
LSTM
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"lstm"
)
{}
PDNode
*
LSTM
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
,
PDNode
*
x
);
PDNode
*
operator
()(
PDNode
*
x
);
PDNode
*
GRU
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
,
PDNode
*
x
);
// Operators
PATTERN_DECL_NODE
(
lstm
);
// Inputs
PATTERN_DECL_NODE
(
Input
);
PATTERN_DECL_NODE
(
H0
);
PATTERN_DECL_NODE
(
C0
);
PATTERN_DECL_NODE
(
Weight
);
PATTERN_DECL_NODE
(
Bias
);
// Outputs
PATTERN_DECL_NODE
(
Hidden
);
PATTERN_DECL_NODE
(
Cell
);
PATTERN_DECL_NODE
(
BatchGate
);
PATTERN_DECL_NODE
(
BatchCellPreAct
);
};
struct
GRU
:
public
PatternBase
{
GRU
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"lstm"
)
{}
PDNode
*
operator
()(
PDNode
*
x
);
// Operators
PATTERN_DECL_NODE
(
gru
);
// Inputs
PATTERN_DECL_NODE
(
Bias
);
PATTERN_DECL_NODE
(
Weight
);
// Outputs
PATTERN_DECL_NODE
(
BatchGate
);
PATTERN_DECL_NODE
(
BatchResetHiddenPrev
);
PATTERN_DECL_NODE
(
BatchHidden
);
PATTERN_DECL_NODE
(
Hidden
);
};
}
// namespace patterns
}
// namespace patterns
// Link two ir::Nodes from each other.
#define IR_NODE_LINK_TO(a, b) \
#define IR_NODE_LINK_TO(a, b) \
a->outputs.push_back(b); \
a->outputs.push_back(b); \
b->inputs.push_back(a);
b->inputs.push_back(a);
...
...
paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc
浏览文件 @
478a4e85
...
@@ -192,6 +192,8 @@ std::unique_ptr<ir::Graph> SeqConcatFcFusePass::ApplyImpl(
...
@@ -192,6 +192,8 @@ std::unique_ptr<ir::Graph> SeqConcatFcFusePass::ApplyImpl(
auto* id = subgraph.at(pattern.RetrieveNode(#id)); \
auto* id = subgraph.at(pattern.RetrieveNode(#id)); \
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
int
fuse_count
{
0
};
detector
(
graph
.
get
(),
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
detector
(
graph
.
get
(),
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
Graph
*
graph
)
{
VLOG
(
4
)
<<
"get one concat pattern"
;
VLOG
(
4
)
<<
"get one concat pattern"
;
...
@@ -239,8 +241,12 @@ std::unique_ptr<ir::Graph> SeqConcatFcFusePass::ApplyImpl(
...
@@ -239,8 +241,12 @@ std::unique_ptr<ir::Graph> SeqConcatFcFusePass::ApplyImpl(
marked_nodes
.
erase
(
sequence_expand1_in
);
marked_nodes
.
erase
(
sequence_expand1_in
);
marked_nodes
.
erase
(
fc_out
);
marked_nodes
.
erase
(
fc_out
);
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
++
fuse_count
;
});
});
AddStatis
(
fuse_count
);
return
graph
;
return
graph
;
}
}
...
...
paddle/fluid/inference/analysis/analyzer_tester.cc
浏览文件 @
478a4e85
...
@@ -267,6 +267,7 @@ void TestDituRNNPrediction(bool use_analysis, bool activate_ir,
...
@@ -267,6 +267,7 @@ void TestDituRNNPrediction(bool use_analysis, bool activate_ir,
PADDLE_ENFORCE
(
config
.
ir_mode
==
PADDLE_ENFORCE
(
config
.
ir_mode
==
AnalysisConfig
::
IrPassMode
::
kExclude
);
// default
AnalysisConfig
::
IrPassMode
::
kExclude
);
// default
config
.
ir_passes
.
clear
();
// Do not exclude any pass.
config
.
ir_passes
.
clear
();
// Do not exclude any pass.
int
batch_size
=
FLAGS_batch_size
;
int
batch_size
=
FLAGS_batch_size
;
int
num_times
=
FLAGS_repeat
;
int
num_times
=
FLAGS_repeat
;
...
@@ -346,6 +347,7 @@ void TestDituRNNPrediction(bool use_analysis, bool activate_ir,
...
@@ -346,6 +347,7 @@ void TestDituRNNPrediction(bool use_analysis, bool activate_ir,
ASSERT_TRUE
(
fuse_statis
.
count
(
"fc_fuse"
));
ASSERT_TRUE
(
fuse_statis
.
count
(
"fc_fuse"
));
EXPECT_EQ
(
fuse_statis
.
at
(
"fc_fuse"
),
1
);
EXPECT_EQ
(
fuse_statis
.
at
(
"fc_fuse"
),
1
);
EXPECT_EQ
(
fuse_statis
.
at
(
"fc_nobias_lstm_fuse"
),
2
);
// bi-directional LSTM
EXPECT_EQ
(
fuse_statis
.
at
(
"fc_nobias_lstm_fuse"
),
2
);
// bi-directional LSTM
EXPECT_EQ
(
fuse_statis
.
at
(
"seq_concat_fc_fuse"
),
1
);
EXPECT_EQ
(
num_ops
,
EXPECT_EQ
(
num_ops
,
13
);
// After graph optimization, only 13 operators exists.
13
);
// After graph optimization, only 13 operators exists.
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录