Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
597b7305
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
597b7305
编写于
9月 03, 2018
作者:
Y
Yan Chunwei
提交者:
GitHub
9月 03, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine/fc lstm fusion link (#13158)
上级
1e7ccf9f
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
312 addition
and
106 deletion
+312
-106
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
+105
-45
paddle/fluid/framework/ir/fc_lstm_fuse_pass.h
paddle/fluid/framework/ir/fc_lstm_fuse_pass.h
+16
-2
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+148
-4
paddle/fluid/framework/ir/graph_pattern_detector.h
paddle/fluid/framework/ir/graph_pattern_detector.h
+34
-55
paddle/fluid/framework/ir/graph_viz_pass.h
paddle/fluid/framework/ir/graph_viz_pass.h
+7
-0
paddle/fluid/inference/analysis/analyzer.cc
paddle/fluid/inference/analysis/analyzer.cc
+1
-0
paddle/fluid/inference/analysis/analyzer_tester.cc
paddle/fluid/inference/analysis/analyzer_tester.cc
+1
-0
未找到文件。
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
浏览文件 @
597b7305
...
@@ -13,37 +13,37 @@
...
@@ -13,37 +13,37 @@
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h"
#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h"
#include "paddle/fluid/framework/lod_tensor.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
ir
{
namespace
ir
{
std
::
unique_ptr
<
ir
::
Graph
>
FCLstmFusePass
::
ApplyImpl
(
std
::
string
GenNodeName
(
const
std
::
string
&
prefix
,
const
std
::
string
&
name
)
{
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
return
prefix
+
"/"
+
name
;
GraphPatternDetector
gpd
;
}
auto
*
pattern
=
gpd
.
mutable_pattern
();
std
::
unordered_set
<
int
>
fused_ops
({
// first lstm
13
,
15
,
16
,
// second lstm
23
,
25
,
26
});
pattern
->
NewNode
([
&
](
Node
*
x
)
{
return
fused_ops
.
count
(
x
->
id
());
},
"any_node"
);
std
::
unordered_set
<
Node
*>
marked_nodes
;
void
BuildPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
,
bool
with_fc_bias
)
{
PDNode
*
x
=
pattern
->
NewNode
(
name_scope
,
"x"
)
->
assert_is_op_input
(
"mul"
)
->
assert_var_not_persistable
();
auto
*
fc_out
=
patterns
::
FC
(
pattern
,
name_scope
,
x
,
with_fc_bias
);
fc_out
->
AsIntermediate
();
// fc_out is a tmp var, will be removed after fuse.
patterns
::
LSTM
(
pattern
,
name_scope
,
fc_out
);
// LOG(INFO) << "\n" << pattern->DotString();
}
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
int
BuildFusion
(
Graph
*
graph
,
const
std
::
string
&
name_scope
,
Scope
*
scope
,
Graph
*
g
)
{
bool
with_fc_bias
)
{
GraphPatternDetector
gpd
;
auto
*
pattern
=
gpd
.
mutable_pattern
();
auto
*
id
=
subgraph
.
at
(
gpd
.
pattern
().
RetrieveNode
(
"any_node"
));
BuildPattern
(
pattern
,
name_scope
,
with_fc_bias
);
marked_nodes
.
insert
(
id
);
};
gpd
(
graph
.
get
(),
handler
);
// Create New OpDesc
// Create New OpDesc
auto
lstm_creator
=
[
&
](
int
lstm
,
int
input
,
int
weight_x
,
int
weight_h
,
auto
lstm_creator
=
[
&
](
int
lstm
,
int
input
,
int
weight_x
,
int
weight_h
,
int
bias
,
int
hidden
,
int
cell
,
int
xx
)
{
int
bias
,
int
hidden
,
int
cell
,
int
xx
,
int
fc_bias
)
{
#define GET_NODE(x) auto* x##_n = graph->RetriveNode(x);
#define GET_NODE(x) auto* x##_n = graph->RetriveNode(x);
GET_NODE
(
input
);
GET_NODE
(
input
);
GET_NODE
(
weight_x
);
GET_NODE
(
weight_x
);
...
@@ -61,12 +61,33 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
...
@@ -61,12 +61,33 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
SET_IN
(
WeightX
,
weight_x
);
SET_IN
(
WeightX
,
weight_x
);
SET_IN
(
WeightH
,
weight_h
);
SET_IN
(
WeightH
,
weight_h
);
SET_IN
(
Bias
,
bias
);
SET_IN
(
Bias
,
bias
);
#undef GET_NODE
#undef SET_IN
#undef SET_IN
if
(
with_fc_bias
)
{
// Add FC-bias with LSTM-bias and create a new weight
PADDLE_ENFORCE
(
scope
);
const
std
::
string
&
new_bias_var
=
name_scope
+
"_bias.new"
;
auto
*
bias_var
=
scope
->
Var
(
new_bias_var
);
PADDLE_ENFORCE
(
bias_var
);
auto
*
bias_tensor
=
bias_var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
lstm_bias_var
=
scope
->
FindVar
(
bias_n
->
Name
());
PADDLE_ENFORCE
(
lstm_bias_var
);
const
auto
&
lstm_bias_tensor
=
lstm_bias_var
->
Get
<
framework
::
LoDTensor
>
();
bias_tensor
->
Resize
(
lstm_bias_tensor
.
dims
());
GET_NODE
(
fc_bias
);
auto
*
fc_bias_var
=
scope
->
FindVar
(
fc_bias_n
->
Name
());
const
auto
&
fc_bias_tensor
=
fc_bias_var
->
Get
<
framework
::
LoDTensor
>
();
auto
*
data
=
bias_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
for
(
int
i
=
0
;
i
<
bias_tensor
->
numel
();
i
++
)
{
data
[
i
]
=
fc_bias_tensor
.
data
<
float
>
()[
i
]
+
lstm_bias_tensor
.
data
<
float
>
()[
i
];
}
op_desc
.
SetInput
(
"Bias"
,
{
new_bias_var
});
}
VLOG
(
4
)
<<
"hidden_n: "
<<
hidden_n
->
Name
();
#undef GET_NODE
VLOG
(
4
)
<<
"cell: "
<<
cell_n
->
Name
();
VLOG
(
4
)
<<
"xx: "
<<
xx_n
->
Name
();
op_desc
.
SetInput
(
"H0"
,
{});
op_desc
.
SetInput
(
"H0"
,
{});
op_desc
.
SetInput
(
"C0"
,
{});
op_desc
.
SetInput
(
"C0"
,
{});
...
@@ -76,7 +97,7 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
...
@@ -76,7 +97,7 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
op_desc
.
SetOutput
(
"BatchedGate"
,
{
"blstm_0.tmp_2"
});
op_desc
.
SetOutput
(
"BatchedGate"
,
{
"blstm_0.tmp_2"
});
op_desc
.
SetOutput
(
"BatchCellPreAct"
,
{
"blstm_1.tmp_2"
});
op_desc
.
SetOutput
(
"BatchCellPreAct"
,
{
"blstm_1.tmp_2"
});
op_desc
.
SetAttr
(
"is_reverse"
,
lstm_n
->
Op
()
->
GetAttr
(
"is_reverse"
));
op_desc
.
SetAttr
(
"is_reverse"
,
lstm_n
->
Op
()
->
GetAttr
(
"is_reverse"
));
op_desc
.
SetAttr
(
"use_peepholes"
,
false
);
op_desc
.
SetAttr
(
"use_peepholes"
,
lstm_n
->
Op
()
->
GetAttr
(
"use_peepholes"
)
);
auto
*
op
=
graph
->
CreateOpNode
(
&
op_desc
);
auto
*
op
=
graph
->
CreateOpNode
(
&
op_desc
);
#define LINK_TO(a, b) \
#define LINK_TO(a, b) \
...
@@ -89,33 +110,71 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
...
@@ -89,33 +110,71 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
LINK_TO
(
op
,
hidden_n
);
LINK_TO
(
op
,
hidden_n
);
#undef LINK_TO
#undef LINK_TO
return
op
;
return
op
;
};
};
lstm_creator
(
16
,
12
,
14
,
18
,
17
,
22
,
21
,
19
);
int
fusion_count
{
0
};
lstm_creator
(
26
,
12
,
24
,
28
,
27
,
32
,
31
,
29
);
// remove all the nodes
auto
fc_no_bias_handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
for
(
auto
*
node
:
marked_nodes
)
{
#define GET_NODE(name__) \
graph
->
RemoveNode
(
const_cast
<
Node
*>
(
node
));
std::string name__##key = name_scope + "/" + #name__; \
}
auto* name__##n = pattern->RetrieveNode(name__##key); \
PADDLE_ENFORCE(name__##n); \
PADDLE_ENFORCE(subgraph.count(name__##n)); \
Node* name__##_n = subgraph.at(name__##n); \
int name__ __attribute__((unused)) = name__##_n->id();
for
(
auto
*
node
:
graph
->
Nodes
())
{
GET_NODE
(
x
);
for
(
auto
it
=
node
->
inputs
.
begin
();
it
!=
node
->
inputs
.
end
();)
{
GET_NODE
(
w
);
if
(
marked_nodes
.
count
(
*
it
))
{
GET_NODE
(
mul
);
it
=
const_cast
<
Node
*>
(
node
)
->
inputs
.
erase
(
it
);
GET_NODE
(
fc_out
);
}
else
GET_NODE
(
Weight
);
it
++
;
GET_NODE
(
lstm
);
}
GET_NODE
(
Bias
);
for
(
auto
it
=
node
->
outputs
.
begin
();
it
!=
node
->
outputs
.
end
();)
{
GET_NODE
(
Hidden
);
if
(
marked_nodes
.
count
(
*
it
))
{
GET_NODE
(
Cell
);
it
=
const_cast
<
Node
*>
(
node
)
->
outputs
.
erase
(
it
);
}
else
if
(
with_fc_bias
)
{
it
++
;
GET_NODE
(
fc_bias
);
lstm_creator
(
lstm
,
x
,
w
,
Weight
,
Bias
,
Hidden
,
Cell
,
fc_out
,
fc_bias
);
}
else
{
lstm_creator
(
lstm
,
x
,
w
,
Weight
,
Bias
,
Hidden
,
Cell
,
fc_out
,
-
1
);
}
}
}
#undef GET_NODE
// Remove unneeded nodes.
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
mul_n
,
lstm_n
});
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
++
fusion_count
;
};
gpd
(
graph
,
fc_no_bias_handler
);
return
fusion_count
;
}
std
::
unique_ptr
<
ir
::
Graph
>
MulLstmFusePass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
FusePassBase
::
Init
(
name_scope_
,
graph
.
get
());
int
fusion_count
=
BuildFusion
(
graph
.
get
(),
name_scope_
,
param_scope
(),
false
/*with_fc_bias*/
);
AddStatis
(
fusion_count
);
return
graph
;
}
std
::
unique_ptr
<
ir
::
Graph
>
FCLstmFusePass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
FusePassBase
::
Init
(
name_scope_
,
graph
.
get
());
int
fusion_count
=
BuildFusion
(
graph
.
get
(),
name_scope_
,
param_scope
(),
true
/*with_fc_bias*/
);
AddStatis
(
fusion_count
);
return
graph
;
return
graph
;
}
}
...
@@ -123,4 +182,5 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
...
@@ -123,4 +182,5 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
REGISTER_PASS
(
mul_lstm_fuse_pass
,
paddle
::
framework
::
ir
::
MulLstmFusePass
);
REGISTER_PASS
(
fc_lstm_fuse_pass
,
paddle
::
framework
::
ir
::
FCLstmFusePass
);
REGISTER_PASS
(
fc_lstm_fuse_pass
,
paddle
::
framework
::
ir
::
FCLstmFusePass
);
paddle/fluid/framework/ir/fc_lstm_fuse_pass.h
浏览文件 @
597b7305
...
@@ -12,20 +12,34 @@
...
@@ -12,20 +12,34 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
ir
{
namespace
ir
{
class
FCLstmFusePass
:
public
Pass
{
// The MulLstmFusePass and MulLstmFusePass will fuse to the same FusionLstm op.
// Just FC without bias
class
FCLstmFusePass
:
public
FusePassBase
{
public:
public:
virtual
~
FCLstmFusePass
()
{}
virtual
~
FCLstmFusePass
()
{}
protected:
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
;
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
;
const
std
::
string
name_scope_
{
"fc_lstm_fuse"
};
};
class
MulLstmFusePass
:
public
FusePassBase
{
public:
virtual
~
MulLstmFusePass
()
{}
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
;
const
std
::
string
name_scope_
{
"fc_nobias_lstm_fuse"
};
};
};
}
// namespace ir
}
// namespace ir
...
...
paddle/fluid/framework/ir/graph_pattern_detector.cc
浏览文件 @
597b7305
...
@@ -19,6 +19,7 @@
...
@@ -19,6 +19,7 @@
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -71,7 +72,11 @@ void PDPattern::AddEdge(PDNode* a, PDNode* b) {
...
@@ -71,7 +72,11 @@ void PDPattern::AddEdge(PDNode* a, PDNode* b) {
void
GraphPatternDetector
::
operator
()(
Graph
*
graph
,
void
GraphPatternDetector
::
operator
()(
Graph
*
graph
,
GraphPatternDetector
::
handle_t
handler
)
{
GraphPatternDetector
::
handle_t
handler
)
{
if
(
!
MarkPDNodesInGraph
(
*
graph
))
return
;
if
(
!
MarkPDNodesInGraph
(
*
graph
))
{
LOG
(
INFO
)
<<
"Mark failed"
;
return
;
}
auto
subgraphs
=
DetectPatterns
();
auto
subgraphs
=
DetectPatterns
();
UniquePatterns
(
&
subgraphs
);
UniquePatterns
(
&
subgraphs
);
RemoveOverlappedMatch
(
&
subgraphs
);
RemoveOverlappedMatch
(
&
subgraphs
);
...
@@ -87,7 +92,7 @@ void GraphPatternDetector::operator()(Graph* graph,
...
@@ -87,7 +92,7 @@ void GraphPatternDetector::operator()(Graph* graph,
}
}
bool
GraphPatternDetector
::
MarkPDNodesInGraph
(
const
ir
::
Graph
&
graph
)
{
bool
GraphPatternDetector
::
MarkPDNodesInGraph
(
const
ir
::
Graph
&
graph
)
{
VLOG
(
4
)
<<
"mark pdnodes in graph"
;
VLOG
(
3
)
<<
"mark pdnodes in graph"
;
if
(
graph
.
Nodes
().
empty
())
return
false
;
if
(
graph
.
Nodes
().
empty
())
return
false
;
for
(
auto
&
node
:
GraphTraits
::
DFS
(
graph
))
{
for
(
auto
&
node
:
GraphTraits
::
DFS
(
graph
))
{
...
@@ -107,6 +112,7 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) {
...
@@ -107,6 +112,7 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) {
}
}
}
}
VLOG
(
3
)
<<
pdnodes2nodes_
.
size
()
<<
" nodes marked"
;
VLOG
(
3
)
<<
pdnodes2nodes_
.
size
()
<<
" nodes marked"
;
return
!
pdnodes2nodes_
.
empty
();
return
!
pdnodes2nodes_
.
empty
();
}
}
...
@@ -357,7 +363,9 @@ PDNode* PDNode::assert_is_op_nth_input(const std::string& op_type,
...
@@ -357,7 +363,9 @@ PDNode* PDNode::assert_is_op_nth_input(const std::string& op_type,
assert_is_op_input
(
op_type
);
assert_is_op_input
(
op_type
);
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
for
(
auto
*
op
:
x
->
outputs
)
{
for
(
auto
*
op
:
x
->
outputs
)
{
if
(
IsNthInput
(
x
,
op
,
argument
,
nth
))
return
true
;
if
(
op
->
IsOp
()
&&
op
->
Op
()
->
Type
()
==
op_type
&&
IsNthInput
(
x
,
op
,
argument
,
nth
))
return
true
;
}
}
return
false
;
return
false
;
});
});
...
@@ -368,7 +376,9 @@ PDNode* PDNode::assert_is_op_nth_output(const std::string& op_type,
...
@@ -368,7 +376,9 @@ PDNode* PDNode::assert_is_op_nth_output(const std::string& op_type,
assert_is_var
();
assert_is_var
();
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
for
(
auto
*
op
:
x
->
inputs
)
{
for
(
auto
*
op
:
x
->
inputs
)
{
if
(
IsNthOutput
(
x
,
op
,
argument
,
nth
))
return
true
;
if
(
op
->
IsOp
()
&&
op
->
Op
()
->
Type
()
==
op_type
&&
IsNthOutput
(
x
,
op
,
argument
,
nth
))
return
true
;
}
}
return
false
;
return
false
;
});
});
...
@@ -412,6 +422,12 @@ PDNode* PDNode::assert_is_op_output(const std::string& op_type) {
...
@@ -412,6 +422,12 @@ PDNode* PDNode::assert_is_op_output(const std::string& op_type) {
});
});
return
this
;
return
this
;
}
}
PDNode
*
PDNode
::
assert_is_op_output
(
const
std
::
string
&
op_type
,
const
std
::
string
&
argument
)
{
assert_is_var
();
assert_is_op_nth_output
(
op_type
,
argument
,
0
);
return
this
;
}
PDNode
*
PDNode
::
assert_is_op_input
(
const
std
::
string
&
op_type
)
{
PDNode
*
PDNode
::
assert_is_op_input
(
const
std
::
string
&
op_type
)
{
assert_is_var
();
assert_is_var
();
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
...
@@ -424,6 +440,12 @@ PDNode* PDNode::assert_is_op_input(const std::string& op_type) {
...
@@ -424,6 +440,12 @@ PDNode* PDNode::assert_is_op_input(const std::string& op_type) {
});
});
return
this
;
return
this
;
}
}
PDNode
*
PDNode
::
assert_is_op_input
(
const
std
::
string
&
op_type
,
const
std
::
string
&
argument
)
{
assert_is_var
();
assert_is_op_nth_input
(
op_type
,
argument
,
0
);
return
this
;
}
PDNode
*
PDNode
::
assert_op_has_n_inputs
(
const
std
::
string
&
op_type
,
size_t
n
)
{
PDNode
*
PDNode
::
assert_op_has_n_inputs
(
const
std
::
string
&
op_type
,
size_t
n
)
{
assert_is_op
(
op_type
);
assert_is_op
(
op_type
);
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
return
x
->
inputs
.
size
()
==
n
;
});
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
return
x
->
inputs
.
size
()
==
n
;
});
...
@@ -439,6 +461,128 @@ PDNode* PDNode::assert_more(PDNode::teller_t&& teller) {
...
@@ -439,6 +461,128 @@ PDNode* PDNode::assert_more(PDNode::teller_t&& teller) {
return
this
;
return
this
;
}
}
bool
VarLinksToOp
(
Node
*
node
,
const
std
::
string
&
op_type
)
{
for
(
auto
*
out
:
node
->
outputs
)
{
if
(
out
->
IsOp
()
&&
out
->
Op
()
->
Type
()
==
op_type
)
{
return
true
;
}
}
return
false
;
}
bool
IsNthInput
(
Node
*
var
,
Node
*
op
,
const
std
::
string
&
argument
,
size_t
nth
)
{
PADDLE_ENFORCE
(
var
->
IsVar
());
PADDLE_ENFORCE
(
op
->
IsOp
());
if
(
op
->
Op
()
->
Input
(
argument
).
size
()
<=
nth
)
return
false
;
return
var
->
Name
()
==
op
->
Op
()
->
Input
(
argument
)[
nth
];
}
bool
IsNthOutput
(
Node
*
var
,
Node
*
op
,
const
std
::
string
&
argument
,
size_t
nth
)
{
PADDLE_ENFORCE
(
var
->
IsVar
());
PADDLE_ENFORCE
(
op
->
IsOp
());
if
(
op
->
Op
()
->
Output
(
argument
).
size
()
<=
nth
)
return
false
;
return
var
->
Name
()
==
op
->
Op
()
->
Output
(
argument
)[
nth
];
}
void
GraphSafeRemoveNodes
(
Graph
*
graph
,
const
std
::
unordered_set
<
const
Node
*>&
nodes
)
{
for
(
auto
*
node
:
nodes
)
{
graph
->
RemoveNode
(
const_cast
<
Node
*>
(
node
));
}
for
(
auto
*
node
:
graph
->
Nodes
())
{
for
(
auto
it
=
node
->
inputs
.
begin
();
it
!=
node
->
inputs
.
end
();)
{
if
(
nodes
.
count
(
*
it
))
{
it
=
const_cast
<
Node
*>
(
node
)
->
inputs
.
erase
(
it
);
}
else
it
++
;
}
for
(
auto
it
=
node
->
outputs
.
begin
();
it
!=
node
->
outputs
.
end
();)
{
if
(
nodes
.
count
(
*
it
))
{
it
=
const_cast
<
Node
*>
(
node
)
->
outputs
.
erase
(
it
);
}
else
it
++
;
}
}
}
bool
VarLinksFromOp
(
Node
*
node
,
const
std
::
string
&
op_type
)
{
for
(
auto
*
out
:
node
->
inputs
)
{
if
(
out
->
IsOp
()
&&
out
->
Op
()
->
Type
()
==
op_type
)
{
return
true
;
}
}
return
false
;
}
PDNode
*
patterns
::
FC
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
,
PDNode
*
x
,
bool
with_bias
)
{
// Create Operators
PDNode
*
elementwise_add_op
{
nullptr
};
auto
*
mul_op
=
pattern
->
NewNode
(
name_scope
,
"mul"
)
->
assert_is_op
(
"mul"
);
if
(
with_bias
)
{
elementwise_add_op
=
pattern
->
NewNode
(
name_scope
,
"elementwise_add"
)
->
assert_is_op
(
"elementwise_add"
);
}
// Create variables
// w
auto
*
mul_weight_var
=
pattern
->
NewNode
(
name_scope
,
"w"
)
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_nth_input
(
"mul"
,
"Y"
,
0
);
PDNode
*
mul_out_var
{
nullptr
};
if
(
with_bias
)
{
// intermediate variable, will be removed in the IR after fuse.
mul_out_var
=
pattern
->
NewNode
(
name_scope
,
"mul_out"
)
->
AsIntermediate
()
->
assert_is_only_output_of_op
(
"mul"
)
->
assert_is_op_input
(
"elementwise_add"
);
}
PDNode
*
bias
{
nullptr
},
*
fc_out
{
nullptr
};
if
(
with_bias
)
{
// bias
bias
=
pattern
->
NewNode
(
name_scope
,
"fc_bias"
)
->
assert_is_op_input
(
"elementwise_add"
)
->
AsInput
();
// output
fc_out
=
pattern
->
NewNode
(
name_scope
,
"fc_out"
)
->
AsOutput
()
->
assert_is_op_output
(
"elementwise_add"
);
}
else
{
fc_out
=
pattern
->
NewNode
(
name_scope
,
"fc_out"
)
->
AsOutput
()
->
assert_is_op_output
(
"mul"
);
}
if
(
with_bias
)
{
mul_op
->
LinksFrom
({
mul_weight_var
,
x
}).
LinksTo
({
mul_out_var
});
elementwise_add_op
->
LinksFrom
({
mul_out_var
,
bias
}).
LinksTo
({
fc_out
});
}
else
{
mul_op
->
LinksFrom
({
mul_weight_var
,
x
}).
LinksTo
({
fc_out
});
}
return
fc_out
;
}
PDNode
*
patterns
::
LSTM
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
,
PDNode
*
x
)
{
x
->
assert_is_op_input
(
"lstm"
,
"Input"
);
auto
*
lstm_op
=
pattern
->
NewNode
(
name_scope
,
"lstm"
)
->
assert_is_op
(
"lstm"
);
#define NEW_NODE(arg__, io__) \
auto* arg__ = pattern->NewNode(name_scope, #arg__) \
->assert_is_op_##io__("lstm", #arg__);
// Currently, the H0 and C0 are optional
// TODO(Superjomn) upgrade the fuse framework to support optional.
// NEW_NODE(H0, input);
// NEW_NODE(C0, input);
NEW_NODE
(
Weight
,
input
);
NEW_NODE
(
Bias
,
input
);
NEW_NODE
(
Hidden
,
output
);
NEW_NODE
(
Cell
,
output
);
NEW_NODE
(
BatchGate
,
output
);
NEW_NODE
(
BatchCellPreAct
,
output
);
lstm_op
->
LinksFrom
({
x
,
Weight
,
Bias
});
lstm_op
->
LinksTo
({
Hidden
,
Cell
,
BatchGate
,
BatchCellPreAct
});
return
Hidden
;
}
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/ir/graph_pattern_detector.h
浏览文件 @
597b7305
...
@@ -95,7 +95,11 @@ struct PDNode {
...
@@ -95,7 +95,11 @@ struct PDNode {
PDNode
*
assert_var_not_persistable
();
PDNode
*
assert_var_not_persistable
();
PDNode
*
assert_is_persistable_var
();
PDNode
*
assert_is_persistable_var
();
PDNode
*
assert_is_op_output
(
const
std
::
string
&
op_type
);
PDNode
*
assert_is_op_output
(
const
std
::
string
&
op_type
);
PDNode
*
assert_is_op_output
(
const
std
::
string
&
op_type
,
const
std
::
string
&
argument
);
PDNode
*
assert_is_op_input
(
const
std
::
string
&
op_type
);
PDNode
*
assert_is_op_input
(
const
std
::
string
&
op_type
);
PDNode
*
assert_is_op_input
(
const
std
::
string
&
op_type
,
const
std
::
string
&
argument
);
PDNode
*
assert_is_op_nth_input
(
const
std
::
string
&
op_type
,
PDNode
*
assert_is_op_nth_input
(
const
std
::
string
&
op_type
,
const
std
::
string
&
argument
,
int
nth
);
const
std
::
string
&
argument
,
int
nth
);
PDNode
*
assert_is_op_nth_output
(
const
std
::
string
&
op_type
,
PDNode
*
assert_is_op_nth_output
(
const
std
::
string
&
op_type
,
...
@@ -167,6 +171,9 @@ class PDPattern {
...
@@ -167,6 +171,9 @@ class PDPattern {
PDNode
*
NewNode
(
PDNode
::
teller_t
&&
teller
,
const
std
::
string
&
name
=
NewID
());
PDNode
*
NewNode
(
PDNode
::
teller_t
&&
teller
,
const
std
::
string
&
name
=
NewID
());
PDNode
*
NewNode
(
const
std
::
string
&
name
=
NewID
());
PDNode
*
NewNode
(
const
std
::
string
&
name
=
NewID
());
PDNode
*
NewNode
(
const
std
::
string
&
prefix
,
const
std
::
string
&
name
)
{
return
NewNode
(
prefix
+
"/"
+
name
);
}
PDNode
*
RetrieveNode
(
const
std
::
string
&
id
)
const
;
PDNode
*
RetrieveNode
(
const
std
::
string
&
id
)
const
;
const
std
::
vector
<
std
::
unique_ptr
<
PDNode
>>&
nodes
()
const
{
return
nodes_
;
}
const
std
::
vector
<
std
::
unique_ptr
<
PDNode
>>&
nodes
()
const
{
return
nodes_
;
}
...
@@ -257,64 +264,36 @@ class GraphPatternDetector {
...
@@ -257,64 +264,36 @@ class GraphPatternDetector {
// some helper methods.
// some helper methods.
// Op's input.
// Tell if a var links to an Op
static
bool
VarLinksToOp
(
Node
*
node
,
const
std
::
string
&
op_type
)
{
bool
VarLinksToOp
(
Node
*
node
,
const
std
::
string
&
op_type
);
for
(
auto
*
out
:
node
->
outputs
)
{
if
(
out
->
IsOp
()
&&
out
->
Op
()
->
Type
()
==
op_type
)
{
// Tell if an op links to a var
return
true
;
bool
VarLinksFromOp
(
Node
*
node
,
const
std
::
string
&
op_type
);
}
}
return
false
;
}
// Op's output.
static
bool
VarLinksFromOp
(
Node
*
node
,
const
std
::
string
&
op_type
)
{
for
(
auto
*
out
:
node
->
inputs
)
{
if
(
out
->
IsOp
()
&&
out
->
Op
()
->
Type
()
==
op_type
)
{
return
true
;
}
}
return
false
;
}
// Check whether a var node is a op node's nth input.
// Check whether a var node is a op node's nth input.
static
bool
IsNthInput
(
Node
*
var
,
Node
*
op
,
const
std
::
string
&
argument
,
bool
IsNthInput
(
Node
*
var
,
Node
*
op
,
const
std
::
string
&
argument
,
size_t
nth
);
size_t
nth
)
{
PADDLE_ENFORCE
(
var
->
IsVar
());
PADDLE_ENFORCE
(
op
->
IsOp
());
if
(
op
->
inputs
.
size
()
<=
nth
)
return
false
;
return
var
->
Name
()
==
op
->
Op
()
->
Input
(
argument
)[
nth
];
}
static
bool
IsNthOutput
(
Node
*
var
,
Node
*
op
,
const
std
::
string
&
argument
,
size_t
nth
)
{
PADDLE_ENFORCE
(
var
->
IsVar
());
PADDLE_ENFORCE
(
op
->
IsOp
());
if
(
op
->
inputs
.
size
()
<=
nth
)
return
false
;
return
var
->
Name
()
==
op
->
Op
()
->
Output
(
argument
)[
nth
];
}
static
void
GraphSafeRemoveNodes
(
Graph
*
graph
,
const
std
::
unordered_set
<
const
Node
*>&
nodes
)
{
for
(
auto
*
node
:
nodes
)
{
graph
->
RemoveNode
(
const_cast
<
Node
*>
(
node
));
}
for
(
auto
*
node
:
graph
->
Nodes
())
{
// Tell whether a var node is a op node's nth output.
for
(
auto
it
=
node
->
inputs
.
begin
();
it
!=
node
->
inputs
.
end
();)
{
bool
IsNthOutput
(
Node
*
var
,
Node
*
op
,
const
std
::
string
&
argument
,
size_t
nth
);
if
(
nodes
.
count
(
*
it
))
{
it
=
const_cast
<
Node
*>
(
node
)
->
inputs
.
erase
(
it
);
// Graph safely remove some nodes, will automatically clean up the edges.
}
else
void
GraphSafeRemoveNodes
(
Graph
*
graph
,
it
++
;
const
std
::
unordered_set
<
const
Node
*>&
nodes
);
}
for
(
auto
it
=
node
->
outputs
.
begin
();
it
!=
node
->
outputs
.
end
();)
{
// Some pre-defined patterns those can be reused in multiple passes.
if
(
nodes
.
count
(
*
it
))
{
namespace
patterns
{
it
=
const_cast
<
Node
*>
(
node
)
->
outputs
.
erase
(
it
);
}
else
// FC with bias
it
++
;
// op: mul + elementwise_add
}
// named nodes:
}
// mul, elementwise_add
}
// w, mul_out, bias, fc_out
PDNode
*
FC
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
,
PDNode
*
x
,
bool
with_bias
);
PDNode
*
LSTM
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
,
PDNode
*
x
);
}
// namespace patterns
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
...
...
paddle/fluid/framework/ir/graph_viz_pass.h
浏览文件 @
597b7305
...
@@ -42,6 +42,13 @@ class GraphVizPass : public Pass {
...
@@ -42,6 +42,13 @@ class GraphVizPass : public Pass {
marked_nodes_t
ConsumeMarkedNodes
(
Graph
*
graph
)
const
;
marked_nodes_t
ConsumeMarkedNodes
(
Graph
*
graph
)
const
;
};
};
static
GraphVizPass
::
marked_nodes_t
&
GetMarkedNodes
(
Graph
*
graph
)
{
if
(
!
graph
->
Has
(
kGraphvizMarkedNodeAttr
))
{
graph
->
Set
(
kGraphvizMarkedNodeAttr
,
new
GraphVizPass
::
marked_nodes_t
);
}
return
graph
->
Get
<
GraphVizPass
::
marked_nodes_t
>
(
kGraphvizMarkedNodeAttr
);
}
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/inference/analysis/analyzer.cc
浏览文件 @
597b7305
...
@@ -109,6 +109,7 @@ void Analyzer::Run(Argument* argument) {
...
@@ -109,6 +109,7 @@ void Analyzer::Run(Argument* argument) {
"infer_clean_graph_pass"
,
"graph_viz_pass"
,
//
"infer_clean_graph_pass"
,
"graph_viz_pass"
,
//
"attention_lstm_fuse_pass"
,
"graph_viz_pass"
,
//
"attention_lstm_fuse_pass"
,
"graph_viz_pass"
,
//
"fc_lstm_fuse_pass"
,
"graph_viz_pass"
,
//
"fc_lstm_fuse_pass"
,
"graph_viz_pass"
,
//
"mul_lstm_fuse_pass"
,
"graph_viz_pass"
,
//
"seq_concat_fc_fuse_pass"
,
"graph_viz_pass"
,
//
"seq_concat_fc_fuse_pass"
,
"graph_viz_pass"
,
//
"fc_fuse_pass"
,
"graph_viz_pass"
//
"fc_fuse_pass"
,
"graph_viz_pass"
//
...
...
paddle/fluid/inference/analysis/analyzer_tester.cc
浏览文件 @
597b7305
...
@@ -329,6 +329,7 @@ void TestDituRNNPrediction(const std::string &model_path,
...
@@ -329,6 +329,7 @@ void TestDituRNNPrediction(const std::string &model_path,
ASSERT_TRUE
(
fuse_statis
.
count
(
"fc"
));
ASSERT_TRUE
(
fuse_statis
.
count
(
"fc"
));
EXPECT_EQ
(
fuse_statis
.
at
(
"fc"
),
1
);
EXPECT_EQ
(
fuse_statis
.
at
(
"fc"
),
1
);
EXPECT_EQ
(
fuse_statis
.
at
(
"fc_nobias_lstm_fuse"
),
1
);
}
}
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录