Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
5558784c
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5558784c
编写于
9月 10, 2018
作者:
Y
Yancey1989
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into reset_vars_on_pserver
上级
32b94a7d
5023530a
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
337 addition
and
275 deletion
+337
-275
paddle/fluid/framework/ir/fc_fuse_pass.cc
paddle/fluid/framework/ir/fc_fuse_pass.cc
+11
-22
paddle/fluid/framework/ir/fc_gru_fuse_pass.cc
paddle/fluid/framework/ir/fc_gru_fuse_pass.cc
+44
-62
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
+61
-91
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+62
-61
paddle/fluid/framework/ir/graph_pattern_detector.h
paddle/fluid/framework/ir/graph_pattern_detector.h
+130
-4
paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc
paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc
+6
-0
paddle/fluid/inference/analysis/CMakeLists.txt
paddle/fluid/inference/analysis/CMakeLists.txt
+8
-8
paddle/fluid/inference/analysis/analyzer_tester.cc
paddle/fluid/inference/analysis/analyzer_tester.cc
+15
-27
未找到文件。
paddle/fluid/framework/ir/fc_fuse_pass.cc
浏览文件 @
5558784c
...
@@ -29,39 +29,27 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
...
@@ -29,39 +29,27 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
std
::
unordered_set
<
Node
*>
nodes2delete
;
std
::
unordered_set
<
Node
*>
nodes2delete
;
GraphPatternDetector
gpd
;
GraphPatternDetector
gpd
;
// BuildFCPattern(gpd.mutable_pattern());
auto
*
x
=
gpd
.
mutable_pattern
()
auto
*
x
=
gpd
.
mutable_pattern
()
->
NewNode
(
"fc_fuse/x"
)
->
NewNode
(
"fc_fuse/x"
)
->
AsInput
()
->
AsInput
()
->
assert_is_op_input
(
"mul"
,
"X"
);
->
assert_is_op_input
(
"mul"
,
"X"
);
patterns
::
FC
(
gpd
.
mutable_pattern
(),
"fc_fuse"
,
x
,
true
/*with bias*/
);
patterns
::
FC
fc_pattern
(
gpd
.
mutable_pattern
(),
"fc_fuse"
);
fc_pattern
(
x
,
true
/*with bias*/
);
#define GET_NODE(id) \
PADDLE_ENFORCE(subgraph.count(gpd.pattern().RetrieveNode("fc_fuse/" #id)), \
"pattern has no Node called %s", #id); \
auto* id = subgraph.at(gpd.pattern().RetrieveNode("fc_fuse/" #id)); \
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", "fc_fuse/" #id);
int
found_fc_count
=
0
;
int
found_fc_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
Graph
*
g
)
{
VLOG
(
4
)
<<
"handle FC fuse"
;
VLOG
(
4
)
<<
"handle FC fuse"
;
// Currently, there is no FC op available, so I will just simulate the
GET_IR_NODE_FROM_SUBGRAPH
(
w
,
w
,
fc_pattern
);
// scenerio.
GET_IR_NODE_FROM_SUBGRAPH
(
fc_bias
,
bias
,
fc_pattern
);
// FC's fusion is simple, just op fuse, no need to process the
GET_IR_NODE_FROM_SUBGRAPH
(
fc_out
,
Out
,
fc_pattern
);
// parameters.
GET_IR_NODE_FROM_SUBGRAPH
(
mul
,
mul
,
fc_pattern
);
GET_NODE
(
x
);
// x
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_add
,
elementwise_add
,
fc_pattern
);
GET_NODE
(
w
);
// Y
GET_IR_NODE_FROM_SUBGRAPH
(
mul_out
,
mul_out
,
fc_pattern
);
GET_NODE
(
fc_bias
);
// bias
GET_NODE
(
fc_out
);
// Out
GET_NODE
(
mul
);
// MUL op
GET_NODE
(
elementwise_add
);
// ELEMENT_ADD op
GET_NODE
(
mul_out
);
// tmp
#undef GET_NODE
// Create an FC Node.
// Create an FC Node.
OpDesc
desc
;
OpDesc
desc
;
std
::
string
fc_x_in
=
x
->
Name
();
std
::
string
fc_x_in
=
subgraph
.
at
(
x
)
->
Name
();
std
::
string
fc_Y_in
=
w
->
Name
();
std
::
string
fc_Y_in
=
w
->
Name
();
std
::
string
fc_bias_in
=
fc_bias
->
Name
();
std
::
string
fc_bias_in
=
fc_bias
->
Name
();
std
::
string
fc_out_out
=
fc_out
->
Name
();
std
::
string
fc_out_out
=
fc_out
->
Name
();
...
@@ -73,7 +61,8 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
...
@@ -73,7 +61,8 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
auto
fc_node
=
g
->
CreateOpNode
(
&
desc
);
// OpDesc will be copied.
auto
fc_node
=
g
->
CreateOpNode
(
&
desc
);
// OpDesc will be copied.
GraphSafeRemoveNodes
(
graph
.
get
(),
{
mul
,
elementwise_add
,
mul_out
});
GraphSafeRemoveNodes
(
graph
.
get
(),
{
mul
,
elementwise_add
,
mul_out
});
IR_NODE_LINK_TO
(
x
,
fc_node
);
PADDLE_ENFORCE
(
subgraph
.
count
(
x
));
IR_NODE_LINK_TO
(
subgraph
.
at
(
x
),
fc_node
);
IR_NODE_LINK_TO
(
w
,
fc_node
);
IR_NODE_LINK_TO
(
w
,
fc_node
);
IR_NODE_LINK_TO
(
fc_bias
,
fc_node
);
IR_NODE_LINK_TO
(
fc_bias
,
fc_node
);
IR_NODE_LINK_TO
(
fc_node
,
fc_out
);
IR_NODE_LINK_TO
(
fc_node
,
fc_out
);
...
...
paddle/fluid/framework/ir/fc_gru_fuse_pass.cc
浏览文件 @
5558784c
...
@@ -20,52 +20,43 @@ namespace paddle {
...
@@ -20,52 +20,43 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
ir
{
namespace
ir
{
static
void
BuildPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
,
bool
with_fc_bias
)
{
PDNode
*
x
=
pattern
->
NewNode
(
name_scope
,
"x"
)
->
assert_is_op_input
(
"mul"
)
->
assert_var_not_persistable
();
auto
*
fc_out
=
patterns
::
FC
(
pattern
,
name_scope
,
x
,
with_fc_bias
);
fc_out
->
AsIntermediate
();
// fc_out is a tmp var, will be removed after fuse.
patterns
::
GRU
(
pattern
,
name_scope
,
fc_out
);
VLOG
(
3
)
<<
"fc_gru pattern
\n
"
<<
pattern
->
DotString
();
}
static
int
BuildFusion
(
Graph
*
graph
,
const
std
::
string
&
name_scope
,
static
int
BuildFusion
(
Graph
*
graph
,
const
std
::
string
&
name_scope
,
Scope
*
scope
,
bool
with_fc_bias
)
{
Scope
*
scope
,
bool
with_fc_bias
)
{
GraphPatternDetector
gpd
;
GraphPatternDetector
gpd
;
auto
*
pattern
=
gpd
.
mutable_pattern
();
auto
*
pattern
=
gpd
.
mutable_pattern
();
BuildPattern
(
pattern
,
name_scope
,
with_fc_bias
);
// Create pattern.
patterns
::
FC
fc_pattern
(
pattern
,
name_scope
);
patterns
::
GRU
gru_pattern
(
pattern
,
name_scope
);
PDNode
*
x
=
pattern
->
NewNode
(
patterns
::
UniqueKey
(
"x"
))
->
assert_var_not_persistable
();
auto
*
fc_out
=
fc_pattern
(
x
,
with_fc_bias
);
fc_out
->
AsIntermediate
();
// fc_out is a tmp var, will be removed after fuse.
gru_pattern
(
fc_out
);
// Create New OpDesc
// Create New OpDesc
auto
gru_creater
=
[
&
](
int
gru
,
int
x
,
int
weight_x
,
int
weight_h
,
int
bias
,
auto
gru_creater
=
[
&
](
Node
*
gru
,
Node
*
x
,
Node
*
weight_x
,
Node
*
weight_h
,
int
hidden
,
int
fc_bias
)
{
Node
*
bias
,
Node
*
hidden
,
Node
*
fc_bias
)
{
#define GET_NODE(x) auto* x##_n = graph->RetriveNode(x);
GET_NODE
(
x
);
GET_NODE
(
weight_x
);
GET_NODE
(
weight_h
);
GET_NODE
(
bias
);
GET_NODE
(
hidden
);
GET_NODE
(
gru
);
OpDesc
op_desc
;
OpDesc
op_desc
;
op_desc
.
SetType
(
"fusion_gru"
);
op_desc
.
SetType
(
"fusion_gru"
);
#define NEW_NAME(x) name_scope + "/at." #x ".new"
#define NEW_NAME(x) name_scope + "/at." #x ".new"
#define SET_IN(Key, node__) op_desc.SetInput(#Key, {node__
##_n
->Name()});
#define SET_IN(Key, node__) op_desc.SetInput(#Key, {node__->Name()});
SET_IN
(
X
,
x
);
SET_IN
(
X
,
x
);
SET_IN
(
WeightX
,
weight_x
);
SET_IN
(
WeightX
,
weight_x
);
SET_IN
(
WeightH
,
weight_h
);
SET_IN
(
WeightH
,
weight_h
);
if
(
with_fc_bias
)
{
if
(
with_fc_bias
)
{
op_desc
.
SetInput
(
"Bias"
,
{
NEW_NAME
(
bias
)
+
bias
_n
->
Name
()});
op_desc
.
SetInput
(
"Bias"
,
{
NEW_NAME
(
bias
)
+
bias
->
Name
()});
}
else
{
}
else
{
SET_IN
(
Bias
,
bias
);
SET_IN
(
Bias
,
bias
);
}
}
#undef SET_IN
#undef SET_IN
op_desc
.
SetInput
(
"H0"
,
{});
op_desc
.
SetInput
(
"H0"
,
{});
op_desc
.
SetOutput
(
"Hidden"
,
{
hidden
_n
->
Name
()});
op_desc
.
SetOutput
(
"Hidden"
,
{
hidden
->
Name
()});
op_desc
.
SetAttr
(
"is_reverse"
,
gru
_n
->
Op
()
->
GetAttr
(
"is_reverse"
));
op_desc
.
SetAttr
(
"is_reverse"
,
gru
->
Op
()
->
GetAttr
(
"is_reverse"
));
// TODO(TJ): This should be a option for infer
// TODO(TJ): This should be a option for infer
op_desc
.
SetAttr
(
"use_seq"
,
true
);
op_desc
.
SetAttr
(
"use_seq"
,
true
);
...
@@ -82,14 +73,12 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
...
@@ -82,14 +73,12 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
PADDLE_ENFORCE
(
scope
);
PADDLE_ENFORCE
(
scope
);
if
(
with_fc_bias
)
{
if
(
with_fc_bias
)
{
// Fusion GRU bias = fcbias + grubias
// Fusion GRU bias = fcbias + grubias
auto
*
fusion_bias_var
=
scope
->
Var
(
NEW_NAME
(
bias
)
+
bias
_n
->
Name
());
auto
*
fusion_bias_var
=
scope
->
Var
(
NEW_NAME
(
bias
)
+
bias
->
Name
());
auto
*
out_bias_tensor
=
auto
*
out_bias_tensor
=
fusion_bias_var
->
GetMutable
<
framework
::
LoDTensor
>
();
fusion_bias_var
->
GetMutable
<
framework
::
LoDTensor
>
();
PADDLE_ENFORCE
(
fusion_bias_var
);
PADDLE_ENFORCE
(
fusion_bias_var
);
GET_NODE
(
fc_bias
);
auto
*
gru_bias_var
=
scope
->
FindVar
(
bias
->
Name
());
PADDLE_ENFORCE
(
fc_bias_n
);
auto
*
fc_bias_var
=
scope
->
FindVar
(
fc_bias
->
Name
());
auto
*
gru_bias_var
=
scope
->
FindVar
(
bias_n
->
Name
());
auto
*
fc_bias_var
=
scope
->
FindVar
(
fc_bias_n
->
Name
());
PADDLE_ENFORCE
(
gru_bias_var
);
PADDLE_ENFORCE
(
gru_bias_var
);
PADDLE_ENFORCE
(
fc_bias_var
);
PADDLE_ENFORCE
(
fc_bias_var
);
const
auto
&
gru_bias_tenosr
=
gru_bias_var
->
Get
<
framework
::
LoDTensor
>
();
const
auto
&
gru_bias_tenosr
=
gru_bias_var
->
Get
<
framework
::
LoDTensor
>
();
...
@@ -113,11 +102,11 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
...
@@ -113,11 +102,11 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
#undef NEW_NAME
#undef NEW_NAME
#undef NEW_IMTERMEDIATE_OUT
#undef NEW_IMTERMEDIATE_OUT
IR_NODE_LINK_TO
(
x
_n
,
op
);
IR_NODE_LINK_TO
(
x
,
op
);
IR_NODE_LINK_TO
(
weight_x
_n
,
op
);
IR_NODE_LINK_TO
(
weight_x
,
op
);
IR_NODE_LINK_TO
(
weight_h
_n
,
op
);
IR_NODE_LINK_TO
(
weight_h
,
op
);
IR_NODE_LINK_TO
(
bias
_n
,
op
);
// actually should link to new bias if have
IR_NODE_LINK_TO
(
bias
,
op
);
// actually should link to new bias if have
IR_NODE_LINK_TO
(
op
,
hidden
_n
);
IR_NODE_LINK_TO
(
op
,
hidden
);
// h0?
// h0?
return
op
;
return
op
;
};
};
...
@@ -125,42 +114,35 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
...
@@ -125,42 +114,35 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
int
fusion_count
{
0
};
int
fusion_count
{
0
};
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
Graph
*
g
)
{
#define GET_NODE(name__) \
auto
*
x_n
=
subgraph
.
at
(
x
);
std::string name__##key = name_scope + "/" + #name__; \
GET_IR_NODE_FROM_SUBGRAPH
(
w
,
w
,
fc_pattern
);
auto* name__##n = pattern->RetrieveNode(name__##key); \
GET_IR_NODE_FROM_SUBGRAPH
(
mul
,
mul
,
fc_pattern
);
PADDLE_ENFORCE(name__##n); \
GET_IR_NODE_FROM_SUBGRAPH
(
fc_out
,
Out
,
fc_pattern
);
PADDLE_ENFORCE(subgraph.count(name__##n)); \
GET_IR_NODE_FROM_SUBGRAPH
(
Weight
,
Weight
,
gru_pattern
);
Node* name__##_n = subgraph.at(name__##n); \
GET_IR_NODE_FROM_SUBGRAPH
(
gru
,
gru
,
gru_pattern
);
int name__ __attribute__((unused)) = name__##_n->id();
GET_IR_NODE_FROM_SUBGRAPH
(
Bias
,
Bias
,
gru_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
Hidden
,
Hidden
,
gru_pattern
);
GET_NODE
(
x
);
GET_NODE
(
w
);
// fc weight
GET_NODE
(
mul
);
GET_NODE
(
fc_out
);
GET_NODE
(
Weight
);
GET_NODE
(
gru
);
GET_NODE
(
Bias
);
GET_NODE
(
Hidden
);
// nodes need be removed
// nodes need be removed
GET_
NODE
(
BatchGate
);
GET_
IR_NODE_FROM_SUBGRAPH
(
BatchGate
,
BatchGate
,
gru_pattern
);
GET_
NODE
(
BatchResetHiddenPrev
);
GET_
IR_NODE_FROM_SUBGRAPH
(
BatchResetHiddenPrev
,
BatchGate
,
gru_pattern
);
GET_
NODE
(
BatchHidde
n
);
GET_
IR_NODE_FROM_SUBGRAPH
(
BatchHidden
,
BatchGate
,
gru_patter
n
);
if
(
with_fc_bias
)
{
if
(
with_fc_bias
)
{
GET_NODE
(
mul_out
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul_out
,
mul_out
,
fc_pattern
);
GET_NODE
(
fc_bias
);
GET_IR_NODE_FROM_SUBGRAPH
(
fc_bias
,
bias
,
fc_pattern
);
GET_NODE
(
elementwise_add
);
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_add
,
elementwise_add
,
fc_pattern
);
gru_creater
(
gru
,
x
,
w
,
Weight
,
Bias
,
Hidden
,
fc_bias
);
gru_creater
(
gru
,
x_n
,
w
,
Weight
,
Bias
,
Hidden
,
fc_bias
);
// Remove unneeded nodes.
// Remove unneeded nodes.
std
::
unordered_set
<
const
Node
*>
marked_nodes
(
std
::
unordered_set
<
const
Node
*>
marked_nodes
(
{
mul
_n
,
gru_n
,
elementwise_add_n
,
fc_bias_n
,
fc_out_n
,
mul_out_n
,
{
mul
,
gru
,
elementwise_add
,
fc_bias
,
fc_out
,
mul_out
,
BatchGate
,
Batch
Gate_n
,
BatchResetHiddenPrev_n
,
BatchHidden_
n
});
Batch
ResetHiddenPrev
,
BatchHidde
n
});
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
}
else
{
}
else
{
gru_creater
(
gru
,
x
,
w
,
Weight
,
Bias
,
Hidden
,
-
1
);
gru_creater
(
gru
,
x
_n
,
w
,
Weight
,
Bias
,
Hidden
,
nullptr
);
// Remove unneeded nodes.
// Remove unneeded nodes.
std
::
unordered_set
<
const
Node
*>
marked_nodes
(
std
::
unordered_set
<
const
Node
*>
marked_nodes
(
{
mul
_n
,
gru_n
,
BatchGate_n
,
BatchResetHiddenPrev_n
,
BatchHidden_
n
});
{
mul
,
gru
,
BatchGate
,
BatchResetHiddenPrev
,
BatchHidde
n
});
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
}
}
#undef GET_NODE
#undef GET_NODE
...
...
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
浏览文件 @
5558784c
...
@@ -20,45 +20,29 @@ namespace paddle {
...
@@ -20,45 +20,29 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
ir
{
namespace
ir
{
static
std
::
string
GenNodeName
(
const
std
::
string
&
prefix
,
int
BuildFusion
(
Graph
*
graph
,
const
std
::
string
&
name_scope
,
Scope
*
scope
,
const
std
::
string
&
name
)
{
return
prefix
+
"/"
+
name
;
}
static
void
BuildPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
,
bool
with_fc_bias
)
{
bool
with_fc_bias
)
{
PDNode
*
x
=
pattern
->
NewNode
(
name_scope
,
"x"
)
->
assert_is_op_input
(
"mul"
)
->
assert_var_not_persistable
();
auto
*
fc_out
=
patterns
::
FC
(
pattern
,
name_scope
,
x
,
with_fc_bias
);
fc_out
->
AsIntermediate
();
// fc_out is a tmp var, will be removed after fuse.
patterns
::
LSTM
(
pattern
,
name_scope
,
fc_out
);
// LOG(INFO) << "\n" << pattern->DotString();
}
static
int
BuildFusion
(
Graph
*
graph
,
const
std
::
string
&
name_scope
,
Scope
*
scope
,
bool
with_fc_bias
)
{
GraphPatternDetector
gpd
;
GraphPatternDetector
gpd
;
auto
*
pattern
=
gpd
.
mutable_pattern
();
auto
*
pattern
=
gpd
.
mutable_pattern
();
BuildPattern
(
pattern
,
name_scope
,
with_fc_bias
);
// Build pattern
PDNode
*
x
=
pattern
->
NewNode
(
patterns
::
PDNodeName
(
name_scope
,
"x"
))
->
assert_is_op_input
(
"mul"
)
->
assert_var_not_persistable
();
patterns
::
FC
fc_pattern
(
pattern
,
name_scope
);
// fc_out is a tmp var, will be removed after fuse, so marked as intermediate.
auto
*
fc_out
=
fc_pattern
(
x
,
with_fc_bias
)
->
AsIntermediate
();
patterns
::
LSTM
lstm_pattern
(
pattern
,
name_scope
);
lstm_pattern
(
fc_out
);
// Create New OpDesc
// Create New OpDesc
auto
lstm_creator
=
[
&
](
int
lstm
,
int
input
,
int
weight_x
,
int
weight_h
,
auto
lstm_creator
=
[
&
](
Node
*
lstm
,
Node
*
input
,
Node
*
weight_x
,
int
bias
,
int
hidden
,
int
cell
,
int
xx
,
int
fc_bias
)
{
Node
*
weight_h
,
Node
*
bias
,
Node
*
hidden
,
Node
*
cell
,
#define GET_NODE(x) auto* x##_n = graph->RetriveNode(x);
Node
*
xx
,
Node
*
fc_bias
)
{
GET_NODE
(
input
);
GET_NODE
(
weight_x
);
GET_NODE
(
weight_h
);
GET_NODE
(
bias
);
GET_NODE
(
hidden
);
GET_NODE
(
cell
);
GET_NODE
(
xx
);
GET_NODE
(
lstm
);
OpDesc
op_desc
;
OpDesc
op_desc
;
op_desc
.
SetType
(
"fusion_lstm"
);
op_desc
.
SetType
(
"fusion_lstm"
);
#define SET_IN(Key, node__) op_desc.SetInput(#Key, {node__
##_n
->Name()});
#define SET_IN(Key, node__) op_desc.SetInput(#Key, {node__->Name()});
SET_IN
(
X
,
input
);
SET_IN
(
X
,
input
);
SET_IN
(
WeightX
,
weight_x
);
SET_IN
(
WeightX
,
weight_x
);
SET_IN
(
WeightH
,
weight_h
);
SET_IN
(
WeightH
,
weight_h
);
...
@@ -71,13 +55,12 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
...
@@ -71,13 +55,12 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
auto
*
bias_var
=
scope
->
Var
(
new_bias_var
);
auto
*
bias_var
=
scope
->
Var
(
new_bias_var
);
PADDLE_ENFORCE
(
bias_var
);
PADDLE_ENFORCE
(
bias_var
);
auto
*
bias_tensor
=
bias_var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
bias_tensor
=
bias_var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
lstm_bias_var
=
scope
->
FindVar
(
bias
_n
->
Name
());
auto
*
lstm_bias_var
=
scope
->
FindVar
(
bias
->
Name
());
PADDLE_ENFORCE
(
lstm_bias_var
);
PADDLE_ENFORCE
(
lstm_bias_var
);
const
auto
&
lstm_bias_tensor
=
lstm_bias_var
->
Get
<
framework
::
LoDTensor
>
();
const
auto
&
lstm_bias_tensor
=
lstm_bias_var
->
Get
<
framework
::
LoDTensor
>
();
bias_tensor
->
Resize
(
lstm_bias_tensor
.
dims
());
bias_tensor
->
Resize
(
lstm_bias_tensor
.
dims
());
GET_NODE
(
fc_bias
);
auto
*
fc_bias_var
=
scope
->
FindVar
(
fc_bias
->
Name
());
auto
*
fc_bias_var
=
scope
->
FindVar
(
fc_bias_n
->
Name
());
const
auto
&
fc_bias_tensor
=
fc_bias_var
->
Get
<
framework
::
LoDTensor
>
();
const
auto
&
fc_bias_tensor
=
fc_bias_var
->
Get
<
framework
::
LoDTensor
>
();
auto
*
data
=
bias_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
auto
*
data
=
bias_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
...
@@ -88,31 +71,36 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
...
@@ -88,31 +71,36 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
}
}
op_desc
.
SetInput
(
"Bias"
,
{
new_bias_var
});
op_desc
.
SetInput
(
"Bias"
,
{
new_bias_var
});
}
}
#undef GET_NODE
// Create temp variables.
// Create temp variables.
scope
->
Var
(
name_scope
+
"/BatchedInput.new"
)
const
std
::
string
BatchedInput
=
patterns
::
UniqueKey
(
"BatchedInput"
);
->
GetMutable
<
framework
::
LoDTensor
>
();
const
std
::
string
BatchedCellPreAct
=
scope
->
Var
(
name_scope
+
"/BatchCellPreAct.new"
)
patterns
::
UniqueKey
(
"BatchedCellPreAct"
);
->
GetMutable
<
framework
::
LoDTensor
>
();
const
std
::
string
BatchedGate
=
patterns
::
UniqueKey
(
"BatchedGate"
);
scope
->
Var
(
name_scope
+
"/BatchedGate.new"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
scope
->
Var
(
BatchedInput
)
->
GetMutable
<
framework
::
LoDTensor
>
();
scope
->
Var
(
BatchedCellPreAct
)
->
GetMutable
<
framework
::
LoDTensor
>
();
scope
->
Var
(
BatchedGate
)
->
GetMutable
<
framework
::
LoDTensor
>
();
op_desc
.
SetInput
(
"H0"
,
{});
op_desc
.
SetInput
(
"H0"
,
{});
op_desc
.
SetInput
(
"C0"
,
{});
op_desc
.
SetInput
(
"C0"
,
{});
op_desc
.
SetOutput
(
"Hidden"
,
{
hidden
_n
->
Name
()});
op_desc
.
SetOutput
(
"Hidden"
,
{
hidden
->
Name
()});
op_desc
.
SetOutput
(
"Cell"
,
{
cell
_n
->
Name
()});
op_desc
.
SetOutput
(
"Cell"
,
{
cell
->
Name
()});
op_desc
.
SetOutput
(
"XX"
,
{
xx
_n
->
Name
()});
op_desc
.
SetOutput
(
"XX"
,
{
xx
->
Name
()});
op_desc
.
SetOutput
(
"BatchedGate"
,
{
name_scope
+
"/BatchedGate.new"
});
op_desc
.
SetOutput
(
"BatchedGate"
,
{
BatchedGate
});
op_desc
.
SetOutput
(
"BatchCellPreAct"
,
{
name_scope
+
"/BatchCellPreAct.new"
});
op_desc
.
SetOutput
(
"BatchCellPreAct"
,
{
BatchedCellPreAct
});
op_desc
.
SetOutput
(
"BatchedInput"
,
{
name_scope
+
"/BatchedInput.new"
});
op_desc
.
SetOutput
(
"BatchedInput"
,
{
BatchedInput
});
op_desc
.
SetAttr
(
"is_reverse"
,
lstm
_n
->
Op
()
->
GetAttr
(
"is_reverse"
));
op_desc
.
SetAttr
(
"is_reverse"
,
lstm
->
Op
()
->
GetAttr
(
"is_reverse"
));
op_desc
.
SetAttr
(
"use_peepholes"
,
lstm
_n
->
Op
()
->
GetAttr
(
"use_peepholes"
));
op_desc
.
SetAttr
(
"use_peepholes"
,
lstm
->
Op
()
->
GetAttr
(
"use_peepholes"
));
// TODO(TJ): get from attr
// TODO(TJ): get from attr
op_desc
.
SetAttr
(
"use_seq"
,
true
);
op_desc
.
SetAttr
(
"use_seq"
,
true
);
#define TMP_NAME(x) "at.new.tmp." #x
PADDLE_ENFORCE
(
graph
->
Has
(
kParamScopeAttr
));
#define OP_SET_OUT(x) op_desc.SetOutput(#x, {TMP_NAME(x)})
auto
*
scope
=
graph
->
Get
<
Scope
*>
(
kParamScopeAttr
);
#define OP_SET_OUT(x) \
const std::string x = patterns::UniqueKey(#x); \
op_desc.SetOutput(#x, {x}); \
scope->Var(x)->GetMutable<LoDTensor>()
OP_SET_OUT
(
BatchedCell
);
OP_SET_OUT
(
BatchedCell
);
OP_SET_OUT
(
BatchedHidden
);
OP_SET_OUT
(
BatchedHidden
);
OP_SET_OUT
(
ReorderedH0
);
OP_SET_OUT
(
ReorderedH0
);
...
@@ -120,22 +108,11 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
...
@@ -120,22 +108,11 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
#undef OP_SET_OUT
#undef OP_SET_OUT
auto
*
op
=
graph
->
CreateOpNode
(
&
op_desc
);
auto
*
op
=
graph
->
CreateOpNode
(
&
op_desc
);
PADDLE_ENFORCE
(
graph
->
Has
(
kParamScopeAttr
));
IR_NODE_LINK_TO
(
input
,
op
);
auto
*
scope
=
graph
->
Get
<
Scope
*>
(
kParamScopeAttr
);
IR_NODE_LINK_TO
(
weight_x
,
op
);
IR_NODE_LINK_TO
(
weight_h
,
op
);
#define TMP_NEW(x) scope->Var(TMP_NAME(x))->GetMutable<LoDTensor>()
IR_NODE_LINK_TO
(
bias
,
op
);
TMP_NEW
(
BatchedCell
);
IR_NODE_LINK_TO
(
op
,
hidden
);
TMP_NEW
(
BatchedHidden
);
TMP_NEW
(
ReorderedH0
);
TMP_NEW
(
ReorderedC0
);
#undef TMP_NEW
#undef TMP_NAME
IR_NODE_LINK_TO
(
input_n
,
op
);
IR_NODE_LINK_TO
(
weight_x_n
,
op
);
IR_NODE_LINK_TO
(
weight_h_n
,
op
);
IR_NODE_LINK_TO
(
bias_n
,
op
);
IR_NODE_LINK_TO
(
op
,
hidden_n
);
return
op
;
return
op
;
};
};
...
@@ -143,39 +120,32 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
...
@@ -143,39 +120,32 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
Graph
*
g
)
{
#define GET_NODE(name__) \
std::string name__##key = name_scope + "/" + #name__; \
auto* name__##n = pattern->RetrieveNode(name__##key); \
PADDLE_ENFORCE(name__##n); \
PADDLE_ENFORCE(subgraph.count(name__##n)); \
Node* name__##_n = subgraph.at(name__##n); \
int name__ __attribute__((unused)) = name__##_n->id();
GET_NODE
(
x
);
GET_NODE
(
w
);
GET_NODE
(
mul
);
GET_NODE
(
fc_out
);
GET_NODE
(
Weight
);
GET_NODE
(
lstm
);
GET_NODE
(
Bias
);
GET_NODE
(
Hidden
);
GET_NODE
(
Cell
);
GET_IR_NODE_FROM_SUBGRAPH
(
lstm
,
lstm
,
lstm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
Weight
,
Weight
,
lstm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
Bias
,
Bias
,
lstm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
Cell
,
Cell
,
lstm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
Hidden
,
Hidden
,
lstm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
w
,
w
,
fc_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul
,
mul
,
fc_pattern
);
if
(
with_fc_bias
)
{
if
(
with_fc_bias
)
{
GET_NODE
(
fc_bias
);
GET_IR_NODE_FROM_SUBGRAPH
(
fc_out
,
Out
,
fc_pattern
);
GET_NODE
(
elementwise_add
);
GET_IR_NODE_FROM_SUBGRAPH
(
fc_bias
,
bias
,
fc_pattern
);
lstm_creator
(
lstm
,
x
,
w
,
Weight
,
Bias
,
Hidden
,
Cell
,
fc_out
,
fc_bias
);
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_add
,
elementwise_add
,
fc_pattern
);
lstm_creator
(
lstm
,
subgraph
.
at
(
x
),
w
,
Weight
,
Bias
,
Hidden
,
Cell
,
fc_out
,
fc_bias
);
// Remove unneeded nodes.
// Remove unneeded nodes.
std
::
unordered_set
<
const
Node
*>
marked_nodes
(
std
::
unordered_set
<
const
Node
*>
marked_nodes
(
{
mul
_n
,
lstm_n
,
elementwise_add_n
});
{
mul
,
lstm
,
elementwise_add
});
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
}
else
{
}
else
{
lstm_creator
(
lstm
,
x
,
w
,
Weight
,
Bias
,
Hidden
,
Cell
,
fc_out
,
-
1
);
GET_IR_NODE_FROM_SUBGRAPH
(
fc_out
,
mul_out
,
fc_pattern
);
lstm_creator
(
lstm
,
subgraph
.
at
(
x
),
w
,
Weight
,
Bias
,
Hidden
,
Cell
,
fc_out
,
nullptr
);
// Remove unneeded nodes.
// Remove unneeded nodes.
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
mul
_n
,
lstm_n
});
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
mul
,
lstm
});
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
}
}
#undef GET_NODE
++
fusion_count
;
++
fusion_count
;
};
};
...
...
paddle/fluid/framework/ir/graph_pattern_detector.cc
浏览文件 @
5558784c
...
@@ -21,6 +21,7 @@
...
@@ -21,6 +21,7 @@
#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/printf.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -106,8 +107,7 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) {
...
@@ -106,8 +107,7 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) {
for
(
auto
&
pdnode
:
pattern_
.
nodes
())
{
for
(
auto
&
pdnode
:
pattern_
.
nodes
())
{
if
(
!
pdnodes2nodes_
.
count
(
pdnode
.
get
()))
{
if
(
!
pdnodes2nodes_
.
count
(
pdnode
.
get
()))
{
VLOG
(
4
)
<<
pdnode
->
name
()
<<
" can't find matched Node, early stop"
;
VLOG
(
4
)
<<
pdnode
->
name
()
<<
" can't find matched Node, early stop"
;
// return false;
return
false
;
}
}
}
}
for
(
auto
&
item
:
pdnodes2nodes_
)
{
for
(
auto
&
item
:
pdnodes2nodes_
)
{
...
@@ -517,87 +517,89 @@ bool VarLinksFromOp(Node* node, const std::string& op_type) {
...
@@ -517,87 +517,89 @@ bool VarLinksFromOp(Node* node, const std::string& op_type) {
return
false
;
return
false
;
}
}
PDNode
*
patterns
::
FC
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
,
PDNode
*
patterns
::
FC
::
operator
()(
paddle
::
framework
::
ir
::
PDNode
*
x
,
PDNode
*
x
,
bool
with_bias
)
{
bool
with_bias
)
{
// mul op
// Create shared nodes.
auto
*
mul_op
=
pattern
->
NewNode
(
name_scope
,
"mul"
)
->
assert_is_op
(
"mul"
);
x
->
assert_is_op_input
(
"mul"
,
"X"
);
auto
*
mul_weight_var
=
pattern
->
NewNode
(
name_scope
,
"w"
)
auto
*
mul
=
pattern
->
NewNode
(
mul_repr
())
->
assert_is_op
(
"mul"
);
auto
*
mul_w_var
=
pattern
->
NewNode
(
w_repr
())
->
AsInput
()
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"mul"
,
"Y"
);
->
assert_is_op_input
(
"mul"
,
"Y"
);
PDNode
*
fc_out
{
nullptr
};
auto
*
mul_out_var
=
if
(
with_bias
)
{
pattern
->
NewNode
(
mul_out_repr
())
->
assert_is_op_output
(
"mul"
);
PDNode
*
elementwise_add_op
{
nullptr
};
PDNode
*
mul_out_var
{
nullptr
},
*
bias
{
nullptr
};
if
(
!
with_bias
)
{
// not with bias
elementwise_add_op
=
pattern
->
NewNode
(
name_scope
,
"elementwise_add"
)
// Add links.
mul
->
LinksFrom
({
x
,
mul_w_var
}).
LinksTo
({
mul_out_var
});
return
mul_out_var
;
}
else
{
// with bias
mul_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
// Create operators.
auto
*
elementwise_add
=
pattern
->
NewNode
(
elementwise_add_repr
())
->
assert_is_op
(
"elementwise_add"
);
->
assert_is_op
(
"elementwise_add"
);
// intermediate variable, will be removed in the IR after fuse.
// Create variables.
mul_out_var
=
pattern
->
NewNode
(
name_scope
,
"mul_out"
)
auto
*
bias
=
pattern
->
NewNode
(
bias_repr
())
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
)
->
assert_is_only_output_of_op
(
"mul"
)
->
AsInput
();
->
assert_is_op_input
(
"elementwise_add"
);
// bias
auto
*
fc_out
=
pattern
->
NewNode
(
Out_repr
())
bias
=
pattern
->
NewNode
(
name_scope
,
"fc_bias"
)
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
);
// output
fc_out
=
pattern
->
NewNode
(
name_scope
,
"fc_out"
)
->
AsOutput
()
->
AsOutput
()
->
assert_is_op_output
(
"elementwise_add"
);
->
assert_is_op_output
(
"elementwise_add"
);
mul_op
->
LinksFrom
({
x
,
mul_weight_var
}).
LinksTo
({
mul_out_var
});
elementwise_add_op
->
LinksFrom
({
mul_out_var
,
bias
}).
LinksTo
({
fc_out
});
mul
->
LinksFrom
({
mul_w_var
,
x
}).
LinksTo
({
mul_out_var
});
}
else
{
elementwise_add
->
LinksFrom
({
mul_out_var
,
bias
}).
LinksTo
({
fc_out
});
fc_out
=
pattern
->
NewNode
(
name_scope
,
"fc_out"
)
->
AsOutput
()
->
assert_is_op_output
(
"mul"
);
mul_op
->
LinksFrom
({
mul_weight_var
,
x
}).
LinksTo
({
fc_out
});
}
return
fc_out
;
return
fc_out
;
}
}
}
#define NEW_NODE(op__, arg__, io__) \
PDNode
*
patterns
::
LSTM
::
operator
()(
PDNode
*
x
)
{
auto* arg__ = pattern->NewNode(name_scope, #arg__) \
->assert_is_op_##io__(#op__, #arg__);
PDNode
*
patterns
::
LSTM
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
,
PDNode
*
x
)
{
x
->
assert_is_op_input
(
"lstm"
,
"Input"
);
x
->
assert_is_op_input
(
"lstm"
,
"Input"
);
auto
*
lstm_op
=
pattern
->
NewNode
(
name_scope
,
"lstm"
)
->
assert_is_op
(
"lstm"
);
auto
*
lstm_op
=
pattern
->
NewNode
(
lstm_repr
())
->
assert_is_op
(
"lstm"
);
#define NEW_NODE(arg__, io__) \
auto* arg__ = \
pattern->NewNode(arg__##_repr())->assert_is_op_##io__("lstm", #arg__);
// Currently, the H0 and C0 are optional
// Currently, the H0 and C0 are optional
// TODO(Superjomn) upgrade the fuse framework to support optional.
// TODO(Superjomn) upgrade the fuse framework to support optional.
// NEW_NODE(H0, input);
// NEW_NODE(H0, input);
// NEW_NODE(C0, input);
// NEW_NODE(C0, input);
NEW_NODE
(
lstm
,
Weight
,
input
);
NEW_NODE
(
Weight
,
input
);
NEW_NODE
(
lstm
,
Bias
,
input
);
NEW_NODE
(
Bias
,
input
);
NEW_NODE
(
lstm
,
Hidden
,
output
);
NEW_NODE
(
Hidden
,
output
);
NEW_NODE
(
lstm
,
Cell
,
output
);
NEW_NODE
(
Cell
,
output
);
NEW_NODE
(
lstm
,
BatchGate
,
output
);
NEW_NODE
(
BatchGate
,
output
);
NEW_NODE
(
lstm
,
BatchCellPreAct
,
output
);
NEW_NODE
(
BatchCellPreAct
,
output
);
#undef NEW_NODE
lstm_op
->
LinksFrom
({
x
,
Weight
,
Bias
});
lstm_op
->
LinksFrom
({
x
,
Weight
,
Bias
});
lstm_op
->
LinksTo
({
Hidden
,
Cell
,
BatchGate
,
BatchCellPreAct
});
lstm_op
->
LinksTo
({
Hidden
,
Cell
,
BatchGate
,
BatchCellPreAct
});
return
Hidden
;
return
Hidden
;
}
}
PDNode
*
patterns
::
GRU
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
,
PDNode
*
patterns
::
GRU
::
operator
()(
PDNode
*
x
)
{
PDNode
*
x
)
{
x
->
assert_is_op_input
(
"gru"
,
"Input"
);
x
->
assert_is_op_input
(
"gru"
,
"Input"
);
auto
*
gru_op
=
pattern
->
NewNode
(
name_scope
,
"gru"
)
->
assert_is_op
(
"gru"
);
auto
*
gru_op
=
pattern
->
NewNode
(
gru_repr
())
->
assert_is_op
(
"gru"
);
#define NEW_NODE(arg__, io__) \
auto* arg__ = \
pattern->NewNode(arg__##_repr())->assert_is_op_##io__("gru", #arg__);
NEW_NODE
(
gru
,
Weight
,
input
);
NEW_NODE
(
Weight
,
input
);
// TODO(Superjomn): upgrade the fuse framework to support optional.
// TODO(Superjomn): upgrade the fuse framework to support optional.
// H0 and bias are optional
// H0 and bias are optional
NEW_NODE
(
gru
,
Bias
,
input
);
// also optional
NEW_NODE
(
Bias
,
input
);
// also optional
// NEW_NODE(H0, input);
// NEW_NODE(H0, input);
NEW_NODE
(
gru
,
Hidden
,
output
);
NEW_NODE
(
Hidden
,
output
);
// below are intermediate
// below are intermediate
NEW_NODE
(
gru
,
BatchGate
,
output
);
NEW_NODE
(
BatchGate
,
output
);
NEW_NODE
(
gru
,
BatchResetHiddenPrev
,
output
);
NEW_NODE
(
BatchResetHiddenPrev
,
output
);
NEW_NODE
(
gru
,
BatchHidden
,
output
);
NEW_NODE
(
BatchHidden
,
output
);
#undef NEW_NODE
BatchGate
->
AsIntermediate
();
BatchGate
->
AsIntermediate
();
BatchResetHiddenPrev
->
AsIntermediate
();
BatchResetHiddenPrev
->
AsIntermediate
();
...
@@ -607,7 +609,6 @@ PDNode* patterns::GRU(PDPattern* pattern, const std::string& name_scope,
...
@@ -607,7 +609,6 @@ PDNode* patterns::GRU(PDPattern* pattern, const std::string& name_scope,
gru_op
->
LinksTo
({
Hidden
,
BatchGate
,
BatchResetHiddenPrev
,
BatchHidden
});
gru_op
->
LinksTo
({
Hidden
,
BatchGate
,
BatchResetHiddenPrev
,
BatchHidden
});
return
Hidden
;
return
Hidden
;
}
}
#undef NEW_NODE
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
...
...
paddle/fluid/framework/ir/graph_pattern_detector.h
浏览文件 @
5558784c
...
@@ -286,22 +286,148 @@ void GraphSafeRemoveNodes(Graph* graph,
...
@@ -286,22 +286,148 @@ void GraphSafeRemoveNodes(Graph* graph,
const
std
::
unordered_set
<
const
Node
*>&
nodes
);
const
std
::
unordered_set
<
const
Node
*>&
nodes
);
// Some pre-defined patterns those can be reused in multiple passes.
// Some pre-defined patterns those can be reused in multiple passes.
// The related Fluid Layer or Op should be one pattern here for better reusage
// accross different fusion.
namespace
patterns
{
namespace
patterns
{
struct
KeyCounter
{
static
KeyCounter
&
Instance
()
{
static
KeyCounter
x
;
return
x
;
}
int
IncCounter
(
const
std
::
string
&
key
)
{
return
dic_
[
key
]
++
;
}
private:
std
::
unordered_map
<
std
::
string
,
size_t
>
dic_
;
};
// Generate a unique PDNode's name with name_scope and id.
// The format is {name_scope}/{repr}/{id}/{name}
static
std
::
string
PDNodeName
(
const
std
::
string
&
name_scope
,
const
std
::
string
&
repr
,
size_t
id
,
const
std
::
string
&
name
)
{
return
string
::
Sprintf
(
"%s/%s/%d/%s"
,
name_scope
,
repr
,
id
,
name
);
}
// Generate a unique PDNode's name.
// The format is {name_scope}/{repr}/{id}
static
std
::
string
PDNodeName
(
const
std
::
string
&
name_scope
,
const
std
::
string
&
repr
)
{
return
string
::
Sprintf
(
"%s/%s/%d"
,
name_scope
,
repr
,
KeyCounter
::
Instance
().
IncCounter
(
repr
));
}
// Generate a unique key. It can be used for a universally unique temporary
// name.
// The format is {repr}/{id}
static
std
::
string
UniqueKey
(
const
std
::
string
&
repr
)
{
return
string
::
Sprintf
(
"%s/%d"
,
repr
,
KeyCounter
::
Instance
().
IncCounter
(
repr
));
}
// Declare a PDNode in a pattern, will create two methods:
// std::string xxx_repr(); return this PDNode's string id.
// PDNode* xxx_n(); return the corresponding PDNode.
#define PATTERN_DECL_NODE(name__) \
std::string name__##_repr() const { \
return PDNodeName(name_scope_, repr_, id_, #name__); \
} \
PDNode* name__##_n() const { return pattern->RetrieveNode(name__##_repr()); }
// Get an ir::Node* from the matched subgraph.
// var: variable.
// arg: the argument declared by PATTERN_DECL_NODE in a pattern definition.
// pat: the pattern object.
#define GET_IR_NODE_FROM_SUBGRAPH(var, arg, pat) \
PADDLE_ENFORCE(subgraph.count(pat.arg##_n()), \
"Node not found for PDNode %s", pat.arg##_repr()); \
Node* var = subgraph.at(pat.arg##_n()); \
PADDLE_ENFORCE(var, "node %s not exists in the sub-graph", #arg)
// The base class of all the patterns.
struct
PatternBase
{
PatternBase
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
,
const
std
::
string
&
repr
)
:
pattern
(
pattern
),
name_scope_
(
name_scope
),
repr_
(
repr
),
id_
(
KeyCounter
::
Instance
().
IncCounter
(
repr
))
{}
PDPattern
*
pattern
;
protected:
std
::
string
name_scope_
;
std
::
string
repr_
;
size_t
id_
;
};
// FC with bias
// FC with bias
// op: mul + elementwise_add
// op: mul + elementwise_add
// named nodes:
// named nodes:
// mul, elementwise_add
// mul, elementwise_add
// w, mul_out, bias, fc_out
// w, mul_out, bias, fc_out
PDNode
*
FC
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
,
PDNode
*
x
,
struct
FC
:
public
PatternBase
{
bool
with_bias
);
FC
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"fc"
)
{}
PDNode
*
operator
()(
PDNode
*
x
,
bool
with_bias
);
// declare operator node's name
PATTERN_DECL_NODE
(
fc
);
PATTERN_DECL_NODE
(
mul
);
PATTERN_DECL_NODE
(
elementwise_add
);
// declare variable node's name
PATTERN_DECL_NODE
(
w
);
PATTERN_DECL_NODE
(
mul_out
);
// (x,w) -> mul_out
PATTERN_DECL_NODE
(
bias
);
PATTERN_DECL_NODE
(
Out
);
};
struct
LSTM
:
public
PatternBase
{
LSTM
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"lstm"
)
{}
PDNode
*
LSTM
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
,
PDNode
*
x
);
PDNode
*
operator
()(
PDNode
*
x
);
PDNode
*
GRU
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
,
PDNode
*
x
);
// Operators
PATTERN_DECL_NODE
(
lstm
);
// Inputs
PATTERN_DECL_NODE
(
Input
);
PATTERN_DECL_NODE
(
H0
);
PATTERN_DECL_NODE
(
C0
);
PATTERN_DECL_NODE
(
Weight
);
PATTERN_DECL_NODE
(
Bias
);
// Outputs
PATTERN_DECL_NODE
(
Hidden
);
PATTERN_DECL_NODE
(
Cell
);
PATTERN_DECL_NODE
(
BatchGate
);
PATTERN_DECL_NODE
(
BatchCellPreAct
);
};
struct
GRU
:
public
PatternBase
{
GRU
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"lstm"
)
{}
PDNode
*
operator
()(
PDNode
*
x
);
// Operators
PATTERN_DECL_NODE
(
gru
);
// Inputs
PATTERN_DECL_NODE
(
Bias
);
PATTERN_DECL_NODE
(
Weight
);
// Outputs
PATTERN_DECL_NODE
(
BatchGate
);
PATTERN_DECL_NODE
(
BatchResetHiddenPrev
);
PATTERN_DECL_NODE
(
BatchHidden
);
PATTERN_DECL_NODE
(
Hidden
);
};
}
// namespace patterns
}
// namespace patterns
// Link two ir::Nodes from each other.
#define IR_NODE_LINK_TO(a, b) \
#define IR_NODE_LINK_TO(a, b) \
a->outputs.push_back(b); \
a->outputs.push_back(b); \
b->inputs.push_back(a);
b->inputs.push_back(a);
...
...
paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc
浏览文件 @
5558784c
...
@@ -192,6 +192,8 @@ std::unique_ptr<ir::Graph> SeqConcatFcFusePass::ApplyImpl(
...
@@ -192,6 +192,8 @@ std::unique_ptr<ir::Graph> SeqConcatFcFusePass::ApplyImpl(
auto* id = subgraph.at(pattern.RetrieveNode(#id)); \
auto* id = subgraph.at(pattern.RetrieveNode(#id)); \
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
int
fuse_count
{
0
};
detector
(
graph
.
get
(),
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
detector
(
graph
.
get
(),
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
Graph
*
graph
)
{
VLOG
(
4
)
<<
"get one concat pattern"
;
VLOG
(
4
)
<<
"get one concat pattern"
;
...
@@ -239,8 +241,12 @@ std::unique_ptr<ir::Graph> SeqConcatFcFusePass::ApplyImpl(
...
@@ -239,8 +241,12 @@ std::unique_ptr<ir::Graph> SeqConcatFcFusePass::ApplyImpl(
marked_nodes
.
erase
(
sequence_expand1_in
);
marked_nodes
.
erase
(
sequence_expand1_in
);
marked_nodes
.
erase
(
fc_out
);
marked_nodes
.
erase
(
fc_out
);
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
++
fuse_count
;
});
});
AddStatis
(
fuse_count
);
return
graph
;
return
graph
;
}
}
...
...
paddle/fluid/inference/analysis/CMakeLists.txt
浏览文件 @
5558784c
...
@@ -48,18 +48,18 @@ function (inference_download_and_uncompress install_dir url gz_filename)
...
@@ -48,18 +48,18 @@ function (inference_download_and_uncompress install_dir url gz_filename)
message
(
STATUS
"finish downloading
${
gz_filename
}
"
)
message
(
STATUS
"finish downloading
${
gz_filename
}
"
)
endfunction
(
inference_download_and_uncompress
)
endfunction
(
inference_download_and_uncompress
)
set
(
DITU_RNN_MODEL_URL
"http://paddle-inference-dist.bj.bcebos.com/ditu_rnn_fluid
%2Fmodel.tar.gz"
)
set
(
RNN1_MODEL_URL
"http://paddle-inference-dist.bj.bcebos.com/rnn1
%2Fmodel.tar.gz"
)
set
(
DITU_RNN_DATA_URL
"http://paddle-inference-dist.bj.bcebos.com/ditu_rnn_fluid
%2Fdata.txt.tar.gz"
)
set
(
RNN1_DATA_URL
"http://paddle-inference-dist.bj.bcebos.com/rnn1
%2Fdata.txt.tar.gz"
)
set
(
DITU_INSTALL_DIR
"
${
THIRD_PARTY_PATH
}
/inference_demo/ditu_rnn"
CACHE PATH
"Ditu RNN
model and data root."
FORCE
)
set
(
RNN1_INSTALL_DIR
"
${
THIRD_PARTY_PATH
}
/inference_demo/rnn1"
CACHE PATH
"RNN1
model and data root."
FORCE
)
if
(
NOT EXISTS
${
DITU
_INSTALL_DIR
}
AND WITH_TESTING
)
if
(
NOT EXISTS
${
RNN1
_INSTALL_DIR
}
AND WITH_TESTING
)
inference_download_and_uncompress
(
${
DITU_INSTALL_DIR
}
${
DITU_RNN_MODEL_URL
}
"ditu_rnn_fluid
%2Fmodel.tar.gz"
)
inference_download_and_uncompress
(
${
RNN1_INSTALL_DIR
}
${
RNN1_MODEL_URL
}
"rnn1
%2Fmodel.tar.gz"
)
inference_download_and_uncompress
(
${
DITU_INSTALL_DIR
}
${
DITU_RNN_DATA_URL
}
"ditu_rnn_fluid
%2Fdata.txt.tar.gz"
)
inference_download_and_uncompress
(
${
RNN1_INSTALL_DIR
}
${
RNN1_DATA_URL
}
"rnn1
%2Fdata.txt.tar.gz"
)
endif
()
endif
()
inference_analysis_test
(
test_analyzer SRCS analyzer_tester.cc
inference_analysis_test
(
test_analyzer SRCS analyzer_tester.cc
EXTRA_DEPS paddle_inference_api paddle_fluid_api ir_pass_manager analysis_predictor
EXTRA_DEPS paddle_inference_api paddle_fluid_api ir_pass_manager analysis_predictor
ARGS --infer_
ditu_rnn_model=
${
DITU
_INSTALL_DIR
}
/model
ARGS --infer_
model=
${
RNN1
_INSTALL_DIR
}
/model
--infer_d
itu_rnn_data=
${
DITU
_INSTALL_DIR
}
/data.txt
)
--infer_d
ata=
${
RNN1
_INSTALL_DIR
}
/data.txt
)
inference_analysis_test
(
test_data_flow_graph SRCS data_flow_graph_tester.cc
)
inference_analysis_test
(
test_data_flow_graph SRCS data_flow_graph_tester.cc
)
inference_analysis_test
(
test_data_flow_graph_to_fluid_pass SRCS data_flow_graph_to_fluid_pass_tester.cc
)
inference_analysis_test
(
test_data_flow_graph_to_fluid_pass SRCS data_flow_graph_to_fluid_pass_tester.cc
)
...
...
paddle/fluid/inference/analysis/analyzer_tester.cc
浏览文件 @
5558784c
...
@@ -26,8 +26,8 @@
...
@@ -26,8 +26,8 @@
#include "paddle/fluid/inference/api/paddle_inference_pass.h"
#include "paddle/fluid/inference/api/paddle_inference_pass.h"
#include "paddle/fluid/inference/utils/singleton.h"
#include "paddle/fluid/inference/utils/singleton.h"
DEFINE_string
(
infer_
ditu_rnn_model
,
""
,
"model path for ditu RNN
"
);
DEFINE_string
(
infer_
model
,
""
,
"model path
"
);
DEFINE_string
(
infer_d
itu_rnn_data
,
""
,
"data path for ditu RNN
"
);
DEFINE_string
(
infer_d
ata
,
""
,
"data path
"
);
DEFINE_int32
(
batch_size
,
10
,
"batch size."
);
DEFINE_int32
(
batch_size
,
10
,
"batch size."
);
DEFINE_int32
(
repeat
,
1
,
"Running the inference program repeat times."
);
DEFINE_int32
(
repeat
,
1
,
"Running the inference program repeat times."
);
DEFINE_int32
(
num_threads
,
1
,
"Running the inference program in multi-threads."
);
DEFINE_int32
(
num_threads
,
1
,
"Running the inference program in multi-threads."
);
...
@@ -223,17 +223,6 @@ void PrepareInputs(std::vector<PaddleTensor> *input_slots, DataRecord *data,
...
@@ -223,17 +223,6 @@ void PrepareInputs(std::vector<PaddleTensor> *input_slots, DataRecord *data,
}
// namespace
}
// namespace
const
float
ditu_rnn_target_data
[]
=
{
104.711
,
11.2431
,
1.35422
,
0
,
0
,
0
,
0
,
0
,
27.7039
,
1.41486
,
7.09526
,
0
,
0
,
0
,
0
,
0
,
7.6481
,
6.5324
,
56.383
,
2.88018
,
8.92918
,
132.007
,
4.27429
,
2.02934
,
14.1727
,
10.7461
,
25.0616
,
16.0197
,
14.4163
,
16.9199
,
6.75517
,
0
,
80.0249
,
4.77739
,
0
,
0
,
0
,
0
,
0
,
0
,
47.5643
,
2.67029
,
8.76252
,
0
,
0
,
0
,
0
,
0
,
51.8822
,
4.4411
,
0
,
0
,
0
,
0
,
0
,
0
,
10.7286
,
12.0595
,
10.6672
,
0
,
0
,
0
,
0
,
0
,
93.5771
,
3.84641
,
0
,
0
,
0
,
0
,
0
,
0
,
169.426
,
0
,
0
,
0
,
0
,
0
,
0
,
0
};
void
CompareResult
(
const
std
::
vector
<
PaddleTensor
>
&
outputs
,
void
CompareResult
(
const
std
::
vector
<
PaddleTensor
>
&
outputs
,
const
std
::
vector
<
PaddleTensor
>
&
base_outputs
)
{
const
std
::
vector
<
PaddleTensor
>
&
base_outputs
)
{
PADDLE_ENFORCE_GT
(
outputs
.
size
(),
0
);
PADDLE_ENFORCE_GT
(
outputs
.
size
(),
0
);
...
@@ -255,11 +244,10 @@ void CompareResult(const std::vector<PaddleTensor> &outputs,
...
@@ -255,11 +244,10 @@ void CompareResult(const std::vector<PaddleTensor> &outputs,
}
}
}
}
// Test with a really complicate model.
// Test with a really complicate model.
void
TestDituRNNPrediction
(
bool
use_analysis
,
bool
activate_ir
,
void
TestRNN1Prediction
(
bool
use_analysis
,
bool
activate_ir
,
int
num_threads
)
{
int
num_threads
)
{
AnalysisConfig
config
;
AnalysisConfig
config
;
config
.
prog_file
=
FLAGS_infer_
ditu_rnn_
model
+
"/__model__"
;
config
.
prog_file
=
FLAGS_infer_model
+
"/__model__"
;
config
.
param_file
=
FLAGS_infer_
ditu_rnn_
model
+
"/param"
;
config
.
param_file
=
FLAGS_infer_model
+
"/param"
;
config
.
use_gpu
=
false
;
config
.
use_gpu
=
false
;
config
.
device
=
0
;
config
.
device
=
0
;
config
.
specify_input_name
=
true
;
config
.
specify_input_name
=
true
;
...
@@ -267,6 +255,7 @@ void TestDituRNNPrediction(bool use_analysis, bool activate_ir,
...
@@ -267,6 +255,7 @@ void TestDituRNNPrediction(bool use_analysis, bool activate_ir,
PADDLE_ENFORCE
(
config
.
ir_mode
==
PADDLE_ENFORCE
(
config
.
ir_mode
==
AnalysisConfig
::
IrPassMode
::
kExclude
);
// default
AnalysisConfig
::
IrPassMode
::
kExclude
);
// default
config
.
ir_passes
.
clear
();
// Do not exclude any pass.
config
.
ir_passes
.
clear
();
// Do not exclude any pass.
int
batch_size
=
FLAGS_batch_size
;
int
batch_size
=
FLAGS_batch_size
;
int
num_times
=
FLAGS_repeat
;
int
num_times
=
FLAGS_repeat
;
...
@@ -276,7 +265,7 @@ void TestDituRNNPrediction(bool use_analysis, bool activate_ir,
...
@@ -276,7 +265,7 @@ void TestDituRNNPrediction(bool use_analysis, bool activate_ir,
CreatePaddlePredictor
<
AnalysisConfig
,
PaddleEngineKind
::
kAnalysis
>
(
CreatePaddlePredictor
<
AnalysisConfig
,
PaddleEngineKind
::
kAnalysis
>
(
config
);
config
);
std
::
vector
<
PaddleTensor
>
input_slots
;
std
::
vector
<
PaddleTensor
>
input_slots
;
DataRecord
data
(
FLAGS_infer_d
itu_rnn_d
ata
,
batch_size
);
DataRecord
data
(
FLAGS_infer_data
,
batch_size
);
// Prepare inputs.
// Prepare inputs.
PrepareInputs
(
&
input_slots
,
&
data
,
batch_size
);
PrepareInputs
(
&
input_slots
,
&
data
,
batch_size
);
std
::
vector
<
PaddleTensor
>
outputs
,
base_outputs
;
std
::
vector
<
PaddleTensor
>
outputs
,
base_outputs
;
...
@@ -306,7 +295,7 @@ void TestDituRNNPrediction(bool use_analysis, bool activate_ir,
...
@@ -306,7 +295,7 @@ void TestDituRNNPrediction(bool use_analysis, bool activate_ir,
threads
.
emplace_back
([
&
,
tid
]()
{
threads
.
emplace_back
([
&
,
tid
]()
{
// Each thread should have local input_slots and outputs.
// Each thread should have local input_slots and outputs.
std
::
vector
<
PaddleTensor
>
input_slots
;
std
::
vector
<
PaddleTensor
>
input_slots
;
DataRecord
data
(
FLAGS_infer_d
itu_rnn_d
ata
,
batch_size
);
DataRecord
data
(
FLAGS_infer_data
,
batch_size
);
PrepareInputs
(
&
input_slots
,
&
data
,
batch_size
);
PrepareInputs
(
&
input_slots
,
&
data
,
batch_size
);
std
::
vector
<
PaddleTensor
>
outputs
;
std
::
vector
<
PaddleTensor
>
outputs
;
Timer
timer
;
Timer
timer
;
...
@@ -346,30 +335,29 @@ void TestDituRNNPrediction(bool use_analysis, bool activate_ir,
...
@@ -346,30 +335,29 @@ void TestDituRNNPrediction(bool use_analysis, bool activate_ir,
ASSERT_TRUE
(
fuse_statis
.
count
(
"fc_fuse"
));
ASSERT_TRUE
(
fuse_statis
.
count
(
"fc_fuse"
));
EXPECT_EQ
(
fuse_statis
.
at
(
"fc_fuse"
),
1
);
EXPECT_EQ
(
fuse_statis
.
at
(
"fc_fuse"
),
1
);
EXPECT_EQ
(
fuse_statis
.
at
(
"fc_nobias_lstm_fuse"
),
2
);
// bi-directional LSTM
EXPECT_EQ
(
fuse_statis
.
at
(
"fc_nobias_lstm_fuse"
),
2
);
// bi-directional LSTM
EXPECT_EQ
(
fuse_statis
.
at
(
"seq_concat_fc_fuse"
),
1
);
EXPECT_EQ
(
num_ops
,
EXPECT_EQ
(
num_ops
,
13
);
// After graph optimization, only 13 operators exists.
13
);
// After graph optimization, only 13 operators exists.
}
}
}
}
// Inference with analysis and IR, easy for profiling independently.
// Inference with analysis and IR, easy for profiling independently.
TEST
(
Analyzer
,
DituRNN
)
{
TEST
(
Analyzer
,
rnn1
)
{
TestRNN1Prediction
(
true
,
true
,
FLAGS_num_threads
);
}
TestDituRNNPrediction
(
true
,
true
,
FLAGS_num_threads
);
}
// Other unit-tests of
DituRNN
, test different options of use_analysis,
// Other unit-tests of
RNN1
, test different options of use_analysis,
// activate_ir and multi-threads.
// activate_ir and multi-threads.
TEST
(
Analyzer
,
Ditu
RNN_tests
)
{
TEST
(
Analyzer
,
RNN_tests
)
{
int
num_threads
[
2
]
=
{
1
,
4
};
int
num_threads
[
2
]
=
{
1
,
4
};
for
(
auto
i
:
num_threads
)
{
for
(
auto
i
:
num_threads
)
{
// Directly infer with the original model.
// Directly infer with the original model.
Test
DituRNN
Prediction
(
false
,
false
,
i
);
Test
RNN1
Prediction
(
false
,
false
,
i
);
// Inference with the original model with the analysis turned on, the
// Inference with the original model with the analysis turned on, the
// analysis
// analysis
// module will transform the program to a data flow graph.
// module will transform the program to a data flow graph.
Test
DituRNN
Prediction
(
true
,
false
,
i
);
Test
RNN1
Prediction
(
true
,
false
,
i
);
// Inference with analysis and IR. The IR module will fuse some large
// Inference with analysis and IR. The IR module will fuse some large
// kernels.
// kernels.
Test
DituRNN
Prediction
(
true
,
true
,
i
);
Test
RNN1
Prediction
(
true
,
true
,
i
);
}
}
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录