Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
902f19b4
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
1 年多 前同步成功
通知
696
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
902f19b4
编写于
8月 29, 2018
作者:
Y
Yan Chunwei
提交者:
GitHub
8月 29, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fea/fuse attention lstm simplify.with fusion lstm.with sequnce expand (#13006)
上级
55f240ba
变更
40
显示空白变更内容
内联
并排
Showing
40 changed file
with
1507 addition
and
211 deletion
+1507
-211
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+7
-5
paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc
paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc
+273
-0
paddle/fluid/framework/ir/attention_lstm_fuse_pass.h
paddle/fluid/framework/ir/attention_lstm_fuse_pass.h
+14
-7
paddle/fluid/framework/ir/fc_fuse_pass.cc
paddle/fluid/framework/ir/fc_fuse_pass.cc
+6
-8
paddle/fluid/framework/ir/fc_fuse_pass.h
paddle/fluid/framework/ir/fc_fuse_pass.h
+1
-1
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
+126
-0
paddle/fluid/framework/ir/fc_lstm_fuse_pass.h
paddle/fluid/framework/ir/fc_lstm_fuse_pass.h
+33
-0
paddle/fluid/framework/ir/fuse_pass_base.h
paddle/fluid/framework/ir/fuse_pass_base.h
+44
-0
paddle/fluid/framework/ir/graph.h
paddle/fluid/framework/ir/graph.h
+15
-4
paddle/fluid/framework/ir/graph_helper.cc
paddle/fluid/framework/ir/graph_helper.cc
+1
-1
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+61
-16
paddle/fluid/framework/ir/graph_pattern_detector.h
paddle/fluid/framework/ir/graph_pattern_detector.h
+84
-15
paddle/fluid/framework/ir/graph_pattern_detector_tester.cc
paddle/fluid/framework/ir/graph_pattern_detector_tester.cc
+5
-5
paddle/fluid/framework/ir/graph_viz_pass.cc
paddle/fluid/framework/ir/graph_viz_pass.cc
+54
-28
paddle/fluid/framework/ir/graph_viz_pass.h
paddle/fluid/framework/ir/graph_viz_pass.h
+9
-0
paddle/fluid/framework/ir/node.h
paddle/fluid/framework/ir/node.h
+15
-6
paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc
paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc
+256
-0
paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.h
paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.h
+33
-0
paddle/fluid/inference/analysis/CMakeLists.txt
paddle/fluid/inference/analysis/CMakeLists.txt
+16
-8
paddle/fluid/inference/analysis/analyzer.cc
paddle/fluid/inference/analysis/analyzer.cc
+13
-0
paddle/fluid/inference/analysis/analyzer_tester.cc
paddle/fluid/inference/analysis/analyzer_tester.cc
+24
-47
paddle/fluid/inference/analysis/argument.h
paddle/fluid/inference/analysis/argument.h
+41
-0
paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc
...fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc
+5
-0
paddle/fluid/inference/analysis/dot.h
paddle/fluid/inference/analysis/dot.h
+12
-6
paddle/fluid/inference/analysis/fluid_to_ir_pass.cc
paddle/fluid/inference/analysis/fluid_to_ir_pass.cc
+44
-0
paddle/fluid/inference/analysis/fluid_to_ir_pass.h
paddle/fluid/inference/analysis/fluid_to_ir_pass.h
+42
-12
paddle/fluid/inference/analysis/fluid_to_ir_pass_tester.cc
paddle/fluid/inference/analysis/fluid_to_ir_pass_tester.cc
+6
-1
paddle/fluid/inference/analysis/ir_pass_manager.cc
paddle/fluid/inference/analysis/ir_pass_manager.cc
+8
-4
paddle/fluid/inference/analysis/ir_pass_manager.h
paddle/fluid/inference/analysis/ir_pass_manager.h
+5
-3
paddle/fluid/inference/analysis/pass_manager.cc
paddle/fluid/inference/analysis/pass_manager.cc
+2
-2
paddle/fluid/inference/api/CMakeLists.txt
paddle/fluid/inference/api/CMakeLists.txt
+3
-2
paddle/fluid/inference/api/analysis_predictor.cc
paddle/fluid/inference/api/analysis_predictor.cc
+165
-0
paddle/fluid/inference/api/helper.cc
paddle/fluid/inference/api/helper.cc
+44
-0
paddle/fluid/inference/api/helper.h
paddle/fluid/inference/api/helper.h
+9
-20
paddle/fluid/inference/api/paddle_inference_api.h
paddle/fluid/inference/api/paddle_inference_api.h
+1
-0
paddle/fluid/inference/io.cc
paddle/fluid/inference/io.cc
+16
-0
paddle/fluid/inference/io.h
paddle/fluid/inference/io.h
+5
-0
paddle/fluid/operators/attention_lstm_op.cc
paddle/fluid/operators/attention_lstm_op.cc
+1
-1
paddle/fluid/operators/fusion_seqexpand_concat_fc_op.cc
paddle/fluid/operators/fusion_seqexpand_concat_fc_op.cc
+8
-3
paddle/fluid/platform/init.cc
paddle/fluid/platform/init.cc
+0
-6
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
902f19b4
...
@@ -5,14 +5,16 @@ cc_library(pass SRCS pass.cc DEPS graph node graph_helper)
...
@@ -5,14 +5,16 @@ cc_library(pass SRCS pass.cc DEPS graph node graph_helper)
cc_library
(
graph_viz_pass SRCS graph_viz_pass.cc DEPS graph pass graph_helper
)
cc_library
(
graph_viz_pass SRCS graph_viz_pass.cc DEPS graph pass graph_helper
)
cc_library
(
graph_to_program_pass SRCS graph_to_program_pass.cc DEPS graph pass graph_helper
)
cc_library
(
graph_to_program_pass SRCS graph_to_program_pass.cc DEPS graph pass graph_helper
)
cc_library
(
graph_traits SRCS graph_traits.cc DEPS graph
)
cc_library
(
graph_traits SRCS graph_traits.cc DEPS graph
)
cc_library
(
graph_pattern_detecter SRCS graph_pattern_detecter.cc DEPS graph graph_helper graph_traits
)
cc_library
(
graph_pattern_detector SRCS graph_pattern_detector.cc DEPS graph graph_helper graph_traits
)
cc_library
(
fc_fuse_pass SRCS fc_fuse_pass.cc DEPS graph graph_pattern_detecter
)
cc_library
(
fc_fuse_pass SRCS fc_fuse_pass.cc DEPS graph graph_pattern_detector
)
cc_library
(
attention_lstm_fuse_pass SRCS attention_lstm_fuse_pass.cc DEPS graph graph_pattern_detector
)
cc_library
(
infer_clean_graph_pass SRCS infer_clean_graph_pass.cc DEPS graph pass
)
cc_library
(
infer_clean_graph_pass SRCS infer_clean_graph_pass.cc DEPS graph pass
)
cc_library
(
fc_lstm_fuse_pass SRCS fc_lstm_fuse_pass.cc DEPS graph graph_pattern_detector
)
cc_library
(
seq_concat_fc_fuse_pass SRCS seq_concat_fc_fuse_pass.cc DEPS graph graph_pattern_detector
)
cc_test
(
pass_test SRCS pass_test.cc DEPS graph pass graph_helper
)
cc_test
(
pass_test SRCS pass_test.cc DEPS graph pass graph_helper
)
cc_test
(
graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry
)
cc_test
(
graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry
)
cc_test
(
graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry
)
cc_test
(
graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry
)
cc_test
(
graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph_to_program_pass
)
cc_test
(
graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph_to_program_pass
)
cc_test
(
test_graph_pattern_detect
er SRCS graph_pattern_detecter_tester.cc DEPS graph_pattern_detecte
r
)
cc_test
(
test_graph_pattern_detect
or SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detecto
r
)
cc_test
(
test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass graph_pattern_detect
e
r graph pass graph_traits framework_proto
)
cc_test
(
test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass graph_pattern_detect
o
r graph pass graph_traits framework_proto
)
paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc
0 → 100644
浏览文件 @
902f19b4
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/attention_lstm_fuse_pass.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/inference/api/helper.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
struct
Param
{
std
::
string
X
=
"concat_0.tmp_0"
;
std
::
string
C0
=
"cell_init"
;
std
::
string
H0
=
"hidden_init"
;
std
::
string
AttentionWeight
=
"attention_fc.w_0"
;
std
::
string
AttentionBias
=
"attention_fc.b_0"
;
std
::
string
AttentionScalar
=
"attention_output.w_0"
;
std
::
string
AttentionScalarBias
=
"attention_output.b_0"
;
std
::
string
LSTMWeight
=
"attention_w.new"
;
std
::
string
LSTMBias
=
"attention_b.new"
;
std
::
string
Hidden
=
"array_to_lod_tensor_0.tmp_0"
;
std
::
string
Cell
=
"at.cell.new"
;
std
::
string
AttentionedX
=
"at.x.new"
;
std
::
string
AttentionFCOut
=
"at.fc.new"
;
std
::
string
LSTMX
=
"at.lstmx.new"
;
std
::
string
LSTMOUT
=
"at.lstmout.new"
;
};
void
PrepareParameters
(
Graph
*
graph
,
const
Param
&
param
);
void
FindWhileOp
(
Graph
*
graph
)
{
GraphPatternDetector
gpd
;
std
::
unordered_set
<
int
>
fused_external_ops
(
{
35
,
36
,
37
,
38
,
43
,
44
,
49
,
45
,
46
,
47
,
41
,
42
,
53
,
54
,
48
,
57
,
55
,
56
,
52
,
74
,
80
,
77
,
78
,
79
,
50
,
77
,
39
,
40
,
51
});
gpd
.
mutable_pattern
()
->
NewNode
(
[
&
](
Node
*
n
)
{
return
fused_external_ops
.
count
(
n
->
id
());
},
"while"
);
if
(
!
graph
->
Has
(
kGraphvizMarkedNodeAttr
))
{
graph
->
Set
(
kGraphvizMarkedNodeAttr
,
new
GraphVizPass
::
marked_nodes_t
);
}
auto
&
marked_nodes
=
graph
->
Get
<
GraphVizPass
::
marked_nodes_t
>
(
kGraphvizMarkedNodeAttr
);
auto
handle
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
auto
*
while_pat_node
=
gpd
.
pattern
().
RetriveNode
(
"while"
);
auto
*
while_node
=
subgraph
.
at
(
while_pat_node
);
marked_nodes
.
insert
(
while_node
);
};
gpd
(
graph
,
handle
);
Param
param
;
// Add AttentionLSTM node
OpDesc
op_desc
;
op_desc
.
SetType
(
"attention_lstm"
);
#define OP_SET_IN(x) op_desc.SetInput(#x, {param.x});
#define OP_SET_OUT(x) op_desc.SetOutput(#x, {param.x});
OP_SET_IN
(
X
);
OP_SET_IN
(
C0
);
OP_SET_IN
(
H0
);
OP_SET_IN
(
AttentionWeight
);
OP_SET_IN
(
AttentionBias
);
OP_SET_IN
(
AttentionScalar
);
OP_SET_IN
(
AttentionScalarBias
);
OP_SET_IN
(
LSTMWeight
);
OP_SET_IN
(
LSTMBias
);
OP_SET_OUT
(
Hidden
);
OP_SET_OUT
(
Cell
);
OP_SET_OUT
(
AttentionedX
);
OP_SET_OUT
(
AttentionFCOut
);
OP_SET_OUT
(
LSTMX
);
OP_SET_OUT
(
LSTMOUT
);
#undef OP_SET_IN
#undef OP_SET_OUT
auto
*
X
=
graph
->
RetriveNode
(
34
);
auto
*
LSTMOUT
=
graph
->
RetriveNode
(
81
);
auto
*
cell_init
=
graph
->
RetriveNode
(
6
);
auto
*
hidden_init
=
graph
->
RetriveNode
(
8
);
#define LINK_TO(node0, node1) \
node0->outputs.push_back(node1); \
node1->inputs.push_back(node0);
auto
*
lstm_op
=
graph
->
CreateOpNode
(
&
op_desc
);
PrepareParameters
(
graph
,
param
);
LINK_TO
(
X
,
lstm_op
);
LINK_TO
(
cell_init
,
lstm_op
);
LINK_TO
(
hidden_init
,
lstm_op
);
LINK_TO
(
lstm_op
,
LSTMOUT
);
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
}
#define CHECK_P1(x) PADDLE_ENFORCE_NOT_NULL(x);
#define CHECK_P2(x0, x1) \
CHECK_P1(x0); \
CHECK_P1(x1);
#define CHECK_P3(x0, x1, x2) \
CHECK_P2(x0, x1); \
CHECK_P1(x2);
#define CHECK_P4(x0, x1, x2, x3) \
CHECK_P3(x0, x1, x2); \
CHECK_P1(x3);
#define CHECK_P5(x0, x1, x2, x3, x4) \
CHECK_P4(x0, x1, x2, x3); \
CHECK_P1(x4);
void
PrepareLSTMWeight
(
const
LoDTensor
&
W_forget_w0
,
const
LoDTensor
&
W_forget_w1
,
const
LoDTensor
&
W_input_w0
,
const
LoDTensor
&
W_input_w1
,
const
LoDTensor
&
W_output_w0
,
const
LoDTensor
&
W_output_w1
,
const
LoDTensor
&
W_cell_w0
,
const
LoDTensor
&
W_cell_w1
,
LoDTensor
*
out
);
void
PrepareLSTMBias
(
const
LoDTensor
&
B_forget
,
const
LoDTensor
&
B_input
,
const
LoDTensor
&
B_output
,
const
LoDTensor
&
B_cell
,
LoDTensor
*
out
);
void
PrepareParameters
(
Graph
*
graph
,
const
Param
&
param
)
{
// Check parameters
PADDLE_ENFORCE
(
graph
->
Has
(
kParamScopeAttr
));
auto
*
scope
=
graph
->
Get
<
Scope
*>
(
kParamScopeAttr
);
// Create new parameters.
scope
->
Var
(
param
.
LSTMWeight
)
->
GetMutable
<
LoDTensor
>
();
scope
->
Var
(
param
.
LSTMBias
)
->
GetMutable
<
LoDTensor
>
();
scope
->
Var
(
param
.
Hidden
)
->
GetMutable
<
LoDTensor
>
();
scope
->
Var
(
param
.
Cell
)
->
GetMutable
<
LoDTensor
>
();
scope
->
Var
(
param
.
AttentionedX
)
->
GetMutable
<
LoDTensor
>
();
scope
->
Var
(
param
.
AttentionFCOut
)
->
GetMutable
<
LoDTensor
>
();
scope
->
Var
(
param
.
LSTMX
)
->
GetMutable
<
LoDTensor
>
();
scope
->
Var
(
param
.
LSTMOUT
)
->
GetMutable
<
LoDTensor
>
();
#define GATE_W(name__) \
auto* W_##name__##_w0 = scope->FindVar(#name__ ".w_0"); \
auto* W_##name__##_w1 = scope->FindVar(#name__ ".w_1"); \
auto* W_##name__##_b0 = scope->FindVar(#name__ ".b_0"); \
CHECK_P3(W_##name__##_w0, W_##name__##_w1, W_##name__##_b0); \
VLOG(4) << #name__ "_w0" \
<< " shape: " << W_##name__##_w0->Get<LoDTensor>().dims(); \
VLOG(4) << #name__ "_w1" \
<< " shape: " << W_##name__##_w1->Get<LoDTensor>().dims(); \
VLOG(4) << #name__ "_b0" \
<< " shape: " << W_##name__##_b0->Get<LoDTensor>().dims(); \
auto& W_##name__##_w0_t = W_##name__##_w0->Get<LoDTensor>(); \
auto& W_##name__##_w1_t = W_##name__##_w1->Get<LoDTensor>(); \
auto& W_##name__##_b0_t = W_##name__##_b0->Get<LoDTensor>();
GATE_W
(
forget
);
GATE_W
(
input
);
GATE_W
(
output
);
GATE_W
(
c
);
#undef GATE_W
auto
*
attention_fc_w
=
scope
->
FindVar
(
"attention_fc.w_0"
);
auto
*
attention_fc_b
=
scope
->
FindVar
(
"attention_fc.b_0"
);
auto
*
attention_output_w
=
scope
->
FindVar
(
"attention_output.w_0"
);
auto
*
attention_output_b
=
scope
->
FindVar
(
"attention_output.b_0"
);
CHECK_P4
(
attention_fc_w
,
attention_fc_b
,
attention_output_w
,
attention_output_b
);
auto
*
lstm_weight
=
scope
->
Var
(
param
.
LSTMWeight
);
auto
*
lstm_weight_t
=
lstm_weight
->
GetMutable
<
LoDTensor
>
();
auto
*
lstm_bias
=
scope
->
Var
(
param
.
LSTMBias
);
auto
*
lstm_bias_t
=
lstm_bias
->
GetMutable
<
LoDTensor
>
();
// reshape attention_bias
auto
*
attention_bias_t
=
scope
->
FindVar
(
param
.
AttentionBias
)
->
GetMutable
<
LoDTensor
>
();
PADDLE_ENFORCE_EQ
(
attention_bias_t
->
dims
().
size
(),
1
);
attention_bias_t
->
Resize
(
make_ddim
({
1
,
attention_bias_t
->
dims
()[
0
]}));
auto
*
attention_scalar_bias_t
=
scope
->
FindVar
(
param
.
AttentionScalarBias
)
->
GetMutable
<
LoDTensor
>
();
attention_scalar_bias_t
->
Resize
(
make_ddim
({
1
,
attention_scalar_bias_t
->
dims
()[
0
]}));
PrepareLSTMWeight
(
W_forget_w0_t
,
W_forget_w1_t
,
W_input_w0_t
,
W_input_w1_t
,
W_output_w0_t
,
W_output_w1_t
,
W_c_w0_t
,
W_c_w1_t
,
lstm_weight_t
);
PrepareLSTMBias
(
W_forget_b0_t
,
W_input_b0_t
,
W_output_b0_t
,
W_c_b0_t
,
lstm_bias_t
);
}
// Prepare parameters
void
PrepareLSTMWeight
(
const
LoDTensor
&
W_forget_w0
,
const
LoDTensor
&
W_forget_w1
,
const
LoDTensor
&
W_input_w0
,
const
LoDTensor
&
W_input_w1
,
const
LoDTensor
&
W_output_w0
,
const
LoDTensor
&
W_output_w1
,
const
LoDTensor
&
W_cell_w0
,
const
LoDTensor
&
W_cell_w1
,
LoDTensor
*
out
)
{
int
D
=
W_forget_w0
.
dims
()[
0
];
int
M
=
W_forget_w1
.
dims
()[
0
];
out
->
Resize
(
make_ddim
({
D
+
M
,
4
*
D
}));
VLOG
(
3
)
<<
"LSTMWeight resized to "
<<
out
->
dims
();
float
*
out_data
=
out
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
std
::
array
<
const
float
*
,
4
>
tensors
(
{
W_forget_w0
.
data
<
float
>
(),
W_input_w0
.
data
<
float
>
(),
W_output_w0
.
data
<
float
>
(),
W_cell_w0
.
data
<
float
>
()});
std
::
array
<
const
float
*
,
4
>
tensors1
(
{
W_forget_w1
.
data
<
float
>
(),
W_input_w1
.
data
<
float
>
(),
W_output_w1
.
data
<
float
>
(),
W_cell_w1
.
data
<
float
>
()});
for
(
int
row
=
0
;
row
<
D
;
row
++
)
{
for
(
int
col
=
0
;
col
<
4
;
col
++
)
{
float
*
dst
=
out_data
+
4
*
D
*
row
+
D
*
col
;
const
float
*
src
=
tensors
[
col
]
+
D
*
row
;
memcpy
(
dst
,
src
,
D
*
sizeof
(
float
));
}
}
for
(
int
row
=
0
;
row
<
M
;
row
++
)
{
for
(
int
col
=
0
;
col
<
4
;
col
++
)
{
float
*
dst
=
out_data
+
4
*
D
*
(
D
+
row
)
+
D
*
col
;
const
float
*
src
=
tensors1
[
col
]
+
D
*
row
;
memcpy
(
dst
,
src
,
D
*
sizeof
(
float
));
}
}
}
void
PrepareLSTMBias
(
const
LoDTensor
&
B_forget
,
const
LoDTensor
&
B_input
,
const
LoDTensor
&
B_output
,
const
LoDTensor
&
B_cell
,
LoDTensor
*
out
)
{
std
::
array
<
const
float
*
,
4
>
tensors
(
{
B_forget
.
data
<
float
>
(),
B_input
.
data
<
float
>
(),
B_output
.
data
<
float
>
(),
B_cell
.
data
<
float
>
()});
PADDLE_ENFORCE_EQ
(
B_forget
.
dims
().
size
(),
1
);
int
D
=
B_forget
.
dims
()[
0
];
out
->
Resize
(
make_ddim
({
1
,
4
*
D
}));
auto
*
out_data
=
out
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
for
(
size_t
i
=
0
;
i
<
tensors
.
size
();
i
++
)
{
memcpy
(
out_data
+
D
*
i
,
tensors
[
i
],
D
*
sizeof
(
float
));
}
}
// Parameters
std
::
unique_ptr
<
ir
::
Graph
>
AttentionLSTMFusePass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
PDPattern
external_pattern
,
subblock_pattern
;
FindWhileOp
(
graph
.
get
());
return
graph
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
attention_lstm_fuse_pass
,
paddle
::
framework
::
ir
::
AttentionLSTMFusePass
);
paddle/fluid/
inference/analysis/dot.cc
→
paddle/fluid/
framework/ir/attention_lstm_fuse_pass.h
浏览文件 @
902f19b4
...
@@ -12,12 +12,19 @@
...
@@ -12,12 +12,19 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/inference/analysis/dot.h"
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace
paddle
{
namespace
paddle
{
namespace
inference
{
namespace
framework
{
namespace
analysis
{
namespace
ir
{
size_t
Dot
::
counter
=
0
;
}
// namespace analysis
class
AttentionLSTMFusePass
:
public
FusePassBase
{
}
// namespace inference
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/ir/fc_fuse_pass.cc
浏览文件 @
902f19b4
...
@@ -100,12 +100,10 @@ void BuildFCPattern(PDPattern* pattern) {
...
@@ -100,12 +100,10 @@ void BuildFCPattern(PDPattern* pattern) {
},
},
"elementwise_add_out"
);
"elementwise_add_out"
);
pattern
->
AddEdge
(
mul_parameter_var
,
mul_op
);
mul_op
->
LinksFrom
({
mul_parameter_var
,
mul_tmp_input_var
})
pattern
->
AddEdge
(
mul_tmp_input_var
,
mul_op
);
.
LinksTo
({
mul_out_var
});
pattern
->
AddEdge
(
mul_op
,
mul_out_var
);
elementwise_add_op
->
LinksFrom
({
mul_out_var
,
elementwise_add_tmp_var
})
pattern
->
AddEdge
(
mul_out_var
,
elementwise_add_op
);
.
LinksTo
({
elementwise_add_out_var
});
pattern
->
AddEdge
(
elementwise_add_tmp_var
,
elementwise_add_op
);
pattern
->
AddEdge
(
elementwise_add_op
,
elementwise_add_out_var
);
}
}
// Replace the node `from` in the links to `to`
// Replace the node `from` in the links to `to`
...
@@ -125,7 +123,7 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
...
@@ -125,7 +123,7 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
std
::
unordered_set
<
Node
*>
nodes2delete
;
std
::
unordered_set
<
Node
*>
nodes2delete
;
GraphPatternDetect
e
r
gpd
;
GraphPatternDetect
o
r
gpd
;
BuildFCPattern
(
gpd
.
mutable_pattern
());
BuildFCPattern
(
gpd
.
mutable_pattern
());
#define GET_NODE(id) \
#define GET_NODE(id) \
...
@@ -134,7 +132,7 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
...
@@ -134,7 +132,7 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
auto* id = subgraph.at(gpd.pattern().RetriveNode(#id)); \
auto* id = subgraph.at(gpd.pattern().RetriveNode(#id)); \
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
auto
handler
=
[
&
](
const
GraphPatternDetect
e
r
::
subgraph_t
&
subgraph
,
auto
handler
=
[
&
](
const
GraphPatternDetect
o
r
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
Graph
*
g
)
{
VLOG
(
4
)
<<
"handle FC fuse"
;
VLOG
(
4
)
<<
"handle FC fuse"
;
// Currently, there is no FC op available, so I will just simulate the
// Currently, there is no FC op available, so I will just simulate the
...
...
paddle/fluid/framework/ir/fc_fuse_pass.h
浏览文件 @
902f19b4
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detect
e
r.h"
#include "paddle/fluid/framework/ir/graph_pattern_detect
o
r.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace
paddle
{
namespace
paddle
{
...
...
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
0 → 100644
浏览文件 @
902f19b4
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
std
::
unique_ptr
<
ir
::
Graph
>
FCLstmFusePass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
GraphPatternDetector
gpd
;
auto
*
pattern
=
gpd
.
mutable_pattern
();
std
::
unordered_set
<
int
>
fused_ops
({
// first lstm
13
,
15
,
16
,
// second lstm
23
,
25
,
26
});
pattern
->
NewNode
([
&
](
Node
*
x
)
{
return
fused_ops
.
count
(
x
->
id
());
},
"any_node"
);
std
::
unordered_set
<
Node
*>
marked_nodes
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
auto
*
id
=
subgraph
.
at
(
gpd
.
pattern
().
RetriveNode
(
"any_node"
));
marked_nodes
.
insert
(
id
);
};
gpd
(
graph
.
get
(),
handler
);
// Create New OpDesc
auto
lstm_creator
=
[
&
](
int
lstm
,
int
input
,
int
weight_x
,
int
weight_h
,
int
bias
,
int
hidden
,
int
cell
,
int
xx
)
{
#define GET_NODE(x) auto* x##_n = graph->RetriveNode(x);
GET_NODE
(
input
);
GET_NODE
(
weight_x
);
GET_NODE
(
weight_h
);
GET_NODE
(
bias
);
GET_NODE
(
hidden
);
GET_NODE
(
cell
);
GET_NODE
(
xx
);
GET_NODE
(
lstm
);
OpDesc
op_desc
;
op_desc
.
SetType
(
"fusion_lstm"
);
#define SET_IN(Key, node__) op_desc.SetInput(#Key, {node__##_n->Name()});
SET_IN
(
X
,
input
);
SET_IN
(
WeightX
,
weight_x
);
SET_IN
(
WeightH
,
weight_h
);
SET_IN
(
Bias
,
bias
);
#undef GET_NODE
#undef SET_IN
LOG
(
INFO
)
<<
"hidden_n: "
<<
hidden_n
->
Name
();
LOG
(
INFO
)
<<
"cell: "
<<
cell_n
->
Name
();
LOG
(
INFO
)
<<
"xx: "
<<
xx_n
->
Name
();
op_desc
.
SetInput
(
"H0"
,
{});
op_desc
.
SetInput
(
"C0"
,
{});
op_desc
.
SetOutput
(
"Hidden"
,
{
hidden_n
->
Name
()});
op_desc
.
SetOutput
(
"Cell"
,
{
cell_n
->
Name
()});
op_desc
.
SetOutput
(
"XX"
,
{
xx_n
->
Name
()});
op_desc
.
SetOutput
(
"BatchedGate"
,
{
"blstm_0.tmp_2"
});
op_desc
.
SetOutput
(
"BatchCellPreAct"
,
{
"blstm_1.tmp_2"
});
op_desc
.
SetAttr
(
"is_reverse"
,
lstm_n
->
Op
()
->
GetAttr
(
"is_reverse"
));
op_desc
.
SetAttr
(
"use_peepholes"
,
false
);
auto
*
op
=
graph
->
CreateOpNode
(
&
op_desc
);
#define LINK_TO(a, b) \
a->outputs.push_back(b); \
b->inputs.push_back(a);
LINK_TO
(
input_n
,
op
);
LINK_TO
(
weight_x_n
,
op
);
LINK_TO
(
weight_h_n
,
op
);
LINK_TO
(
bias_n
,
op
);
LINK_TO
(
op
,
hidden_n
);
#undef LINK_TO
return
op
;
};
lstm_creator
(
16
,
12
,
14
,
18
,
17
,
22
,
21
,
19
);
lstm_creator
(
26
,
12
,
24
,
28
,
27
,
32
,
31
,
29
);
// remove all the nodes
for
(
auto
*
node
:
marked_nodes
)
{
graph
->
RemoveNode
(
const_cast
<
Node
*>
(
node
));
}
for
(
auto
*
node
:
graph
->
Nodes
())
{
for
(
auto
it
=
node
->
inputs
.
begin
();
it
!=
node
->
inputs
.
end
();)
{
if
(
marked_nodes
.
count
(
*
it
))
{
it
=
const_cast
<
Node
*>
(
node
)
->
inputs
.
erase
(
it
);
}
else
it
++
;
}
for
(
auto
it
=
node
->
outputs
.
begin
();
it
!=
node
->
outputs
.
end
();)
{
if
(
marked_nodes
.
count
(
*
it
))
{
it
=
const_cast
<
Node
*>
(
node
)
->
outputs
.
erase
(
it
);
}
else
it
++
;
}
}
return
graph
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
fc_lstm_fuse_pass
,
paddle
::
framework
::
ir
::
FCLstmFusePass
);
paddle/fluid/framework/ir/fc_lstm_fuse_pass.h
0 → 100644
浏览文件 @
902f19b4
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
FCLstmFusePass
:
public
Pass
{
public:
virtual
~
FCLstmFusePass
()
{}
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/fuse_pass_base.h
0 → 100644
浏览文件 @
902f19b4
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/scope.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
static
const
char
kParamScopeAttr
[]
=
"param_scope"
;
class
FusePassBase
:
public
Pass
{
public:
void
Init
(
Graph
*
graph
)
const
{
graph_
=
graph
;
}
Scope
*
param_scope
()
const
{
PADDLE_ENFORCE
(
graph_
->
Has
(
kParamScopeAttr
));
return
graph_
->
Get
<
framework
::
Scope
*>
(
kParamScopeAttr
);
}
virtual
~
FusePassBase
()
{}
protected:
mutable
Graph
*
graph_
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/graph.h
浏览文件 @
902f19b4
...
@@ -99,13 +99,13 @@ class Graph {
...
@@ -99,13 +99,13 @@ class Graph {
// Create a normal variable with non-null VarDesc.
// Create a normal variable with non-null VarDesc.
ir
::
Node
*
CreateVarNode
(
VarDesc
*
var_desc
)
{
ir
::
Node
*
CreateVarNode
(
VarDesc
*
var_desc
)
{
PADDLE_ENFORCE
(
var_desc
);
PADDLE_ENFORCE
(
var_desc
);
return
AddNode
(
new
ir
::
Node
(
var_desc
));
return
AddNode
(
new
ir
::
Node
(
var_desc
,
node_count_
++
));
}
}
// Create a normal runnable operator with OpDesc.
// Create a normal runnable operator with OpDesc.
ir
::
Node
*
CreateOpNode
(
OpDesc
*
op_desc
)
{
ir
::
Node
*
CreateOpNode
(
OpDesc
*
op_desc
)
{
PADDLE_ENFORCE
(
op_desc
);
PADDLE_ENFORCE
(
op_desc
);
return
AddNode
(
new
ir
::
Node
(
op_desc
));
return
AddNode
(
new
ir
::
Node
(
op_desc
,
node_count_
++
));
}
}
// Create a control dependency var that connects 2 operations. The
// Create a control dependency var that connects 2 operations. The
...
@@ -115,13 +115,14 @@ class Graph {
...
@@ -115,13 +115,14 @@ class Graph {
// TODO(panyx0718): control var name should be really unique.
// TODO(panyx0718): control var name should be really unique.
const
std
::
string
name
=
string
::
Sprintf
(
const
std
::
string
name
=
string
::
Sprintf
(
"%s@%llu"
,
ir
::
Node
::
kControlDepVarName
,
node_set_
.
size
());
"%s@%llu"
,
ir
::
Node
::
kControlDepVarName
,
node_set_
.
size
());
return
AddNode
(
new
ir
::
Node
(
name
,
ir
::
Node
::
Type
::
kVariable
));
return
AddNode
(
new
ir
::
Node
(
name
,
ir
::
Node
::
Type
::
kVariable
,
node_count_
++
));
}
}
// A more free style way of creating a graph node. Mostly use for test
// A more free style way of creating a graph node. Mostly use for test
// or "copy" from another node. Avoid using it if possible.
// or "copy" from another node. Avoid using it if possible.
ir
::
Node
*
CreateEmptyNode
(
const
std
::
string
&
name
,
ir
::
Node
::
Type
type
)
{
ir
::
Node
*
CreateEmptyNode
(
const
std
::
string
&
name
,
ir
::
Node
::
Type
type
)
{
return
AddNode
(
new
ir
::
Node
(
name
,
type
));
return
AddNode
(
new
ir
::
Node
(
name
,
type
,
node_count_
++
));
}
}
// Clear all node information of the graph and return the ownership of the
// Clear all node information of the graph and return the ownership of the
...
@@ -142,12 +143,20 @@ class Graph {
...
@@ -142,12 +143,20 @@ class Graph {
nodes_
.
erase
(
node
);
nodes_
.
erase
(
node
);
}
}
Node
*
RetriveNode
(
int
id
)
{
auto
it
=
id2node_
.
find
(
id
);
if
(
it
!=
id2node_
.
end
())
return
it
->
second
;
return
nullptr
;
}
private:
private:
// This method takes ownership of `node`.
// This method takes ownership of `node`.
ir
::
Node
*
AddNode
(
ir
::
Node
*
node
)
{
ir
::
Node
*
AddNode
(
ir
::
Node
*
node
)
{
PADDLE_ENFORCE
(
node_set_
.
find
(
node
)
==
node_set_
.
end
());
PADDLE_ENFORCE
(
node_set_
.
find
(
node
)
==
node_set_
.
end
());
nodes_
[
node
].
reset
(
node
);
nodes_
[
node
].
reset
(
node
);
node_set_
.
insert
(
node
);
node_set_
.
insert
(
node
);
PADDLE_ENFORCE
(
!
id2node_
.
count
(
node
->
id
()),
"duplicate id %d"
,
node
->
id
());
id2node_
[
node
->
id
()]
=
node
;
return
node
;
return
node
;
}
}
...
@@ -157,6 +166,8 @@ class Graph {
...
@@ -157,6 +166,8 @@ class Graph {
std
::
map
<
std
::
string
,
std
::
function
<
void
(
void
)
>>
attr_dels_
;
std
::
map
<
std
::
string
,
std
::
function
<
void
(
void
)
>>
attr_dels_
;
std
::
map
<
ir
::
Node
*
,
std
::
unique_ptr
<
ir
::
Node
>>
nodes_
;
std
::
map
<
ir
::
Node
*
,
std
::
unique_ptr
<
ir
::
Node
>>
nodes_
;
std
::
unordered_set
<
ir
::
Node
*>
node_set_
;
std
::
unordered_set
<
ir
::
Node
*>
node_set_
;
std
::
map
<
int
,
Node
*>
id2node_
;
int
node_count_
{
0
};
};
};
bool
IsControlDepVar
(
const
ir
::
Node
&
var
);
bool
IsControlDepVar
(
const
ir
::
Node
&
var
);
...
...
paddle/fluid/framework/ir/graph_helper.cc
浏览文件 @
902f19b4
...
@@ -103,10 +103,10 @@ std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList(
...
@@ -103,10 +103,10 @@ std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList(
for
(
auto
&
var
:
n
->
inputs
)
{
for
(
auto
&
var
:
n
->
inputs
)
{
for
(
auto
&
adj_n
:
var
->
inputs
)
{
for
(
auto
&
adj_n
:
var
->
inputs
)
{
PADDLE_ENFORCE
(
adj_n
->
NodeType
()
==
ir
::
Node
::
Type
::
kOperation
);
PADDLE_ENFORCE
(
adj_n
->
NodeType
()
==
ir
::
Node
::
Type
::
kOperation
);
adj_list
[
n
].
insert
(
adj_n
);
VLOG
(
4
)
<<
"adj "
<<
adj_n
->
Name
()
<<
reinterpret_cast
<
void
*>
(
adj_n
)
VLOG
(
4
)
<<
"adj "
<<
adj_n
->
Name
()
<<
reinterpret_cast
<
void
*>
(
adj_n
)
<<
" -> "
<<
n
->
Name
()
<<
reinterpret_cast
<
void
*>
(
n
)
<<
" -> "
<<
n
->
Name
()
<<
reinterpret_cast
<
void
*>
(
n
)
<<
" via "
<<
var
->
Name
()
<<
reinterpret_cast
<
void
*>
(
var
);
<<
" via "
<<
var
->
Name
()
<<
reinterpret_cast
<
void
*>
(
var
);
adj_list
[
n
].
insert
(
adj_n
);
}
}
}
}
}
}
...
...
paddle/fluid/framework/ir/graph_pattern_detect
e
r.cc
→
paddle/fluid/framework/ir/graph_pattern_detect
o
r.cc
浏览文件 @
902f19b4
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
#include <vector>
#include <vector>
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detect
e
r.h"
#include "paddle/fluid/framework/ir/graph_pattern_detect
o
r.h"
#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
...
@@ -34,7 +34,7 @@ PDNode* PDPattern::NewNode(PDNode::teller_t&& teller, const std::string& name) {
...
@@ -34,7 +34,7 @@ PDNode* PDPattern::NewNode(PDNode::teller_t&& teller, const std::string& name) {
name
);
name
);
}
}
nodes_
.
emplace_back
(
new
PDNode
(
std
::
move
(
teller
),
name
));
nodes_
.
emplace_back
(
new
PDNode
(
std
::
move
(
teller
),
this
,
name
));
auto
*
cur
=
nodes_
.
back
().
get
();
auto
*
cur
=
nodes_
.
back
().
get
();
node_map_
[
name
]
=
cur
;
node_map_
[
name
]
=
cur
;
return
cur
;
return
cur
;
...
@@ -56,19 +56,22 @@ void PDPattern::AddEdge(PDNode* a, PDNode* b) {
...
@@ -56,19 +56,22 @@ void PDPattern::AddEdge(PDNode* a, PDNode* b) {
edges_
.
emplace_back
(
a
,
b
);
edges_
.
emplace_back
(
a
,
b
);
}
}
void
GraphPatternDetect
e
r
::
operator
()(
Graph
*
graph
,
void
GraphPatternDetect
o
r
::
operator
()(
Graph
*
graph
,
GraphPatternDetect
e
r
::
handle_t
handler
)
{
GraphPatternDetect
o
r
::
handle_t
handler
)
{
if
(
!
MarkPDNodesInGraph
(
*
graph
))
return
;
if
(
!
MarkPDNodesInGraph
(
*
graph
))
return
;
auto
subgraphs
=
DetectPatterns
();
auto
subgraphs
=
DetectPatterns
();
UniquePatterns
(
&
subgraphs
);
UniquePatterns
(
&
subgraphs
);
RemoveOverlappedMatch
(
&
subgraphs
);
RemoveOverlappedMatch
(
&
subgraphs
);
LOG
(
INFO
)
<<
"detect "
<<
subgraphs
.
size
()
<<
" subgraph matches the pattern"
;
int
id
=
0
;
for
(
auto
&
g
:
subgraphs
)
{
for
(
auto
&
g
:
subgraphs
)
{
LOG
(
INFO
)
<<
"optimizing #"
<<
id
++
<<
" subgraph"
;
handler
(
g
,
graph
);
handler
(
g
,
graph
);
}
}
}
}
bool
GraphPatternDetect
e
r
::
MarkPDNodesInGraph
(
const
ir
::
Graph
&
graph
)
{
bool
GraphPatternDetect
o
r
::
MarkPDNodesInGraph
(
const
ir
::
Graph
&
graph
)
{
VLOG
(
4
)
<<
"mark pdnodes in graph"
;
VLOG
(
4
)
<<
"mark pdnodes in graph"
;
if
(
graph
.
Nodes
().
empty
())
return
false
;
if
(
graph
.
Nodes
().
empty
())
return
false
;
...
@@ -114,13 +117,15 @@ bool IsNodesLink(Node* a, Node* b) {
...
@@ -114,13 +117,15 @@ bool IsNodesLink(Node* a, Node* b) {
return
false
;
return
false
;
}
}
std
::
vector
<
GraphPatternDetect
e
r
::
subgraph_t
>
std
::
vector
<
GraphPatternDetect
o
r
::
subgraph_t
>
GraphPatternDetect
e
r
::
DetectPatterns
()
{
GraphPatternDetect
o
r
::
DetectPatterns
()
{
// Init empty subgraphs.
// Init empty subgraphs.
std
::
vector
<
GraphPatternDetect
e
r
::
subgraph_t
>
result
;
std
::
vector
<
GraphPatternDetect
o
r
::
subgraph_t
>
result
;
std
::
vector
<
HitGroup
>
init_groups
;
std
::
vector
<
HitGroup
>
init_groups
;
PADDLE_ENFORCE
(
!
pattern_
.
edges
().
empty
(),
"At least one edge is needed"
);
std
::
array
<
std
::
vector
<
HitGroup
>
,
2
>
bi_records
;
auto
*
first_pnode
=
pattern_
.
edges
().
front
().
first
;
// PADDLE_ENFORCE(!pattern_.edges().empty(), "At least one edge is needed");
auto
*
first_pnode
=
pattern_
.
edges
().
empty
()
?
pattern
().
nodes
().
front
().
get
()
:
pattern_
.
edges
().
front
().
first
;
if
(
!
pdnodes2nodes_
.
count
(
first_pnode
))
return
result
;
if
(
!
pdnodes2nodes_
.
count
(
first_pnode
))
return
result
;
for
(
auto
*
node
:
pdnodes2nodes_
[
first_pnode
])
{
for
(
auto
*
node
:
pdnodes2nodes_
[
first_pnode
])
{
HitGroup
group
;
HitGroup
group
;
...
@@ -129,7 +134,6 @@ GraphPatternDetecter::DetectPatterns() {
...
@@ -129,7 +134,6 @@ GraphPatternDetecter::DetectPatterns() {
}
}
int
step
=
0
;
int
step
=
0
;
std
::
array
<
std
::
vector
<
HitGroup
>
,
2
>
bi_records
;
bi_records
[
0
]
=
std
::
move
(
init_groups
);
bi_records
[
0
]
=
std
::
move
(
init_groups
);
// Extend a PDNode to subgraphs by deducing the connection relations defined
// Extend a PDNode to subgraphs by deducing the connection relations defined
...
@@ -141,6 +145,7 @@ GraphPatternDetecter::DetectPatterns() {
...
@@ -141,6 +145,7 @@ GraphPatternDetecter::DetectPatterns() {
auto
&
pre_groups
=
bi_records
[
step
%
2
];
auto
&
pre_groups
=
bi_records
[
step
%
2
];
auto
&
cur_groups
=
bi_records
[
1
-
(
step
++
%
2
)];
auto
&
cur_groups
=
bi_records
[
1
-
(
step
++
%
2
)];
cur_groups
.
clear
();
cur_groups
.
clear
();
if
(
pre_groups
.
empty
())
break
;
// source -> target
// source -> target
for
(
Node
*
source
:
pdnodes2nodes_
[
edge
.
first
])
{
for
(
Node
*
source
:
pdnodes2nodes_
[
edge
.
first
])
{
for
(
Node
*
target
:
pdnodes2nodes_
[
edge
.
second
])
{
for
(
Node
*
target
:
pdnodes2nodes_
[
edge
.
second
])
{
...
@@ -163,7 +168,7 @@ GraphPatternDetecter::DetectPatterns() {
...
@@ -163,7 +168,7 @@ GraphPatternDetecter::DetectPatterns() {
}
}
for
(
auto
&
group
:
bi_records
[
step
%
2
])
{
for
(
auto
&
group
:
bi_records
[
step
%
2
])
{
GraphPatternDetect
e
r
::
subgraph_t
subgraph
;
GraphPatternDetect
o
r
::
subgraph_t
subgraph
;
for
(
auto
&
role
:
group
.
roles
)
{
for
(
auto
&
role
:
group
.
roles
)
{
subgraph
.
emplace
(
role
.
first
,
role
.
second
);
subgraph
.
emplace
(
role
.
first
,
role
.
second
);
}
}
...
@@ -172,10 +177,10 @@ GraphPatternDetecter::DetectPatterns() {
...
@@ -172,10 +177,10 @@ GraphPatternDetecter::DetectPatterns() {
return
result
;
return
result
;
}
}
void
GraphPatternDetect
e
r
::
UniquePatterns
(
void
GraphPatternDetect
o
r
::
UniquePatterns
(
std
::
vector
<
GraphPatternDetect
e
r
::
subgraph_t
>*
subgraphs
)
{
std
::
vector
<
GraphPatternDetect
o
r
::
subgraph_t
>*
subgraphs
)
{
if
(
subgraphs
->
empty
())
return
;
if
(
subgraphs
->
empty
())
return
;
std
::
vector
<
GraphPatternDetect
e
r
::
subgraph_t
>
result
;
std
::
vector
<
GraphPatternDetect
o
r
::
subgraph_t
>
result
;
std
::
unordered_set
<
size_t
>
set
;
std
::
unordered_set
<
size_t
>
set
;
for
(
auto
&
g
:
*
subgraphs
)
{
for
(
auto
&
g
:
*
subgraphs
)
{
...
@@ -192,7 +197,7 @@ void GraphPatternDetecter::UniquePatterns(
...
@@ -192,7 +197,7 @@ void GraphPatternDetecter::UniquePatterns(
*
subgraphs
=
result
;
*
subgraphs
=
result
;
}
}
void
GraphPatternDetect
e
r
::
RemoveOverlappedMatch
(
void
GraphPatternDetect
o
r
::
RemoveOverlappedMatch
(
std
::
vector
<
subgraph_t
>*
subgraphs
)
{
std
::
vector
<
subgraph_t
>*
subgraphs
)
{
std
::
vector
<
subgraph_t
>
result
;
std
::
vector
<
subgraph_t
>
result
;
std
::
unordered_set
<
Node
*>
node_set
;
std
::
unordered_set
<
Node
*>
node_set
;
...
@@ -215,6 +220,46 @@ void GraphPatternDetecter::RemoveOverlappedMatch(
...
@@ -215,6 +220,46 @@ void GraphPatternDetecter::RemoveOverlappedMatch(
*
subgraphs
=
result
;
*
subgraphs
=
result
;
}
}
std
::
string
PDPattern
::
DotString
()
const
{
using
inference
::
analysis
::
Dot
;
Dot
dot
;
int
id
=
0
;
// Create Nodes
std
::
unordered_map
<
PDNode
*
,
std
::
string
>
node2dot
;
for
(
const
auto
&
node
:
nodes
())
{
std
::
string
node_id
=
"Node"
+
std
::
to_string
(
id
++
);
dot
.
AddNode
(
node_id
,
{},
node
->
name
());
node2dot
[
node
.
get
()]
=
node_id
;
}
// Create Edges
for
(
const
auto
&
edge
:
edges
())
{
if
(
!
node2dot
.
count
(
edge
.
first
)
||
!
node2dot
.
count
(
edge
.
second
))
{
LOG
(
ERROR
)
<<
"no node "
<<
edge
.
first
<<
" "
<<
edge
.
second
;
continue
;
}
auto
&
src
=
node2dot
.
at
(
edge
.
first
);
auto
&
trg
=
node2dot
.
at
(
edge
.
second
);
dot
.
AddEdge
(
src
,
trg
,
{});
}
return
dot
.
Build
();
}
PDNode
&
PDNode
::
LinksTo
(
const
std
::
vector
<
PDNode
*>&
others
)
{
// extend outlinks.
for
(
PDNode
*
x
:
others
)
{
pattern_
->
AddEdge
(
this
,
x
);
}
return
*
this
;
}
PDNode
&
PDNode
::
LinksFrom
(
const
std
::
vector
<
PDNode
*>&
others
)
{
// extend outlinks.
for
(
PDNode
*
x
:
others
)
{
pattern_
->
AddEdge
(
x
,
this
);
}
return
*
this
;
}
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/ir/graph_pattern_detect
e
r.h
→
paddle/fluid/framework/ir/graph_pattern_detect
o
r.h
浏览文件 @
902f19b4
...
@@ -21,12 +21,14 @@
...
@@ -21,12 +21,14 @@
#include <numeric>
#include <numeric>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/inference/analysis/dot.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
ir
{
namespace
ir
{
class
PDPattern
;
// Some basic t
orminoly
gies:
// Some basic t
erminolo
gies:
// - PDPattern: a pattern defined as a data flow graph.
// - PDPattern: a pattern defined as a data flow graph.
// - PDNode: the node in the pattern, each PDNode represents an `ir::Node`
// - PDNode: the node in the pattern, each PDNode represents an `ir::Node`
// that meets some conditions defined in `PDNode.teller`.
// that meets some conditions defined in `PDNode.teller`.
...
@@ -36,30 +38,43 @@ namespace ir {
...
@@ -36,30 +38,43 @@ namespace ir {
struct
PDNode
{
struct
PDNode
{
// tell whether an ir::Node* is a candidation for a PDNode.
// tell whether an ir::Node* is a candidation for a PDNode.
using
teller_t
=
std
::
function
<
bool
(
Node
*
)
>
;
using
teller_t
=
std
::
function
<
bool
(
Node
*
)
>
;
enum
class
Type
{
kOp
,
kVar
};
PDNode
(
teller_t
&&
teller
,
const
std
::
string
&
name
=
""
)
// this link to others
:
teller_
(
teller
),
name_
(
name
)
{
PDNode
&
LinksTo
(
const
std
::
vector
<
PDNode
*>&
others
);
PADDLE_ENFORCE
(
teller_
!=
nullptr
,
"invalid teller functer is set."
);
PDNode
&
LinksFrom
(
const
std
::
vector
<
PDNode
*>&
others
);
}
PDNode
(
PDNode
&&
other
)
=
default
;
std
::
vector
<
PDNode
*>
inlinks
;
std
::
vector
<
PDNode
*>
outlinks
;
bool
Tell
(
Node
*
node
)
const
{
bool
Tell
(
Node
*
node
)
const
{
PADDLE_ENFORCE
(
teller_
!=
nullptr
,
"teller should be set for a PDNode"
);
PADDLE_ENFORCE
(
teller_
!=
nullptr
,
"teller should be set for a PDNode"
);
return
teller_
(
node
);
return
teller_
(
node
);
}
}
bool
IsOp
()
const
{
return
type_
==
Type
::
kOp
;
}
bool
IsVar
()
const
{
return
type_
==
Type
::
kVar
;
}
const
std
::
string
&
name
()
const
{
return
name_
;
}
const
std
::
string
&
name
()
const
{
return
name_
;
}
PDNode
(
const
PDNode
&
)
=
delete
;
PDNode
(
const
PDNode
&
)
=
delete
;
PDNode
&
operator
=
(
const
PDNode
&
)
=
delete
;
PDNode
&
operator
=
(
const
PDNode
&
)
=
delete
;
private:
private:
PDNode
(
teller_t
&&
teller
,
PDPattern
*
pattern
,
const
std
::
string
&
name
=
""
,
Type
type
=
Type
::
kVar
)
:
teller_
(
std
::
move
(
teller
)),
pattern_
(
pattern
),
name_
(
name
),
type_
(
type
)
{
PADDLE_ENFORCE
(
teller_
!=
nullptr
,
"invalid teller functer is set."
);
}
PDNode
(
PDNode
&&
other
)
=
default
;
friend
class
PDPattern
;
teller_t
teller_
;
teller_t
teller_
;
PDPattern
*
pattern_
;
std
::
string
name_
;
std
::
string
name_
;
Type
type_
;
};
};
/*
/*
...
@@ -102,6 +117,8 @@ class PDPattern {
...
@@ -102,6 +117,8 @@ class PDPattern {
const
std
::
vector
<
std
::
unique_ptr
<
PDNode
>>&
nodes
()
const
{
return
nodes_
;
}
const
std
::
vector
<
std
::
unique_ptr
<
PDNode
>>&
nodes
()
const
{
return
nodes_
;
}
const
std
::
vector
<
edge_t
>&
edges
()
const
{
return
edges_
;
}
const
std
::
vector
<
edge_t
>&
edges
()
const
{
return
edges_
;
}
std
::
string
DotString
()
const
;
private:
private:
#ifdef PADDLE_WITH_TESTING
#ifdef PADDLE_WITH_TESTING
FRIEND_TEST
(
PDPattern
,
AddEdge
);
FRIEND_TEST
(
PDPattern
,
AddEdge
);
...
@@ -117,7 +134,7 @@ class PDPattern {
...
@@ -117,7 +134,7 @@ class PDPattern {
};
};
/*
/*
* GraphPatternDetect
e
r helps to detect the specific patterns in the graph.
* GraphPatternDetect
o
r helps to detect the specific patterns in the graph.
* Input a pattern, output a list of the matched subgraphs/nodes.
* Input a pattern, output a list of the matched subgraphs/nodes.
* This helper can be used to support fuse(conv+batchnorm => batchnorm e.g.).
* This helper can be used to support fuse(conv+batchnorm => batchnorm e.g.).
*
*
...
@@ -129,7 +146,7 @@ class PDPattern {
...
@@ -129,7 +146,7 @@ class PDPattern {
*
*
* Usage:
* Usage:
* // Create a detector
* // Create a detector
* GraphPatternDetect
e
r detector;
* GraphPatternDetect
o
r detector;
* // Define the detector's pattern, by adding PDNode and define the edges.
* // Define the detector's pattern, by adding PDNode and define the edges.
* auto* node0 = detector.mutable_pattern().AddNode(...)
* auto* node0 = detector.mutable_pattern().AddNode(...)
* auto* node1 = detector.mutable_pattern().AddNode(...)
* auto* node1 = detector.mutable_pattern().AddNode(...)
...
@@ -138,11 +155,11 @@ class PDPattern {
...
@@ -138,11 +155,11 @@ class PDPattern {
* detector.mutable_pattern().AddEdge(node0, node1);
* detector.mutable_pattern().AddEdge(node0, node1);
* // Create an handler, to define the behavior of treating the filtered
* // Create an handler, to define the behavior of treating the filtered
* // subgraphs that comply with the patterns.
* // subgraphs that comply with the patterns.
* GraphPatternDetect
e
r::handle_t handler = some labmda
* GraphPatternDetect
o
r::handle_t handler = some labmda
* // Execute the detector.
* // Execute the detector.
* detector(&graph, handler);
* detector(&graph, handler);
*/
*/
class
GraphPatternDetect
e
r
{
class
GraphPatternDetect
o
r
{
public:
public:
using
subgraph_t
=
std
::
unordered_map
<
PDNode
*
,
Node
*>
;
using
subgraph_t
=
std
::
unordered_map
<
PDNode
*
,
Node
*>
;
...
@@ -177,10 +194,62 @@ class GraphPatternDetecter {
...
@@ -177,10 +194,62 @@ class GraphPatternDetecter {
using
hit_rcd_t
=
using
hit_rcd_t
=
std
::
pair
<
Node
*
/*node in graph*/
,
PDNode
*
/*node in pattern*/
>
;
std
::
pair
<
Node
*
/*node in graph*/
,
PDNode
*
/*node in pattern*/
>
;
PDPattern
pattern_
;
PDPattern
pattern_
;
std
::
vector
<
hit_rcd_t
>
marked_records_
;
std
::
unordered_map
<
const
PDNode
*
,
std
::
unordered_set
<
Node
*>>
pdnodes2nodes_
;
std
::
unordered_map
<
const
PDNode
*
,
std
::
unordered_set
<
Node
*>>
pdnodes2nodes_
;
};
};
// some helper methods.
// Op's input.
static
bool
VarLinksToOp
(
Node
*
node
,
const
std
::
string
&
op_type
)
{
for
(
auto
*
out
:
node
->
outputs
)
{
if
(
out
->
IsOp
()
&&
out
->
Op
()
->
Type
()
==
op_type
)
{
return
true
;
}
}
return
false
;
}
// Op's output.
static
bool
VarLinksFromOp
(
Node
*
node
,
const
std
::
string
&
op_type
)
{
for
(
auto
*
out
:
node
->
inputs
)
{
if
(
out
->
IsOp
()
&&
out
->
Op
()
->
Type
()
==
op_type
)
{
return
true
;
}
}
return
false
;
}
// Check whether a var node is a op node's nth input.
static
bool
IsNthInput
(
Node
*
var
,
Node
*
op
,
const
std
::
string
&
argument
,
size_t
nth
)
{
PADDLE_ENFORCE
(
var
->
IsVar
());
PADDLE_ENFORCE
(
op
->
IsOp
());
if
(
op
->
inputs
.
size
()
<=
nth
)
return
false
;
return
var
->
Name
()
==
op
->
Op
()
->
Input
(
argument
)[
nth
];
}
static
void
GraphSafeRemoveNodes
(
Graph
*
graph
,
const
std
::
unordered_set
<
const
Node
*>&
nodes
)
{
for
(
auto
*
node
:
nodes
)
{
graph
->
RemoveNode
(
const_cast
<
Node
*>
(
node
));
}
for
(
auto
*
node
:
graph
->
Nodes
())
{
for
(
auto
it
=
node
->
inputs
.
begin
();
it
!=
node
->
inputs
.
end
();)
{
if
(
nodes
.
count
(
*
it
))
{
it
=
const_cast
<
Node
*>
(
node
)
->
inputs
.
erase
(
it
);
}
else
it
++
;
}
for
(
auto
it
=
node
->
outputs
.
begin
();
it
!=
node
->
outputs
.
end
();)
{
if
(
nodes
.
count
(
*
it
))
{
it
=
const_cast
<
Node
*>
(
node
)
->
outputs
.
erase
(
it
);
}
else
it
++
;
}
}
}
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/ir/graph_pattern_detect
e
r_tester.cc
→
paddle/fluid/framework/ir/graph_pattern_detect
o
r_tester.cc
浏览文件 @
902f19b4
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/ir/graph_pattern_detect
e
r.h"
#include "paddle/fluid/framework/ir/graph_pattern_detect
o
r.h"
#include <gtest/gtest.h>
#include <gtest/gtest.h>
...
@@ -82,7 +82,7 @@ TEST(PDPattern, AddEdge) {
...
@@ -82,7 +82,7 @@ TEST(PDPattern, AddEdge) {
}
}
TEST
(
GraphPatternDetecter
,
MarkPDNodesInGraph
)
{
TEST
(
GraphPatternDetecter
,
MarkPDNodesInGraph
)
{
GraphPatternDetect
e
r
x
;
GraphPatternDetect
o
r
x
;
// mark o2, o3, v2
// mark o2, o3, v2
// The pattern is a graph:
// The pattern is a graph:
...
@@ -131,7 +131,7 @@ TEST(GraphPatternDetecter, MultiSubgraph) {
...
@@ -131,7 +131,7 @@ TEST(GraphPatternDetecter, MultiSubgraph) {
Graph
graph
(
program
);
Graph
graph
(
program
);
BuildGraph
(
&
graph
);
BuildGraph
(
&
graph
);
GraphPatternDetect
e
r
x
;
GraphPatternDetect
o
r
x
;
// The pattern is a graph:
// The pattern is a graph:
// op -> var
// op -> var
...
@@ -149,8 +149,8 @@ TEST(GraphPatternDetecter, MultiSubgraph) {
...
@@ -149,8 +149,8 @@ TEST(GraphPatternDetecter, MultiSubgraph) {
x
.
mutable_pattern
()
->
AddEdge
(
any_var
,
any_op1
);
x
.
mutable_pattern
()
->
AddEdge
(
any_var
,
any_op1
);
int
count
=
0
;
int
count
=
0
;
GraphPatternDetect
e
r
::
handle_t
handle
=
[
&
](
GraphPatternDetect
o
r
::
handle_t
handle
=
[
&
](
const
GraphPatternDetect
e
r
::
subgraph_t
&
s
,
Graph
*
g
)
{
const
GraphPatternDetect
o
r
::
subgraph_t
&
s
,
Graph
*
g
)
{
LOG
(
INFO
)
<<
"Detect "
<<
s
.
at
(
any_op
)
->
Name
()
<<
" -> "
LOG
(
INFO
)
<<
"Detect "
<<
s
.
at
(
any_op
)
->
Name
()
<<
" -> "
<<
s
.
at
(
any_var
)
->
Name
()
<<
" -> "
<<
s
.
at
(
any_op1
)
->
Name
();
<<
s
.
at
(
any_var
)
->
Name
()
<<
" -> "
<<
s
.
at
(
any_op1
)
->
Name
();
count
++
;
count
++
;
...
...
paddle/fluid/framework/ir/graph_viz_pass.cc
浏览文件 @
902f19b4
...
@@ -16,11 +16,13 @@ limitations under the License. */
...
@@ -16,11 +16,13 @@ limitations under the License. */
#include <unordered_set>
#include <unordered_set>
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/inference/analysis/dot.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
ir
{
namespace
ir
{
static
const
char
kGraphVizPath
[]
=
"graph_viz_path"
;
static
const
char
kGraphVizPath
[]
=
"graph_viz_path"
;
using
inference
::
analysis
::
Dot
;
std
::
unique_ptr
<
ir
::
Graph
>
GraphVizPass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
GraphVizPass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
...
@@ -30,41 +32,65 @@ std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl(
...
@@ -30,41 +32,65 @@ std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl(
PADDLE_ENFORCE
(
fout
->
good
());
PADDLE_ENFORCE
(
fout
->
good
());
std
::
ostream
&
sout
=
*
fout
;
std
::
ostream
&
sout
=
*
fout
;
size_t
var_id
=
0
;
std
::
unordered_map
<
const
ir
::
Node
*
,
std
::
string
>
node2dot
;
std
::
unordered_map
<
const
ir
::
Node
*
,
size_t
>
vars
;
Dot
dot
;
sout
<<
"digraph G {
\n
"
;
std
::
vector
<
Dot
::
Attr
>
op_attrs
({
Dot
::
Attr
(
"style"
,
"filled"
),
for
(
const
ir
::
Node
*
n
:
graph
->
Nodes
())
{
Dot
::
Attr
(
"shape"
,
"box"
),
if
(
n
->
NodeType
()
!=
ir
::
Node
::
Type
::
kVariable
)
continue
;
Dot
::
Attr
(
"fillcolor"
,
"red"
)});
size_t
cur_var_id
=
var_id
++
;
std
::
vector
<
Dot
::
Attr
>
var_attrs
({
Dot
::
Attr
(
"style"
,
"filled,rounded"
),
vars
[
n
]
=
cur_var_id
;
// Dot::Attr("shape", "diamond"),
Dot
::
Attr
(
"fillcolor"
,
"yellow"
)});
sout
<<
"var_"
<<
cur_var_id
<<
" [label=
\"
"
<<
n
->
Name
()
<<
"
\"
]"
<<
std
::
endl
;
std
::
vector
<
Dot
::
Attr
>
marked_op_attrs
({
Dot
::
Attr
(
"style"
,
"filled"
),
Dot
::
Attr
(
"shape"
,
"box"
),
Dot
::
Attr
(
"fillcolor"
,
"lightgray"
)});
std
::
vector
<
Dot
::
Attr
>
marked_var_attrs
(
{
Dot
::
Attr
(
"style"
,
"filled,rounded"
),
// Dot::Attr("shape", "diamond"),
Dot
::
Attr
(
"fillcolor"
,
"lightgray"
)});
auto
marked_nodes
=
ConsumeMarkedNodes
(
graph
.
get
());
// Create nodes
for
(
const
Node
*
n
:
graph
->
Nodes
())
{
std
::
string
node_id
=
n
->
Name
()
+
"("
+
std
::
to_string
(
n
->
id
())
+
")"
;
if
(
n
->
IsOp
())
{
decltype
(
op_attrs
)
attr
=
marked_nodes
.
count
(
n
)
?
marked_op_attrs
:
op_attrs
;
dot
.
AddNode
(
node_id
,
attr
,
node_id
);
}
else
if
(
n
->
IsVar
())
{
decltype
(
op_attrs
)
attr
=
marked_nodes
.
count
(
n
)
?
marked_var_attrs
:
var_attrs
;
dot
.
AddNode
(
node_id
,
attr
,
node_id
);
}
}
node2dot
[
n
]
=
node_id
;
size_t
op_id
=
0
;
for
(
const
ir
::
Node
*
n
:
graph
->
Nodes
())
{
if
(
n
->
NodeType
()
!=
ir
::
Node
::
Type
::
kOperation
)
continue
;
std
::
string
op_name
=
"op_"
+
std
::
to_string
(
op_id
++
);
sout
<<
op_name
<<
" [label=
\"
"
<<
n
->
Name
()
<<
"
\"
, shape=rect]"
<<
std
::
endl
;
for
(
auto
in
:
n
->
inputs
)
{
std
::
string
var_name
=
"var_"
+
std
::
to_string
(
vars
[
in
]);
sout
<<
var_name
<<
" -> "
<<
op_name
<<
std
::
endl
;
}
}
// Create edges
for
(
auto
out
:
n
->
outputs
)
{
for
(
const
Node
*
n
:
graph
->
Nodes
())
{
std
::
string
var_name
=
"var_"
+
std
::
to_string
(
vars
[
out
]);
const
auto
&
src_id
=
node2dot
.
at
(
n
);
sout
<<
op_name
<<
" -> "
<<
var_name
<<
std
::
endl
;
for
(
auto
*
out
:
n
->
outputs
)
{
const
auto
&
trg_id
=
node2dot
.
at
(
out
);
dot
.
AddEdge
(
src_id
,
trg_id
,
{});
}
}
}
}
sout
<<
"}
\n
"
;
sout
<<
dot
.
Build
();
return
graph
;
return
graph
;
}
}
GraphVizPass
::
marked_nodes_t
GraphVizPass
::
ConsumeMarkedNodes
(
Graph
*
graph
)
const
{
marked_nodes_t
res
;
if
(
graph
->
Has
(
kGraphvizMarkedNodeAttr
))
{
auto
&
attr
=
graph
->
Get
<
marked_nodes_t
>
(
kGraphvizMarkedNodeAttr
);
res
=
attr
;
attr
.
clear
();
}
return
res
;
}
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
...
...
paddle/fluid/framework/ir/graph_viz_pass.h
浏览文件 @
902f19b4
...
@@ -27,10 +27,19 @@ namespace paddle {
...
@@ -27,10 +27,19 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
ir
{
namespace
ir
{
const
char
kGraphvizMarkedNodeAttr
[]
=
"__graphviz__marked_node__"
;
class
GraphVizPass
:
public
Pass
{
class
GraphVizPass
:
public
Pass
{
public:
using
marked_nodes_t
=
std
::
unordered_set
<
const
Node
*>
;
protected:
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
;
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
;
// Tell whether there are any marked nodes in the graph. Consume the
// corresponding attribute.
marked_nodes_t
ConsumeMarkedNodes
(
Graph
*
graph
)
const
;
};
};
}
// namespace ir
}
// namespace ir
...
...
paddle/fluid/framework/ir/node.h
浏览文件 @
902f19b4
...
@@ -29,20 +29,26 @@ class Node {
...
@@ -29,20 +29,26 @@ class Node {
enum
class
Type
{
kOperation
,
kVariable
};
enum
class
Type
{
kOperation
,
kVariable
};
static
constexpr
char
kControlDepVarName
[]
=
"__control_var"
;
static
constexpr
char
kControlDepVarName
[]
=
"__control_var"
;
explicit
Node
(
const
std
::
string
&
name
,
Type
type
)
explicit
Node
(
const
std
::
string
&
name
,
Type
type
,
int
id
=
-
1
)
:
name_
(
name
),
var_desc_
(
nullptr
),
op_desc_
(
nullptr
),
type_
(
type
)
{}
:
name_
(
name
),
var_desc_
(
nullptr
),
op_desc_
(
nullptr
),
type_
(
type
),
id_
(
id
)
{}
explicit
Node
(
VarDesc
*
var_desc
)
explicit
Node
(
VarDesc
*
var_desc
,
int
id
=
-
1
)
:
name_
(
var_desc
->
Name
()),
:
name_
(
var_desc
->
Name
()),
var_desc_
(
new
VarDesc
(
*
var_desc
)),
var_desc_
(
new
VarDesc
(
*
var_desc
)),
op_desc_
(
nullptr
),
op_desc_
(
nullptr
),
type_
(
Type
::
kVariable
)
{}
type_
(
Type
::
kVariable
),
id_
(
id
)
{}
explicit
Node
(
OpDesc
*
op_desc
)
explicit
Node
(
OpDesc
*
op_desc
,
int
id
=
-
1
)
:
name_
(
op_desc
->
Type
()),
:
name_
(
op_desc
->
Type
()),
var_desc_
(
nullptr
),
var_desc_
(
nullptr
),
op_desc_
(
new
OpDesc
(
*
op_desc
,
op_desc
->
Block
())),
op_desc_
(
new
OpDesc
(
*
op_desc
,
op_desc
->
Block
())),
type_
(
Type
::
kOperation
)
{}
type_
(
Type
::
kOperation
),
id_
(
id
)
{}
Type
NodeType
()
const
{
return
type_
;
}
Type
NodeType
()
const
{
return
type_
;
}
...
@@ -58,6 +64,8 @@ class Node {
...
@@ -58,6 +64,8 @@ class Node {
return
op_desc_
.
get
();
return
op_desc_
.
get
();
}
}
int
id
()
const
{
return
id_
;
}
bool
IsOp
()
const
{
return
type_
==
Type
::
kOperation
;
}
bool
IsOp
()
const
{
return
type_
==
Type
::
kOperation
;
}
bool
IsVar
()
const
{
return
type_
==
Type
::
kVariable
;
}
bool
IsVar
()
const
{
return
type_
==
Type
::
kVariable
;
}
...
@@ -69,6 +77,7 @@ class Node {
...
@@ -69,6 +77,7 @@ class Node {
std
::
unique_ptr
<
VarDesc
>
var_desc_
;
std
::
unique_ptr
<
VarDesc
>
var_desc_
;
std
::
unique_ptr
<
OpDesc
>
op_desc_
;
std
::
unique_ptr
<
OpDesc
>
op_desc_
;
Type
type_
;
Type
type_
;
int
id_
;
private:
private:
DISABLE_COPY_AND_ASSIGN
(
Node
);
DISABLE_COPY_AND_ASSIGN
(
Node
);
...
...
paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc
0 → 100644
浏览文件 @
902f19b4
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/lod_tensor.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
struct
FuseExpr
{};
// sequence expand, concat fuse pattern, return concat's output
PDNode
*
BuildSeqExpandConcatPattern
(
PDPattern
*
pattern
)
{
// The following operators will be fused:
// concat
// sequence_expand
// sequence_expand
// The following variables will be treat as inputs:
// concat mid input, 0th input for fused op
// sequence_expand input, 1th input for fused op
// sequence_expand input, 2th input for fused op
// The following variables will be treat as outputs:
// concat output
// So the following variables will be removed:
// sequence-expand output
// sequence-expand output
// Three operators
auto
*
sequence_expand0
=
pattern
->
NewNode
(
[](
Node
*
x
)
{
return
x
&&
x
->
IsOp
()
&&
x
->
Op
()
->
Type
()
==
"sequence_expand"
;
},
"sequence_expand0"
);
auto
*
sequence_expand1
=
pattern
->
NewNode
(
[](
Node
*
x
)
{
return
x
&&
x
->
IsOp
()
&&
x
->
Op
()
->
Type
()
==
"sequence_expand"
;
},
"sequence_expand1"
);
auto
*
concat
=
pattern
->
NewNode
(
[](
Node
*
x
)
{
return
x
&&
x
->
IsOp
()
&&
x
->
Op
()
->
Type
()
==
"concat"
&&
// basic check
x
->
Op
()
->
Input
(
"X"
).
size
()
==
3
;
// Special case
},
"concat"
);
auto
*
sequence_expand0_in
=
pattern
->
NewNode
(
[](
Node
*
x
)
{
return
x
&&
x
->
IsVar
()
&&
VarLinksToOp
(
x
,
"sequence_expand"
);
},
"sequence_expand0_in"
);
auto
*
sequence_expand1_in
=
pattern
->
NewNode
(
[](
Node
*
x
)
{
return
x
&&
x
->
IsVar
()
&&
VarLinksToOp
(
x
,
"sequence_expand"
);
},
"sequence_expand1_in"
);
// The variables
auto
*
sequence_expand0_out
=
pattern
->
NewNode
(
[](
Node
*
x
)
{
return
x
&&
x
->
IsVar
()
&&
VarLinksFromOp
(
x
,
"sequence_expand"
)
&&
// basic check
VarLinksToOp
(
x
,
"concat"
)
&&
// is concat's input
IsNthInput
(
x
,
x
->
outputs
[
0
],
"X"
,
1
);
// X[0]
},
"sequence_expand0_out"
);
auto
*
sequence_expand1_out
=
pattern
->
NewNode
(
[](
Node
*
x
)
{
return
x
&&
x
->
IsVar
()
&&
VarLinksFromOp
(
x
,
"sequence_expand"
)
&&
// basic check
VarLinksToOp
(
x
,
"concat"
)
&&
// is concat's input
IsNthInput
(
x
,
x
->
outputs
[
0
],
"X"
,
2
);
// x[2]
},
"sequence_expand1_out"
);
auto
*
concat_in0
=
pattern
->
NewNode
(
[](
Node
*
x
)
{
return
x
&&
x
->
IsVar
()
&&
VarLinksToOp
(
x
,
"concat"
);
},
"concat_in0"
);
auto
*
concat_out
=
pattern
->
NewNode
(
[](
Node
*
x
)
{
return
x
&&
x
->
IsVar
()
&&
VarLinksFromOp
(
x
,
"concat"
);
},
"concat_out"
);
// Links
sequence_expand0
->
LinksFrom
({
sequence_expand0_in
})
.
LinksTo
({
sequence_expand0_out
});
sequence_expand1
->
LinksFrom
({
sequence_expand1_in
})
.
LinksTo
({
sequence_expand1_out
});
concat
->
LinksFrom
({
sequence_expand0_out
,
sequence_expand1_out
,
concat_in0
})
.
LinksTo
({
concat_out
});
return
concat_out
;
}
PDNode
*
BuildFCPattern
(
PDPattern
*
pattern
,
PDNode
*
fc_x
)
{
PDNode
*
fc_w
=
pattern
->
NewNode
(
[](
Node
*
x
)
{
return
x
&&
x
->
IsVar
()
&&
// basic
VarLinksToOp
(
x
,
"mul"
)
&&
// link
x
->
Var
()
->
Proto
()
->
persistable
();
// is a parameter
},
"fc_w"
);
PDNode
*
mul_out
=
pattern
->
NewNode
(
[](
Node
*
x
)
{
return
x
&&
x
->
IsVar
()
&&
// basic
VarLinksFromOp
(
x
,
"mul"
)
&&
// link
VarLinksToOp
(
x
,
"elementwise_add"
)
&&
//
!
x
->
Var
()
->
Proto
()
->
persistable
();
// is a parameter
},
"mul_out"
);
PDNode
*
fc_mul
=
pattern
->
NewNode
(
[](
Node
*
x
)
{
return
x
&&
x
->
IsOp
()
&&
x
->
Op
()
->
Type
()
==
"mul"
;
// basic
},
"fc_mul"
);
PDNode
*
fc_bias
=
pattern
->
NewNode
(
[](
Node
*
x
)
{
return
x
&&
x
->
IsVar
()
&&
// basic
VarLinksToOp
(
x
,
"elementwise_add"
)
&&
// link
x
->
Var
()
->
Proto
()
->
persistable
();
// is a parameter
},
"fc_bias"
);
PDNode
*
elementwise_add
=
pattern
->
NewNode
(
[](
Node
*
x
)
{
return
x
&&
x
->
IsOp
()
&&
x
->
Op
()
->
Type
()
==
"elementwise_add"
;
},
"elementwise_add"
);
PDNode
*
add_out
=
pattern
->
NewNode
(
[](
Node
*
x
)
{
return
x
&&
x
->
IsVar
()
&&
// basic
VarLinksFromOp
(
x
,
"elementwise_add"
)
&&
// link
!
x
->
Var
()
->
Proto
()
->
persistable
();
// is a parameter
},
"add_out"
);
std
::
set
<
std
::
string
>
acts
({
"sigmoid"
,
"tanh"
,
"relu"
,
"identity"
});
PDNode
*
act
=
pattern
->
NewNode
(
[
=
](
Node
*
x
)
{
return
x
&&
x
->
IsOp
()
&&
acts
.
count
(
x
->
Op
()
->
Type
());
},
"act"
);
PDNode
*
fc_out
=
pattern
->
NewNode
(
[](
Node
*
x
)
{
return
x
&&
x
->
IsVar
()
&&
// basic
!
x
->
Var
()
->
Proto
()
->
persistable
();
// is a parameter
},
"fc_out"
);
fc_mul
->
LinksFrom
({
fc_w
,
fc_x
}).
LinksTo
({
mul_out
});
elementwise_add
->
LinksFrom
({
mul_out
,
fc_bias
}).
LinksTo
({
add_out
});
act
->
LinksFrom
({
add_out
}).
LinksTo
({
fc_out
});
return
fc_out
;
}
std
::
unique_ptr
<
ir
::
Graph
>
SeqConcatFcFusePass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
FusePassBase
::
Init
(
graph
.
get
());
GraphPatternDetector
detector
;
auto
*
pattern
=
detector
.
mutable_pattern
();
auto
*
concat_out
=
BuildSeqExpandConcatPattern
(
pattern
);
BuildFCPattern
(
pattern
,
concat_out
);
#define GET_NODE(id, pattern) \
PADDLE_ENFORCE(subgraph.count(pattern.RetriveNode(#id)), \
"pattern has no Node called %s", #id); \
auto* id = subgraph.at(pattern.RetriveNode(#id)); \
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
detector
(
graph
.
get
(),
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
VLOG
(
4
)
<<
"get one concat pattern"
;
// fc
GET_NODE
(
fc_w
,
detector
.
pattern
());
GET_NODE
(
fc_bias
,
detector
.
pattern
());
GET_NODE
(
act
,
detector
.
pattern
());
GET_NODE
(
fc_out
,
detector
.
pattern
());
// concat
GET_NODE
(
concat_in0
,
detector
.
pattern
());
GET_NODE
(
sequence_expand0_in
,
detector
.
pattern
());
GET_NODE
(
sequence_expand1_in
,
detector
.
pattern
());
OpDesc
op_desc
;
op_desc
.
SetType
(
"fusion_seqexpand_concat_fc"
);
op_desc
.
SetInput
(
"X"
,
{
concat_in0
->
Name
(),
sequence_expand0_in
->
Name
(),
sequence_expand1_in
->
Name
()});
op_desc
.
SetInput
(
"FCWeight"
,
{
fc_w
->
Name
()});
op_desc
.
SetInput
(
"FCBias"
,
{
fc_bias
->
Name
()});
const
std
::
string
fc_out_tmp
=
fc_out
->
Name
()
+
".tmp"
;
param_scope
()
->
Var
(
fc_out_tmp
)
->
GetMutable
<
framework
::
LoDTensor
>
();
op_desc
.
SetOutput
(
"FCOut"
,
{
fc_out_tmp
});
op_desc
.
SetOutput
(
"Out"
,
{
fc_out
->
Name
()});
op_desc
.
SetAttr
(
"fc_activation"
,
act
->
Op
()
->
Type
());
auto
*
op_node
=
graph
->
CreateOpNode
(
&
op_desc
);
// Add links
#define NODE_LINKS(a, b) \
a->outputs.push_back(b); \
b->inputs.push_back(a);
NODE_LINKS
(
fc_w
,
op_node
);
NODE_LINKS
(
fc_bias
,
op_node
);
NODE_LINKS
(
concat_in0
,
op_node
);
NODE_LINKS
(
sequence_expand0_in
,
op_node
);
NODE_LINKS
(
sequence_expand1_in
,
op_node
);
NODE_LINKS
(
op_node
,
fc_out
);
// Clean nodes.
std
::
unordered_set
<
const
Node
*>
marked_nodes
;
for
(
auto
&
item
:
subgraph
)
{
marked_nodes
.
insert
(
item
.
second
);
}
marked_nodes
.
erase
(
fc_w
);
marked_nodes
.
erase
(
fc_bias
);
marked_nodes
.
erase
(
concat_in0
);
marked_nodes
.
erase
(
sequence_expand0_in
);
marked_nodes
.
erase
(
sequence_expand1_in
);
marked_nodes
.
erase
(
fc_out
);
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
});
return
graph
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
seq_concat_fc_fuse_pass
,
paddle
::
framework
::
ir
::
SeqConcatFcFusePass
);
paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.h
0 → 100644
浏览文件 @
902f19b4
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
SeqConcatFcFusePass
:
public
FusePassBase
{
public:
virtual
~
SeqConcatFcFusePass
()
{}
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/inference/analysis/CMakeLists.txt
浏览文件 @
902f19b4
cc_library
(
ir_pass_manager SRCS ir_pass_manager.cc DEPS graph pass
)
cc_library
(
ir_pass_manager SRCS ir_pass_manager.cc DEPS graph pass
)
cc_library
(
analysis SRCS pass_manager.cc dot.cc node.cc data_flow_graph.cc graph_traits.cc subgraph_splitter.cc
set
(
analysis_deps
framework_proto proto_desc ir_pass_manager graph pass paddle_fluid_api executor
)
cc_library
(
analysis SRCS pass_manager.cc node.cc data_flow_graph.cc graph_traits.cc subgraph_splitter.cc
analyzer.cc
analyzer.cc
helper.cc
helper.cc
# passes
# passes
...
@@ -10,11 +13,11 @@ cc_library(analysis SRCS pass_manager.cc dot.cc node.cc data_flow_graph.cc graph
...
@@ -10,11 +13,11 @@ cc_library(analysis SRCS pass_manager.cc dot.cc node.cc data_flow_graph.cc graph
tensorrt_subgraph_node_mark_pass.cc
tensorrt_subgraph_node_mark_pass.cc
fluid_to_ir_pass.cc
fluid_to_ir_pass.cc
model_store_pass.cc
model_store_pass.cc
DEPS
framework_proto proto_desc ir_pass_manager graph pass
)
DEPS
${
analysis_deps
}
)
cc_test
(
test_node SRCS node_tester.cc DEPS analysis
)
cc_test
(
test_node SRCS node_tester.cc DEPS analysis
)
cc_test
(
test_dot SRCS dot_tester.cc DEPS analysis
)
cc_test
(
test_dot SRCS dot_tester.cc DEPS analysis
)
cc_binary
(
inference_analyzer SRCS analyzer_main.cc DEPS analysis
)
cc_binary
(
inference_analyzer SRCS analyzer_main.cc DEPS analysis
paddle_fluid
)
set
(
PYTHON_TESTS_DIR
${
PADDLE_BINARY_DIR
}
/python/paddle/fluid/tests
)
set
(
PYTHON_TESTS_DIR
${
PADDLE_BINARY_DIR
}
/python/paddle/fluid/tests
)
...
@@ -31,7 +34,7 @@ function (inference_analysis_test TARGET)
...
@@ -31,7 +34,7 @@ function (inference_analysis_test TARGET)
endif
()
endif
()
cc_test
(
${
TARGET
}
cc_test
(
${
TARGET
}
SRCS
"
${
analysis_test_SRCS
}
"
SRCS
"
${
analysis_test_SRCS
}
"
DEPS analysis graph fc_fuse_pass graph_viz_pass infer_clean_graph_pass graph_pattern_detect
e
r pass
${
analysis_test_EXTRA_DEPS
}
DEPS analysis graph fc_fuse_pass graph_viz_pass infer_clean_graph_pass graph_pattern_detect
o
r pass
${
analysis_test_EXTRA_DEPS
}
ARGS --inference_model_dir=
${
PYTHON_TESTS_DIR
}
/book/word2vec.inference.model
${
mem_opt
}
)
ARGS --inference_model_dir=
${
PYTHON_TESTS_DIR
}
/book/word2vec.inference.model
${
mem_opt
}
)
set_tests_properties
(
${
TARGET
}
PROPERTIES DEPENDS test_word2vec
)
set_tests_properties
(
${
TARGET
}
PROPERTIES DEPENDS test_word2vec
)
endif
(
WITH_TESTING
)
endif
(
WITH_TESTING
)
...
@@ -58,20 +61,25 @@ endif()
...
@@ -58,20 +61,25 @@ endif()
inference_analysis_test
(
test_analyzer SRCS analyzer_tester.cc
inference_analysis_test
(
test_analyzer SRCS analyzer_tester.cc
EXTRA_DEPS paddle_inference_api paddle_fluid_api ir_pass_manager analysis
EXTRA_DEPS paddle_inference_api paddle_fluid_api ir_pass_manager analysis
analysis_predictor
# ir
# ir
fc_fuse_pass
fc_fuse_pass
fc_lstm_fuse_pass
seq_concat_fc_fuse_pass
graph_viz_pass
graph_viz_pass
infer_clean_graph_pass
infer_clean_graph_pass
graph_pattern_detect
e
r
graph_pattern_detect
o
r
infer_clean_graph_pass
infer_clean_graph_pass
attention_lstm_fuse_pass
paddle_inference_api
pass
pass
ARGS --inference_model_dir=
${
PYTHON_TESTS_DIR
}
/book/word2vec.inference.model
ARGS --inference_model_dir=
${
PYTHON_TESTS_DIR
}
/book/word2vec.inference.model
--infer_ditu_rnn_model=
${
DITU_INSTALL_DIR
}
/model
--infer_ditu_rnn_model=
${
DITU_INSTALL_DIR
}
/model
--infer_ditu_rnn_data=
${
DITU_INSTALL_DIR
}
/data.txt
)
--infer_ditu_rnn_data=
${
DITU_INSTALL_DIR
}
/data.txt
)
inference_analysis_test
(
test_data_flow_graph SRCS data_flow_graph_tester.cc
)
inference_analysis_test
(
test_data_flow_graph SRCS data_flow_graph_tester.cc
)
inference_analysis_test
(
test_data_flow_graph_to_fluid_pass SRCS data_flow_graph_to_fluid_pass_tester.cc
)
inference_analysis_test
(
test_data_flow_graph_to_fluid_pass SRCS data_flow_graph_to_fluid_pass_tester.cc
EXTRA_DEPS paddle_inference_api
)
inference_analysis_test
(
test_fluid_to_ir_pass SRCS fluid_to_ir_pass_tester.cc
)
inference_analysis_test
(
test_fluid_to_ir_pass SRCS fluid_to_ir_pass_tester.cc
EXTRA_DEPS paddle_fluid
)
inference_analysis_test
(
test_fluid_to_data_flow_graph_pass SRCS fluid_to_data_flow_graph_pass_tester.cc
)
inference_analysis_test
(
test_fluid_to_data_flow_graph_pass SRCS fluid_to_data_flow_graph_pass_tester.cc
)
inference_analysis_test
(
test_subgraph_splitter SRCS subgraph_splitter_tester.cc
)
inference_analysis_test
(
test_subgraph_splitter SRCS subgraph_splitter_tester.cc
)
inference_analysis_test
(
test_dfg_graphviz_draw_pass SRCS dfg_graphviz_draw_pass_tester.cc
)
inference_analysis_test
(
test_dfg_graphviz_draw_pass SRCS dfg_graphviz_draw_pass_tester.cc
)
...
...
paddle/fluid/inference/analysis/analyzer.cc
浏览文件 @
902f19b4
...
@@ -102,6 +102,19 @@ class DfgPassManagerImpl final : public DfgPassManager {
...
@@ -102,6 +102,19 @@ class DfgPassManagerImpl final : public DfgPassManager {
Analyzer
::
Analyzer
()
{
Register
(
"manager1"
,
new
DfgPassManagerImpl
);
}
Analyzer
::
Analyzer
()
{
Register
(
"manager1"
,
new
DfgPassManagerImpl
);
}
void
Analyzer
::
Run
(
Argument
*
argument
)
{
void
Analyzer
::
Run
(
Argument
*
argument
)
{
// Ungly support fluid-to-ir-pass
argument
->
Set
(
kFluidToIrPassesAttr
,
new
std
::
vector
<
std
::
string
>
({
// Manual update the passes here.
"graph_viz_pass"
,
//
"infer_clean_graph_pass"
,
"graph_viz_pass"
,
//
"attention_lstm_fuse_pass"
,
"graph_viz_pass"
,
//
"fc_lstm_fuse_pass"
,
"graph_viz_pass"
,
//
"seq_concat_fc_fuse_pass"
,
"graph_viz_pass"
,
//
"fc_fuse_pass"
,
"graph_viz_pass"
//
}));
for
(
auto
&
x
:
data_
)
{
for
(
auto
&
x
:
data_
)
{
PADDLE_ENFORCE
(
x
->
Initialize
(
argument
));
PADDLE_ENFORCE
(
x
->
Initialize
(
argument
));
x
->
RunAll
();
x
->
RunAll
();
...
...
paddle/fluid/inference/analysis/analyzer_tester.cc
浏览文件 @
902f19b4
...
@@ -20,6 +20,7 @@
...
@@ -20,6 +20,7 @@
#include "paddle/fluid/inference/analysis/ut_helper.h"
#include "paddle/fluid/inference/analysis/ut_helper.h"
#include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/platform/profiler.h"
DEFINE_string
(
infer_ditu_rnn_model
,
""
,
"model path for ditu RNN"
);
DEFINE_string
(
infer_ditu_rnn_model
,
""
,
"model path for ditu RNN"
);
DEFINE_string
(
infer_ditu_rnn_data
,
""
,
"data path for ditu RNN"
);
DEFINE_string
(
infer_ditu_rnn_data
,
""
,
"data path for ditu RNN"
);
...
@@ -264,39 +265,24 @@ void TestDituRNNPrediction(const std::string &model_path,
...
@@ -264,39 +265,24 @@ void TestDituRNNPrediction(const std::string &model_path,
const
std
::
string
&
data_path
,
int
batch_size
,
const
std
::
string
&
data_path
,
int
batch_size
,
bool
use_analysis
,
bool
activate_ir
,
bool
use_analysis
,
bool
activate_ir
,
int
num_times
=
1
)
{
int
num_times
=
1
)
{
FLAGS_IA_enable_ir
=
activate_ir
;
FLAGS_IA_enable_tensorrt_subgraph_engine
=
false
;
FLAGS_IA_output_storage_path
=
"./analysis.out"
;
std
::
string
model_out
;
if
(
use_analysis
)
{
Argument
argument
(
model_path
);
argument
.
model_output_store_path
.
reset
(
new
std
::
string
(
"./analysis.out"
));
Analyzer
analyzer
;
analyzer
.
Run
(
&
argument
);
// Should get the transformed model stored to ./analysis.out
model_out
=
"./analysis.out"
;
ASSERT_TRUE
(
PathExists
(
model_out
));
}
else
{
model_out
=
FLAGS_infer_ditu_rnn_model
;
}
NativeConfig
config
;
NativeConfig
config
;
config
.
prog_file
=
model_out
+
"/__model__"
;
config
.
prog_file
=
FLAGS_infer_ditu_rnn_model
+
"/__model__"
;
config
.
param_file
=
model_out
+
"/param"
;
config
.
param_file
=
FLAGS_infer_ditu_rnn_model
+
"/param"
;
config
.
use_gpu
=
false
;
config
.
use_gpu
=
false
;
config
.
device
=
0
;
config
.
device
=
0
;
config
.
specify_input_name
=
true
;
config
.
specify_input_name
=
true
;
auto
predictor
=
auto
base_
predictor
=
CreatePaddlePredictor
<
NativeConfig
,
PaddleEngineKind
::
kNative
>
(
config
);
CreatePaddlePredictor
<
NativeConfig
,
PaddleEngineKind
::
kNative
>
(
config
);
auto
predictor
=
CreatePaddlePredictor
<
NativeConfig
,
PaddleEngineKind
::
kAnalysis
>
(
config
);
std
::
vector
<
PaddleTensor
>
input_slots
;
std
::
vector
<
PaddleTensor
>
input_slots
;
DataRecord
data
(
data_path
,
batch_size
);
DataRecord
data
(
data_path
,
batch_size
);
// Prepare inputs.
// Prepare inputs.
PrepareInputs
(
&
input_slots
,
&
data
,
batch_size
);
PrepareInputs
(
&
input_slots
,
&
data
,
batch_size
);
std
::
vector
<
PaddleTensor
>
outputs
;
std
::
vector
<
PaddleTensor
>
outputs
,
base_outputs
;
base_predictor
->
Run
(
input_slots
,
&
base_outputs
);
Timer
timer
;
Timer
timer
;
timer
.
tic
();
timer
.
tic
();
...
@@ -308,37 +294,25 @@ void TestDituRNNPrediction(const std::string &model_path,
...
@@ -308,37 +294,25 @@ void TestDituRNNPrediction(const std::string &model_path,
<<
", latency: "
<<
timer
.
toc
()
/
num_times
<<
"ms"
;
<<
", latency: "
<<
timer
.
toc
()
/
num_times
<<
"ms"
;
LOG
(
INFO
)
<<
"====================================="
;
LOG
(
INFO
)
<<
"====================================="
;
for
(
auto
&
out
:
outputs
)
{
PADDLE_ENFORCE_GT
(
outputs
.
size
(),
0
);
PADDLE_ENFORCE_EQ
(
outputs
.
size
(),
base_outputs
.
size
());
for
(
size_t
i
=
0
;
i
<
outputs
.
size
();
i
++
)
{
auto
&
out
=
outputs
[
i
];
auto
&
base_out
=
base_outputs
[
i
];
size_t
size
=
std
::
accumulate
(
out
.
shape
.
begin
(),
out
.
shape
.
end
(),
1
,
size_t
size
=
std
::
accumulate
(
out
.
shape
.
begin
(),
out
.
shape
.
end
(),
1
,
[](
int
a
,
int
b
)
{
return
a
*
b
;
});
[](
int
a
,
int
b
)
{
return
a
*
b
;
});
size_t
size1
=
std
::
accumulate
(
base_out
.
shape
.
begin
(),
base_out
.
shape
.
end
(),
1
,
[](
int
a
,
int
b
)
{
return
a
*
b
;
});
PADDLE_ENFORCE_EQ
(
size
,
size1
);
PADDLE_ENFORCE_GT
(
size
,
0
);
float
*
data
=
static_cast
<
float
*>
(
out
.
data
.
data
());
float
*
data
=
static_cast
<
float
*>
(
out
.
data
.
data
());
for
(
size_t
i
=
0
;
float
*
base_data
=
static_cast
<
float
*>
(
base_out
.
data
.
data
());
i
<
std
::
min
(
sizeof
(
ditu_rnn_target_data
)
/
sizeof
(
float
),
size
);
for
(
size_t
i
=
0
;
i
<
size
;
i
++
)
{
i
++
)
{
EXPECT_NEAR
(
data
[
i
],
base_data
[
i
],
1e-3
);
EXPECT_NEAR
(
data
[
i
],
ditu_rnn_target_data
[
i
],
1e-3
);
}
}
}
}
}
}
// Turn on the IR pass supportion, run a real inference and check the result.
TEST
(
Analyzer
,
SupportIRPass
)
{
FLAGS_IA_enable_ir
=
true
;
FLAGS_IA_enable_tensorrt_subgraph_engine
=
false
;
FLAGS_IA_output_storage_path
=
"./analysis.out"
;
Argument
argument
(
FLAGS_inference_model_dir
);
argument
.
model_output_store_path
.
reset
(
new
std
::
string
(
"./analysis.out"
));
Analyzer
analyzer
;
analyzer
.
Run
(
&
argument
);
// Should get the transformed model stored to ./analysis.out
ASSERT_TRUE
(
PathExists
(
"./analysis.out"
));
// Inference from this path.
TestWord2vecPrediction
(
"./analysis.out"
);
}
// Directly infer with the original model.
// Directly infer with the original model.
TEST
(
Analyzer
,
DituRNN_without_analysis
)
{
TEST
(
Analyzer
,
DituRNN_without_analysis
)
{
TestDituRNNPrediction
(
FLAGS_infer_ditu_rnn_model
,
FLAGS_infer_ditu_rnn_data
,
TestDituRNNPrediction
(
FLAGS_infer_ditu_rnn_model
,
FLAGS_infer_ditu_rnn_data
,
...
@@ -365,5 +339,8 @@ TEST(Analyzer, DituRNN_with_analysis_with_IR) {
...
@@ -365,5 +339,8 @@ TEST(Analyzer, DituRNN_with_analysis_with_IR) {
}
// namespace paddle
}
// namespace paddle
USE_PASS
(
fc_fuse_pass
);
USE_PASS
(
fc_fuse_pass
);
USE_PASS
(
seq_concat_fc_fuse_pass
);
USE_PASS
(
fc_lstm_fuse_pass
);
USE_PASS
(
graph_viz_pass
);
USE_PASS
(
graph_viz_pass
);
USE_PASS
(
infer_clean_graph_pass
);
USE_PASS
(
infer_clean_graph_pass
);
USE_PASS
(
attention_lstm_fuse_pass
);
paddle/fluid/inference/analysis/argument.h
浏览文件 @
902f19b4
...
@@ -26,6 +26,7 @@
...
@@ -26,6 +26,7 @@
#include <string>
#include <string>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/platform/variant.h"
namespace
paddle
{
namespace
paddle
{
namespace
inference
{
namespace
inference
{
...
@@ -58,6 +59,46 @@ struct Argument {
...
@@ -58,6 +59,46 @@ struct Argument {
// The output storage path of ModelStorePass.
// The output storage path of ModelStorePass.
std
::
unique_ptr
<
std
::
string
>
model_output_store_path
;
std
::
unique_ptr
<
std
::
string
>
model_output_store_path
;
// Support for any other attributes.
template
<
typename
T
>
void
Set
(
const
std
::
string
&
key
,
T
*
data
)
{
PADDLE_ENFORCE_NOT_NULL
(
data
);
PADDLE_ENFORCE
(
!
attrs_
.
count
(
key
),
"duplicate attr called %s"
,
key
);
attrs_
[
key
]
=
data
;
attr_deleters_
[
key
]
=
[
data
,
key
,
this
]()
{
VLOG
(
3
)
<<
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
;
VLOG
(
3
)
<<
"argument delete attr: "
<<
key
;
delete
data
;
};
}
bool
Has
(
const
std
::
string
&
name
)
const
{
return
attrs_
.
count
(
name
);
}
template
<
typename
T
>
T
*
Release
(
const
std
::
string
&
key
)
{
PADDLE_ENFORCE
(
attrs_
.
count
(
key
));
auto
*
res
=
boost
::
any_cast
<
T
*>
(
attrs_
.
at
(
key
));
attrs_
.
erase
(
key
);
attr_deleters_
.
erase
(
key
);
return
res
;
}
template
<
typename
T
>
T
&
Get
(
const
std
::
string
&
key
)
{
PADDLE_ENFORCE
(
Has
(
key
));
return
*
boost
::
any_cast
<
T
*>
(
attrs_
.
at
(
key
));
}
~
Argument
()
{
for
(
auto
&
item
:
attr_deleters_
)
{
item
.
second
();
}
}
private:
std
::
unordered_map
<
std
::
string
,
boost
::
any
>
attrs_
;
std
::
unordered_map
<
std
::
string
,
std
::
function
<
void
()
>>
attr_deleters_
;
};
};
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
...
...
paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc
浏览文件 @
902f19b4
...
@@ -19,6 +19,7 @@
...
@@ -19,6 +19,7 @@
#include "paddle/fluid/framework/proto_desc.h"
#include "paddle/fluid/framework/proto_desc.h"
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h"
#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h"
#include "paddle/fluid/inference/io.h"
namespace
paddle
{
namespace
paddle
{
namespace
inference
{
namespace
inference
{
...
@@ -65,6 +66,10 @@ void DataFlowGraphToFluidPass::Run(DataFlowGraph *graph) {
...
@@ -65,6 +66,10 @@ void DataFlowGraphToFluidPass::Run(DataFlowGraph *graph) {
}
}
}
}
if
(
argument_
->
Has
(
"param_scope"
))
{
LOG
(
WARNING
)
<<
"parameter changes in the scope takes effect"
;
}
PADDLE_ENFORCE
(
argument_
->
transformed_program_desc
.
get
());
PADDLE_ENFORCE
(
argument_
->
transformed_program_desc
.
get
());
}
}
...
...
paddle/fluid/inference/analysis/dot.h
浏览文件 @
902f19b4
...
@@ -29,13 +29,13 @@ namespace paddle {
...
@@ -29,13 +29,13 @@ namespace paddle {
namespace
inference
{
namespace
inference
{
namespace
analysis
{
namespace
analysis
{
static
size_t
dot_node_counter
{
0
};
/*
/*
* A Dot template that helps to build a DOT graph definition.
* A Dot template that helps to build a DOT graph definition.
*/
*/
class
Dot
{
class
Dot
{
public:
public:
static
size_t
counter
;
struct
Attr
{
struct
Attr
{
std
::
string
key
;
std
::
string
key
;
std
::
string
value
;
std
::
string
value
;
...
@@ -57,7 +57,7 @@ class Dot {
...
@@ -57,7 +57,7 @@ class Dot {
Node
(
const
std
::
string
&
name
,
const
std
::
vector
<
Attr
>&
attrs
)
Node
(
const
std
::
string
&
name
,
const
std
::
vector
<
Attr
>&
attrs
)
:
name
(
name
),
:
name
(
name
),
attrs
(
attrs
),
attrs
(
attrs
),
id_
(
"node_"
+
std
::
to_string
(
Dot
::
counter
++
))
{}
id_
(
"node_"
+
std
::
to_string
(
dot_node_
counter
++
))
{}
std
::
string
id
()
const
{
return
id_
;
}
std
::
string
id
()
const
{
return
id_
;
}
...
@@ -65,6 +65,10 @@ class Dot {
...
@@ -65,6 +65,10 @@ class Dot {
std
::
stringstream
ss
;
std
::
stringstream
ss
;
CHECK
(
!
name
.
empty
());
CHECK
(
!
name
.
empty
());
ss
<<
id_
;
ss
<<
id_
;
if
(
attrs
.
empty
())
{
ss
<<
"[label="
<<
'"'
<<
name
<<
'"'
<<
"]"
;
return
ss
.
str
();
}
for
(
size_t
i
=
0
;
i
<
attrs
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
attrs
.
size
();
i
++
)
{
if
(
i
==
0
)
{
if
(
i
==
0
)
{
ss
<<
"[label="
<<
'"'
<<
name
<<
'"'
<<
" "
;
ss
<<
"[label="
<<
'"'
<<
name
<<
'"'
<<
" "
;
...
@@ -108,9 +112,11 @@ class Dot {
...
@@ -108,9 +112,11 @@ class Dot {
explicit
Dot
(
const
std
::
vector
<
Attr
>&
attrs
)
:
attrs_
(
attrs
)
{}
explicit
Dot
(
const
std
::
vector
<
Attr
>&
attrs
)
:
attrs_
(
attrs
)
{}
void
AddNode
(
const
std
::
string
&
name
,
const
std
::
vector
<
Attr
>&
attrs
)
{
void
AddNode
(
const
std
::
string
&
id
,
const
std
::
vector
<
Attr
>&
attrs
,
CHECK
(
!
nodes_
.
count
(
name
))
<<
"duplicate Node '"
<<
name
<<
"'"
;
std
::
string
label
=
""
)
{
nodes_
.
emplace
(
name
,
Node
{
name
,
attrs
});
CHECK
(
!
nodes_
.
count
(
id
))
<<
"duplicate Node '"
<<
id
<<
"'"
;
if
(
label
.
empty
())
label
=
id
;
nodes_
.
emplace
(
id
,
Node
{
label
,
attrs
});
}
}
void
AddEdge
(
const
std
::
string
&
source
,
const
std
::
string
&
target
,
void
AddEdge
(
const
std
::
string
&
source
,
const
std
::
string
&
target
,
...
...
paddle/fluid/inference/analysis/fluid_to_ir_pass.cc
浏览文件 @
902f19b4
...
@@ -13,3 +13,47 @@
...
@@ -13,3 +13,47 @@
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/inference/analysis/fluid_to_ir_pass.h"
#include "paddle/fluid/inference/analysis/fluid_to_ir_pass.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
namespace
paddle
{
namespace
inference
{
namespace
analysis
{
void
FluidToIrPass
::
EnableParamModify
(
const
std
::
string
&
model_dir
,
const
std
::
string
&
prog_file
,
const
std
::
string
&
param_file
)
{
PADDLE_ENFORCE
(
argument_
);
argument_
->
Set
(
"param_scope"
,
new
framework
::
Scope
);
// Load parameters.
VLOG
(
3
)
<<
"Loading parameters from "
<<
model_dir
;
LoadParams
(
&
argument_
->
Get
<
framework
::
Scope
>
(
"param_scope"
),
model_dir
,
prog_file
,
param_file
);
}
bool
FluidToIrPass
::
LoadParams
(
framework
::
Scope
*
scope
,
const
std
::
string
&
dir
,
const
std
::
string
&
prog_file
,
const
std
::
string
&
param_file
)
{
platform
::
CPUPlace
place
;
platform
::
CPUDeviceContext
ctx
(
place
);
framework
::
Executor
executor
(
place
);
PADDLE_ENFORCE
(
argument_
->
origin_program_desc
.
get
());
framework
::
ProgramDesc
program
(
*
argument_
->
origin_program_desc
);
if
((
!
prog_file
.
empty
())
&&
(
!
param_file
.
empty
()))
{
LOG
(
INFO
)
<<
"load single model file from "
<<
prog_file
;
Load
(
&
executor
,
scope
,
prog_file
,
param_file
);
}
else
if
(
!
dir
.
empty
())
{
LOG
(
INFO
)
<<
"load from dir "
<<
dir
;
Load
(
&
executor
,
scope
,
dir
);
}
else
{
LOG
(
ERROR
)
<<
"failed to load parameters"
;
return
false
;
}
return
true
;
}
}
// namespace analysis
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/analysis/fluid_to_ir_pass.h
浏览文件 @
902f19b4
...
@@ -21,12 +21,17 @@ namespace paddle {
...
@@ -21,12 +21,17 @@ namespace paddle {
namespace
inference
{
namespace
inference
{
namespace
analysis
{
namespace
analysis
{
static
const
char
kFluidToIrPassesAttr
[]
=
"__fluid_to_ir_passes__"
;
class
FluidToIrPass
final
:
public
DataFlowGraphPass
{
class
FluidToIrPass
final
:
public
DataFlowGraphPass
{
public:
public:
FluidToIrPass
()
=
default
;
FluidToIrPass
()
=
default
;
bool
Initialize
(
Argument
*
argument
)
override
{
bool
Initialize
(
Argument
*
argument
)
override
{
ANALYSIS_ARGUMENT_CHECK_FIELD
(
argument
);
ANALYSIS_ARGUMENT_CHECK_FIELD
(
argument
);
PADDLE_ENFORCE
(
argument
->
Has
(
kFluidToIrPassesAttr
),
"argument need the attr %s"
,
kFluidToIrPassesAttr
);
argument_
=
argument
;
if
(
argument
->
origin_program_desc
)
{
if
(
argument
->
origin_program_desc
)
{
LOG
(
WARNING
)
<<
"argument's origin_program_desc is already set, might "
LOG
(
WARNING
)
<<
"argument's origin_program_desc is already set, might "
"duplicate called"
;
"duplicate called"
;
...
@@ -46,12 +51,21 @@ class FluidToIrPass final : public DataFlowGraphPass {
...
@@ -46,12 +51,21 @@ class FluidToIrPass final : public DataFlowGraphPass {
if
(
!
argument
->
main_dfg
)
{
if
(
!
argument
->
main_dfg
)
{
argument
->
main_dfg
.
reset
(
new
DataFlowGraph
);
argument
->
main_dfg
.
reset
(
new
DataFlowGraph
);
}
}
// Persist the ProgramDesc in graph's attribute. The IR graph just keep the
argument
->
Set
(
"ir_program_desc"
,
new
framework
::
ProgramDesc
(
program
));
// address, will segfault if the original ProgramDesc destroys.
auto
&
ir_program_p
=
argument
->
main_dfg
->
Attr
(
"ir_program_desc"
).
Pointer
();
LOG
(
INFO
)
<<
"Loading parameters"
;
ir_program_p
=
new
framework
::
ProgramDesc
(
program
);
// Load parameters to argument if needed.
if
(
argument
->
fluid_model_dir
||
(
argument
->
fluid_model_program_path
&&
argument
->
fluid_model_param_path
))
{
#define SAFE_GET(ATTR) std::string ATTR = argument->ATTR ? *argument->ATTR : "";
SAFE_GET
(
fluid_model_dir
);
SAFE_GET
(
fluid_model_program_path
);
SAFE_GET
(
fluid_model_param_path
);
#undef SAFE_GET
EnableParamModify
(
fluid_model_dir
,
fluid_model_program_path
,
fluid_model_param_path
);
}
argument_
=
argument
;
return
true
;
return
true
;
}
}
...
@@ -59,20 +73,36 @@ class FluidToIrPass final : public DataFlowGraphPass {
...
@@ -59,20 +73,36 @@ class FluidToIrPass final : public DataFlowGraphPass {
void
Run
(
DataFlowGraph
*
graph
)
override
{
void
Run
(
DataFlowGraph
*
graph
)
override
{
// Call all the IR Passes
// Call all the IR Passes
IRPassManager
ir_passes
(
*
static_cast
<
framework
::
ProgramDesc
*>
(
IRPassManager
ir_passes
(
argument_
->
main_dfg
->
Attr
(
"ir_program_desc"
).
Pointer
()));
argument_
->
Get
<
framework
::
ProgramDesc
>
(
"ir_program_desc"
),
nullptr
);
ir_passes
.
Apply
(
std
::
vector
<
std
::
string
>
(
// Pass the scope from analysis to IR if needed.
{
// Manual update the passes here.
if
(
argument_
->
Has
(
"param_scope"
))
{
"graph_viz_pass"
,
"infer_clean_graph_pass"
,
"graph_viz_pass"
,
// Here the address is passed, attention that IR doesn't own the scope, so
"fc_fuse_pass"
,
"graph_viz_pass"
}));
// the real scope in analysis should live during the IR phase.
ir_passes
.
graph
().
Set
(
"param_scope"
,
new
framework
::
Scope
*
(
&
argument_
->
Get
<
framework
::
Scope
>
(
"param_scope"
)));
}
const
auto
&
ir_passes_to_apply
=
argument_
->
Get
<
std
::
vector
<
std
::
string
>>
(
kFluidToIrPassesAttr
);
ir_passes
.
Apply
(
ir_passes_to_apply
);
PADDLE_ENFORCE
(
argument_
->
main_dfg
.
get
());
PADDLE_ENFORCE
(
argument_
->
main_dfg
.
get
());
argument_
->
main_dfg
->
Build
(
ir_passes
.
graph
());
argument_
->
main_dfg
->
Build
(
ir_passes
.
graph
());
// PADDLE_ENFORCE(argument_->main_dfg->IsFullyConnected());
}
}
void
EnableParamModify
(
const
std
::
string
&
model_dir
,
const
std
::
string
&
prog_file
,
const
std
::
string
&
param_file
);
std
::
string
repr
()
const
override
{
return
"fluid-to-ir-pass"
;
}
std
::
string
repr
()
const
override
{
return
"fluid-to-ir-pass"
;
}
private:
// Load parameters from a single file or from a directory.
bool
LoadParams
(
framework
::
Scope
*
scope
,
const
std
::
string
&
dir
,
const
std
::
string
&
prog_file
,
const
std
::
string
&
param_file
);
private:
private:
Argument
*
argument_
{
nullptr
};
Argument
*
argument_
{
nullptr
};
};
};
...
...
paddle/fluid/inference/analysis/fluid_to_ir_pass_tester.cc
浏览文件 @
902f19b4
...
@@ -24,6 +24,8 @@ namespace analysis {
...
@@ -24,6 +24,8 @@ namespace analysis {
TEST
(
FluidToIrPass
,
Test
)
{
TEST
(
FluidToIrPass
,
Test
)
{
FluidToIrPass
pass
;
FluidToIrPass
pass
;
Argument
argument
(
FLAGS_inference_model_dir
);
Argument
argument
(
FLAGS_inference_model_dir
);
argument
.
Set
(
kFluidToIrPassesAttr
,
new
std
::
vector
<
std
::
string
>
({
"infer_clean_graph_pass"
}));
pass
.
Initialize
(
&
argument
);
pass
.
Initialize
(
&
argument
);
pass
.
Run
(
argument
.
main_dfg
.
get
());
pass
.
Run
(
argument
.
main_dfg
.
get
());
}
}
...
@@ -32,6 +34,9 @@ TEST(FluidToIrPass, Test) {
...
@@ -32,6 +34,9 @@ TEST(FluidToIrPass, Test) {
}
// namespace inference
}
// namespace inference
}
// namespace paddle
}
// namespace paddle
USE_PASS
(
fc_fuse_pass
);
USE_PASS
(
graph_viz_pass
);
USE_PASS
(
graph_viz_pass
);
USE_PASS
(
infer_clean_graph_pass
);
USE_PASS
(
infer_clean_graph_pass
);
USE_PASS
(
attention_lstm_fuse_pass
);
USE_PASS
(
fc_lstm_fuse_pass
);
USE_PASS
(
seq_concat_fc_fuse_pass
);
USE_PASS
(
fc_fuse_pass
);
paddle/fluid/inference/analysis/ir_pass_manager.cc
浏览文件 @
902f19b4
...
@@ -14,20 +14,24 @@
...
@@ -14,20 +14,24 @@
#include "paddle/fluid/inference/analysis/ir_pass_manager.h"
#include "paddle/fluid/inference/analysis/ir_pass_manager.h"
#include <string>
#include <string>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/scope.h"
namespace
paddle
{
namespace
paddle
{
namespace
inference
{
namespace
inference
{
namespace
analysis
{
namespace
analysis
{
IRPassManager
::
IRPassManager
(
const
ProgramDesc
&
program
)
{
IRPassManager
::
IRPassManager
(
const
ProgramDesc
&
program
,
framework
::
Scope
*
scope
)
:
program_
(
program
)
{
graph_
.
reset
(
new
framework
::
ir
::
Graph
(
program
));
graph_
.
reset
(
new
framework
::
ir
::
Graph
(
program
));
if
(
scope
)
graph_
->
Set
(
"param_scope"
,
new
framework
::
Scope
*
(
scope
));
}
}
void
IRPassManager
::
Apply
(
const
std
::
vector
<
std
::
string
>&
passes
)
{
void
IRPassManager
::
Apply
(
const
std
::
vector
<
std
::
string
>
&
passes
)
{
graph_
->
Set
(
"graph_viz_path"
,
new
std
::
string
(
"./1.dot"
));
// Apply all the passes
// Apply all the passes
std
::
string
pre_pass
;
std
::
string
pre_pass
;
for
(
const
std
::
string
&
pass_name
:
passes
)
{
for
(
const
std
::
string
&
pass_name
:
passes
)
{
LOG
(
WARNING
)
<<
"Running IR pass ["
<<
pass_name
<<
"]"
;
LOG
(
WARNING
)
<<
"Running IR pass ["
<<
pass_name
<<
"]"
;
auto
pass
=
framework
::
ir
::
PassRegistry
::
Instance
().
Get
(
pass_name
);
auto
pass
=
framework
::
ir
::
PassRegistry
::
Instance
().
Get
(
pass_name
);
if
(
pass_name
==
"graph_viz_pass"
)
{
if
(
pass_name
==
"graph_viz_pass"
)
{
...
...
paddle/fluid/inference/analysis/ir_pass_manager.h
浏览文件 @
902f19b4
...
@@ -23,6 +23,7 @@
...
@@ -23,6 +23,7 @@
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
namespace
paddle
{
namespace
paddle
{
namespace
inference
{
namespace
inference
{
...
@@ -31,14 +32,15 @@ using framework::ProgramDesc;
...
@@ -31,14 +32,15 @@ using framework::ProgramDesc;
class
IRPassManager
final
{
class
IRPassManager
final
{
public:
public:
IRPassManager
(
const
ProgramDesc
&
program
);
IRPassManager
(
const
ProgramDesc
&
program
,
framework
::
Scope
*
scope
);
void
Apply
(
const
std
::
vector
<
std
::
string
>
&
passes
);
void
Apply
(
const
std
::
vector
<
std
::
string
>
&
passes
);
framework
::
ir
::
Graph
&
graph
()
const
{
return
*
graph_
;
}
framework
::
ir
::
Graph
&
graph
()
const
{
return
*
graph_
;
}
private:
private:
std
::
unique_ptr
<
framework
::
ir
::
Graph
>
graph_
;
std
::
unique_ptr
<
framework
::
ir
::
Graph
>
graph_
;
ProgramDesc
program_
;
};
};
}
// namespace analysis
}
// namespace analysis
...
...
paddle/fluid/inference/analysis/pass_manager.cc
浏览文件 @
902f19b4
...
@@ -33,9 +33,9 @@ bool PassManager::Initialize(Argument* argument) {
...
@@ -33,9 +33,9 @@ bool PassManager::Initialize(Argument* argument) {
void
DfgPassManager
::
RunAll
()
{
void
DfgPassManager
::
RunAll
()
{
PADDLE_ENFORCE
(
argument_
);
PADDLE_ENFORCE
(
argument_
);
LOG
(
INFO
)
<<
"Total "
<<
data_
.
size
()
<<
" passes"
;
LOG
(
INFO
)
<<
"Total "
<<
data_
.
size
()
<<
"
Analysys
passes"
;
for
(
auto
&
pass
:
data_
)
{
for
(
auto
&
pass
:
data_
)
{
LOG
(
WARNING
)
<<
"Running pass ["
<<
pass
->
repr
()
<<
"]"
;
LOG
(
WARNING
)
<<
"Running
Analysis
pass ["
<<
pass
->
repr
()
<<
"]"
;
pass
->
Run
(
argument_
->
main_dfg
.
get
());
pass
->
Run
(
argument_
->
main_dfg
.
get
());
}
}
}
}
...
...
paddle/fluid/inference/api/CMakeLists.txt
浏览文件 @
902f19b4
...
@@ -46,7 +46,8 @@ function(inference_api_test TARGET_NAME)
...
@@ -46,7 +46,8 @@ function(inference_api_test TARGET_NAME)
endif
(
WITH_TESTING
)
endif
(
WITH_TESTING
)
endfunction
(
inference_api_test
)
endfunction
(
inference_api_test
)
cc_library
(
paddle_inference_api SRCS api.cc api_impl.cc DEPS lod_tensor
)
cc_library
(
paddle_inference_api SRCS api.cc api_impl.cc helper.cc DEPS lod_tensor
)
cc_library
(
analysis_predictor SRCS analysis_predictor.cc DEPS paddle_inference_api
)
cc_test
(
test_paddle_inference_api
cc_test
(
test_paddle_inference_api
SRCS api_tester.cc
SRCS api_tester.cc
...
...
paddle/fluid/inference/api/analysis_predictor.cc
0 → 100644
浏览文件 @
902f19b4
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/api/api_impl.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/utils/singleton.h"
namespace
paddle
{
using
inference
::
analysis
::
Argument
;
using
inference
::
Singleton
;
using
inference
::
analysis
::
Analyzer
;
using
framework
::
proto
::
ProgramDesc
;
/* This predictor is based on the original native predictor with IR and Analysis
* support. It will optimize IR and Parameters in the runtime.
* TODO(Superjomn) Replace the Navive predictor?
*/
class
AnalysisPredictor
:
public
NativePaddlePredictor
{
public:
explicit
AnalysisPredictor
(
const
NativeConfig
&
config
)
:
NativePaddlePredictor
(
config
),
config_
(
config
)
{}
bool
Init
(
const
std
::
shared_ptr
<
framework
::
Scope
>&
parent_scope
)
{
VLOG
(
3
)
<<
"Predictor::init()"
;
if
(
config_
.
use_gpu
)
{
place_
=
paddle
::
platform
::
CUDAPlace
(
config_
.
device
);
}
else
{
place_
=
paddle
::
platform
::
CPUPlace
();
}
PADDLE_ENFORCE
(
!
parent_scope
);
if
(
parent_scope
)
{
scope_
=
parent_scope
;
sub_scope_
=
&
(
parent_scope
->
NewScope
());
}
else
{
paddle
::
framework
::
InitDevices
(
false
);
scope_
.
reset
(
new
paddle
::
framework
::
Scope
());
}
executor_
.
reset
(
new
paddle
::
framework
::
Executor
(
place_
));
// Initialize the inference program
if
(
!
config_
.
model_dir
.
empty
())
{
// Parameters are saved in separate files sited in
// the specified `dirname`.
inference_program_
=
paddle
::
inference
::
Load
(
executor_
.
get
(),
scope_
.
get
(),
config_
.
model_dir
);
}
else
if
(
!
config_
.
prog_file
.
empty
()
&&
!
config_
.
param_file
.
empty
())
{
// All parameters are saved in a single file.
// The file names should be consistent with that used
// in Python API `fluid.io.save_inference_model`.
inference_program_
=
paddle
::
inference
::
Load
(
executor_
.
get
(),
scope_
.
get
(),
config_
.
prog_file
,
config_
.
param_file
);
}
else
{
LOG
(
ERROR
)
<<
"fail to load inference model."
;
return
false
;
}
OptimizeInferenceProgram
();
ctx_
=
executor_
->
Prepare
(
*
inference_program_
,
0
);
VLOG
(
5
)
<<
"to create variables"
;
PADDLE_ENFORCE
(
scope_
.
get
());
executor_
->
CreateVariables
(
*
inference_program_
,
sub_scope_
?
sub_scope_
:
scope_
.
get
(),
0
);
// Get the feed_target_names and fetch_target_names
feed_target_names_
=
inference_program_
->
GetFeedTargetNames
();
fetch_target_names_
=
inference_program_
->
GetFetchTargetNames
();
return
true
;
}
bool
Run
(
const
std
::
vector
<
PaddleTensor
>&
inputs
,
std
::
vector
<
PaddleTensor
>*
output_data
,
int
batch_size
=
-
1
)
override
{
return
NativePaddlePredictor
::
Run
(
inputs
,
output_data
,
batch_size
);
}
void
OptimizeInferenceProgram
()
{
LOG
(
INFO
)
<<
"optimize begin"
;
FLAGS_IA_enable_ir
=
true
;
FLAGS_IA_enable_tensorrt_subgraph_engine
=
false
;
FLAGS_IA_output_storage_path
=
""
;
// Don't output the model.
// Analyze inference_program
Argument
argument
;
if
(
!
config_
.
model_dir
.
empty
())
{
argument
.
fluid_model_dir
.
reset
(
new
std
::
string
(
config_
.
model_dir
));
}
else
{
PADDLE_ENFORCE
(
!
config_
.
param_file
.
empty
(),
"Either model_dir or (param_file, prog_file) should be set."
);
PADDLE_ENFORCE
(
!
config_
.
prog_file
.
empty
());
argument
.
fluid_model_program_path
.
reset
(
new
std
::
string
(
config_
.
prog_file
));
argument
.
fluid_model_param_path
.
reset
(
new
std
::
string
(
config_
.
param_file
));
}
argument
.
origin_program_desc
.
reset
(
new
ProgramDesc
(
*
inference_program_
->
Proto
()));
Singleton
<
Analyzer
>::
Global
().
Run
(
&
argument
);
CHECK
(
argument
.
transformed_program_desc
);
VLOG
(
5
)
<<
"to prepare executor"
;
// LOG(INFO) << "transformed_parogram_desc " <<
// argument.transformed_program_desc->DebugString();
inference_program_
.
reset
(
new
framework
::
ProgramDesc
(
*
argument
.
transformed_program_desc
));
PADDLE_ENFORCE
(
argument
.
Has
(
"param_scope"
));
// Update scope.
scope_
.
reset
(
argument
.
Release
<
framework
::
Scope
>
(
"param_scope"
));
LOG
(
INFO
)
<<
"optimize end =="
;
}
private:
NativeConfig
config_
;
};
template
<
>
std
::
unique_ptr
<
PaddlePredictor
>
CreatePaddlePredictor
<
NativeConfig
,
PaddleEngineKind
::
kAnalysis
>
(
const
NativeConfig
&
config
)
{
VLOG
(
3
)
<<
"create NativePredictor"
;
if
(
config
.
use_gpu
)
{
// 1. GPU memeroy
PADDLE_ENFORCE_GT
(
config
.
fraction_of_gpu_memory
,
0.
f
,
"fraction_of_gpu_memory in the config should be set to range (0., 1.]"
);
PADDLE_ENFORCE_GE
(
config
.
device
,
0
,
"Invalid device id %d"
,
config
.
device
);
std
::
vector
<
std
::
string
>
flags
;
if
(
config
.
fraction_of_gpu_memory
>=
0.0
f
||
config
.
fraction_of_gpu_memory
<=
0.95
f
)
{
flags
.
push_back
(
"dummpy"
);
std
::
string
flag
=
"--fraction_of_gpu_memory_to_use="
+
std
::
to_string
(
config
.
fraction_of_gpu_memory
);
flags
.
push_back
(
flag
);
VLOG
(
3
)
<<
"set flag: "
<<
flag
;
framework
::
InitGflags
(
flags
);
}
}
std
::
unique_ptr
<
PaddlePredictor
>
predictor
(
new
AnalysisPredictor
(
config
));
if
(
!
dynamic_cast
<
AnalysisPredictor
*>
(
predictor
.
get
())
->
Init
(
nullptr
))
{
return
nullptr
;
}
return
predictor
;
}
}
// namespace paddle
USE_PASS
(
fc_fuse_pass
);
USE_PASS
(
graph_viz_pass
);
USE_PASS
(
infer_clean_graph_pass
);
paddle/fluid/inference/api/helper.cc
0 → 100644
浏览文件 @
902f19b4
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/api/helper.h"
namespace
paddle
{
namespace
inference
{
template
<
>
std
::
string
to_string
<
std
::
vector
<
float
>>
(
const
std
::
vector
<
std
::
vector
<
float
>>
&
vec
)
{
std
::
stringstream
ss
;
for
(
const
auto
&
piece
:
vec
)
{
ss
<<
to_string
(
piece
)
<<
"
\n
"
;
}
return
ss
.
str
();
}
template
<
>
std
::
string
to_string
<
std
::
vector
<
std
::
vector
<
float
>>>
(
const
std
::
vector
<
std
::
vector
<
std
::
vector
<
float
>>>
&
vec
)
{
std
::
stringstream
ss
;
for
(
const
auto
&
line
:
vec
)
{
for
(
const
auto
&
rcd
:
line
)
{
ss
<<
to_string
(
rcd
)
<<
";
\t
"
;
}
ss
<<
'\n'
;
}
return
ss
.
str
();
}
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/api/helper.h
浏览文件 @
902f19b4
...
@@ -44,7 +44,8 @@ class Timer {
...
@@ -44,7 +44,8 @@ class Timer {
}
}
};
};
void
split
(
const
std
::
string
&
str
,
char
sep
,
std
::
vector
<
std
::
string
>
*
pieces
)
{
static
void
split
(
const
std
::
string
&
str
,
char
sep
,
std
::
vector
<
std
::
string
>
*
pieces
)
{
pieces
->
clear
();
pieces
->
clear
();
if
(
str
.
empty
())
{
if
(
str
.
empty
())
{
return
;
return
;
...
@@ -60,7 +61,8 @@ void split(const std::string &str, char sep, std::vector<std::string> *pieces) {
...
@@ -60,7 +61,8 @@ void split(const std::string &str, char sep, std::vector<std::string> *pieces) {
pieces
->
push_back
(
str
.
substr
(
pos
));
pieces
->
push_back
(
str
.
substr
(
pos
));
}
}
}
}
void
split_to_float
(
const
std
::
string
&
str
,
char
sep
,
std
::
vector
<
float
>
*
fs
)
{
static
void
split_to_float
(
const
std
::
string
&
str
,
char
sep
,
std
::
vector
<
float
>
*
fs
)
{
std
::
vector
<
std
::
string
>
pieces
;
std
::
vector
<
std
::
string
>
pieces
;
split
(
str
,
sep
,
&
pieces
);
split
(
str
,
sep
,
&
pieces
);
std
::
transform
(
pieces
.
begin
(),
pieces
.
end
(),
std
::
back_inserter
(
*
fs
),
std
::
transform
(
pieces
.
begin
(),
pieces
.
end
(),
std
::
back_inserter
(
*
fs
),
...
@@ -76,27 +78,14 @@ std::string to_string(const std::vector<T> &vec) {
...
@@ -76,27 +78,14 @@ std::string to_string(const std::vector<T> &vec) {
}
}
template
<
>
template
<
>
std
::
string
to_string
<
std
::
vector
<
float
>>
(
std
::
string
to_string
<
std
::
vector
<
float
>>
(
const
std
::
vector
<
std
::
vector
<
float
>>
&
vec
)
{
const
std
::
vector
<
std
::
vector
<
float
>>
&
vec
);
std
::
stringstream
ss
;
for
(
const
auto
&
piece
:
vec
)
{
ss
<<
to_string
(
piece
)
<<
"
\n
"
;
}
return
ss
.
str
();
}
template
<
>
template
<
>
std
::
string
to_string
<
std
::
vector
<
std
::
vector
<
float
>>>
(
std
::
string
to_string
<
std
::
vector
<
std
::
vector
<
float
>>>
(
const
std
::
vector
<
std
::
vector
<
std
::
vector
<
float
>>>
&
vec
)
{
const
std
::
vector
<
std
::
vector
<
std
::
vector
<
float
>>>
&
vec
);
std
::
stringstream
ss
;
for
(
const
auto
&
line
:
vec
)
{
for
(
const
auto
&
rcd
:
line
)
{
ss
<<
to_string
(
rcd
)
<<
";
\t
"
;
}
ss
<<
'\n'
;
}
return
ss
.
str
();
}
// clang-format off
// clang-format off
void
TensorAssignData
(
PaddleTensor
*
tensor
,
const
std
::
vector
<
std
::
vector
<
float
>>
&
data
)
{
static
void
TensorAssignData
(
PaddleTensor
*
tensor
,
const
std
::
vector
<
std
::
vector
<
float
>>
&
data
)
{
// Assign buffer
// Assign buffer
int
dim
=
std
::
accumulate
(
tensor
->
shape
.
begin
(),
tensor
->
shape
.
end
(),
1
,
[](
int
a
,
int
b
)
{
return
a
*
b
;
});
int
dim
=
std
::
accumulate
(
tensor
->
shape
.
begin
(),
tensor
->
shape
.
end
(),
1
,
[](
int
a
,
int
b
)
{
return
a
*
b
;
});
tensor
->
data
.
Resize
(
sizeof
(
float
)
*
dim
);
tensor
->
data
.
Resize
(
sizeof
(
float
)
*
dim
);
...
...
paddle/fluid/inference/api/paddle_inference_api.h
浏览文件 @
902f19b4
...
@@ -77,6 +77,7 @@ enum class PaddleEngineKind {
...
@@ -77,6 +77,7 @@ enum class PaddleEngineKind {
kNative
=
0
,
// Use the native Fluid facility.
kNative
=
0
,
// Use the native Fluid facility.
kAnakin
,
// Use Anakin for inference.
kAnakin
,
// Use Anakin for inference.
kAutoMixedTensorRT
,
// Automatically mix Fluid with TensorRT.
kAutoMixedTensorRT
,
// Automatically mix Fluid with TensorRT.
kAnalysis
// TODO(Superjomn) support following engines latter.
// TODO(Superjomn) support following engines latter.
// kTensorRT, // Use TensorRT for inference.
// kTensorRT, // Use TensorRT for inference.
// kAutoMixedAnakin, // Automatically mix Fluid with Anakin.
// kAutoMixedAnakin, // Automatically mix Fluid with Anakin.
...
...
paddle/fluid/inference/io.cc
浏览文件 @
902f19b4
...
@@ -143,5 +143,21 @@ std::unique_ptr<framework::ProgramDesc> Load(
...
@@ -143,5 +143,21 @@ std::unique_ptr<framework::ProgramDesc> Load(
return
main_program
;
return
main_program
;
}
}
void
SaveVars
(
const
framework
::
Scope
&
scope
,
const
std
::
vector
<
std
::
string
>&
vars
,
const
std
::
string
&
dirname
,
bool
predicate
)
{
framework
::
ProgramDesc
prog
;
auto
*
block
=
prog
.
MutableBlock
(
0
);
auto
*
op
=
block
->
AppendOp
();
op
->
SetType
(
"save_combine"
);
op
->
SetInput
(
"X"
,
vars
);
op
->
SetAttr
(
"file_path"
,
dirname
+
"/param"
);
op
->
CheckAttrs
();
platform
::
CPUPlace
place
;
framework
::
Executor
exe
(
place
);
exe
.
Run
(
prog
,
const_cast
<
framework
::
Scope
*>
(
&
scope
),
0
,
true
,
true
);
}
}
// namespace inference
}
// namespace inference
}
// namespace paddle
}
// namespace paddle
paddle/fluid/inference/io.h
浏览文件 @
902f19b4
...
@@ -41,5 +41,10 @@ std::unique_ptr<framework::ProgramDesc> Load(framework::Executor* executor,
...
@@ -41,5 +41,10 @@ std::unique_ptr<framework::ProgramDesc> Load(framework::Executor* executor,
const
std
::
string
&
prog_filename
,
const
std
::
string
&
prog_filename
,
const
std
::
string
&
param_filename
);
const
std
::
string
&
param_filename
);
// Save the variables from a scope to disk.
void
SaveVars
(
const
framework
::
Scope
&
scope
,
const
std
::
vector
<
std
::
string
>&
vars
,
const
std
::
string
&
dirname
,
bool
predicate
=
true
);
}
// namespace inference
}
// namespace inference
}
// namespace paddle
}
// namespace paddle
paddle/fluid/operators/attention_lstm_op.cc
浏览文件 @
902f19b4
...
@@ -56,7 +56,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
...
@@ -56,7 +56,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
const
int
D
=
w_dims
[
1
]
/
4
;
const
int
D
=
w_dims
[
1
]
/
4
;
PADDLE_ENFORCE_EQ
(
w_dims
.
size
(),
2
,
"Input(LSTMWeight)'s rank must be 2."
);
PADDLE_ENFORCE_EQ
(
w_dims
.
size
(),
2
,
"Input(LSTMWeight)'s rank must be 2."
);
PADDLE_ENFORCE_EQ
(
w_dims
[
0
],
D
+
M
,
PADDLE_ENFORCE_EQ
(
w_dims
[
0
],
D
+
M
,
"LSTMWeight dims should be (%d + %d) * %d."
,
D
+
M
,
4
*
D
);
"LSTMWeight dims should be (%d + %d) * %d."
,
D
,
M
,
4
*
D
);
auto
b_dims
=
ctx
->
GetInputDim
(
"LSTMBias"
);
auto
b_dims
=
ctx
->
GetInputDim
(
"LSTMBias"
);
PADDLE_ENFORCE_EQ
(
b_dims
.
size
(),
2
,
"Input(LSTMBias)'s rank must be 2."
);
PADDLE_ENFORCE_EQ
(
b_dims
.
size
(),
2
,
"Input(LSTMBias)'s rank must be 2."
);
...
...
paddle/fluid/operators/fusion_seqexpand_concat_fc_op.cc
浏览文件 @
902f19b4
...
@@ -49,9 +49,14 @@ void FusionSeqExpandConcatFCOp::InferShape(
...
@@ -49,9 +49,14 @@ void FusionSeqExpandConcatFCOp::InferShape(
"FC height should be sum of all inputs width."
);
"FC height should be sum of all inputs width."
);
if
(
ctx
->
HasInput
(
"FCBias"
))
{
if
(
ctx
->
HasInput
(
"FCBias"
))
{
auto
b_dims
=
ctx
->
GetInputDim
(
"FCBias"
);
auto
b_dims
=
ctx
->
GetInputDim
(
"FCBias"
);
PADDLE_ENFORCE_EQ
(
b_dims
.
size
(),
2
,
"Input(FCBias)'s rank must be 2."
);
PADDLE_ENFORCE
(
b_dims
.
size
()
==
1
||
b_dims
.
size
()
==
2
,
PADDLE_ENFORCE_EQ
(
b_dims
[
0
],
1
,
"FCBias shapes must be 1 * %d."
,
D
);
"b_dims should be 1 or 2, get %d"
,
b_dims
.
size
());
PADDLE_ENFORCE_EQ
(
b_dims
[
1
],
D
,
"FCBias shapes must be 1 * %d."
,
D
);
if
(
b_dims
.
size
()
==
1
)
{
PADDLE_ENFORCE_EQ
(
b_dims
[
0
],
D
,
"FCBias shapes must be %d."
,
D
);
}
else
{
PADDLE_ENFORCE_EQ
(
b_dims
[
0
],
1
,
"FCBias shapes must be 1x%d."
,
D
);
PADDLE_ENFORCE_EQ
(
b_dims
[
1
],
D
,
"FCBias shapes must be 1x%d."
,
D
);
}
}
}
ctx
->
SetOutputDim
(
"Out"
,
{
ins_dims
[
0
][
0
],
D
});
ctx
->
SetOutputDim
(
"Out"
,
{
ins_dims
[
0
][
0
],
D
});
...
...
paddle/fluid/platform/init.cc
浏览文件 @
902f19b4
...
@@ -85,9 +85,6 @@ void InitDevices(bool init_p2p) {
...
@@ -85,9 +85,6 @@ void InitDevices(bool init_p2p) {
}
catch
(
const
std
::
exception
&
exp
)
{
}
catch
(
const
std
::
exception
&
exp
)
{
LOG
(
WARNING
)
<<
"Compiled with WITH_GPU, but no GPU found in runtime."
;
LOG
(
WARNING
)
<<
"Compiled with WITH_GPU, but no GPU found in runtime."
;
}
}
#else
LOG
(
WARNING
)
<<
"'CUDA' is not supported, Please re-compile with WITH_GPU option"
;
#endif
#endif
InitDevices
(
init_p2p
,
devices
);
InitDevices
(
init_p2p
,
devices
);
}
}
...
@@ -101,9 +98,6 @@ void InitDevices(bool init_p2p, const std::vector<int> devices) {
...
@@ -101,9 +98,6 @@ void InitDevices(bool init_p2p, const std::vector<int> devices) {
}
catch
(
const
std
::
exception
&
exp
)
{
}
catch
(
const
std
::
exception
&
exp
)
{
LOG
(
WARNING
)
<<
"Compiled with WITH_GPU, but no GPU found in runtime."
;
LOG
(
WARNING
)
<<
"Compiled with WITH_GPU, but no GPU found in runtime."
;
}
}
#else
LOG
(
WARNING
)
<<
"'CUDA' is not supported, Please re-compile with WITH_GPU option"
;
#endif
#endif
for
(
size_t
i
=
0
;
i
<
devices
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
devices
.
size
();
++
i
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录