Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
902f19b4
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
902f19b4
编写于
8月 29, 2018
作者:
Y
Yan Chunwei
提交者:
GitHub
8月 29, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fea/fuse attention lstm simplify.with fusion lstm.with sequnce expand (#13006)
上级
55f240ba
变更
40
隐藏空白更改
内联
并排
Showing
40 changed file
with
1507 addition
and
211 deletion
+1507
-211
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+7
-5
paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc
paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc
+273
-0
paddle/fluid/framework/ir/attention_lstm_fuse_pass.h
paddle/fluid/framework/ir/attention_lstm_fuse_pass.h
+14
-7
paddle/fluid/framework/ir/fc_fuse_pass.cc
paddle/fluid/framework/ir/fc_fuse_pass.cc
+6
-8
paddle/fluid/framework/ir/fc_fuse_pass.h
paddle/fluid/framework/ir/fc_fuse_pass.h
+1
-1
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
+126
-0
paddle/fluid/framework/ir/fc_lstm_fuse_pass.h
paddle/fluid/framework/ir/fc_lstm_fuse_pass.h
+33
-0
paddle/fluid/framework/ir/fuse_pass_base.h
paddle/fluid/framework/ir/fuse_pass_base.h
+44
-0
paddle/fluid/framework/ir/graph.h
paddle/fluid/framework/ir/graph.h
+15
-4
paddle/fluid/framework/ir/graph_helper.cc
paddle/fluid/framework/ir/graph_helper.cc
+1
-1
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+61
-16
paddle/fluid/framework/ir/graph_pattern_detector.h
paddle/fluid/framework/ir/graph_pattern_detector.h
+84
-15
paddle/fluid/framework/ir/graph_pattern_detector_tester.cc
paddle/fluid/framework/ir/graph_pattern_detector_tester.cc
+5
-5
paddle/fluid/framework/ir/graph_viz_pass.cc
paddle/fluid/framework/ir/graph_viz_pass.cc
+54
-28
paddle/fluid/framework/ir/graph_viz_pass.h
paddle/fluid/framework/ir/graph_viz_pass.h
+9
-0
paddle/fluid/framework/ir/node.h
paddle/fluid/framework/ir/node.h
+15
-6
paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc
paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc
+256
-0
paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.h
paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.h
+33
-0
paddle/fluid/inference/analysis/CMakeLists.txt
paddle/fluid/inference/analysis/CMakeLists.txt
+16
-8
paddle/fluid/inference/analysis/analyzer.cc
paddle/fluid/inference/analysis/analyzer.cc
+13
-0
paddle/fluid/inference/analysis/analyzer_tester.cc
paddle/fluid/inference/analysis/analyzer_tester.cc
+24
-47
paddle/fluid/inference/analysis/argument.h
paddle/fluid/inference/analysis/argument.h
+41
-0
paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc
...fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc
+5
-0
paddle/fluid/inference/analysis/dot.h
paddle/fluid/inference/analysis/dot.h
+12
-6
paddle/fluid/inference/analysis/fluid_to_ir_pass.cc
paddle/fluid/inference/analysis/fluid_to_ir_pass.cc
+44
-0
paddle/fluid/inference/analysis/fluid_to_ir_pass.h
paddle/fluid/inference/analysis/fluid_to_ir_pass.h
+42
-12
paddle/fluid/inference/analysis/fluid_to_ir_pass_tester.cc
paddle/fluid/inference/analysis/fluid_to_ir_pass_tester.cc
+6
-1
paddle/fluid/inference/analysis/ir_pass_manager.cc
paddle/fluid/inference/analysis/ir_pass_manager.cc
+8
-4
paddle/fluid/inference/analysis/ir_pass_manager.h
paddle/fluid/inference/analysis/ir_pass_manager.h
+5
-3
paddle/fluid/inference/analysis/pass_manager.cc
paddle/fluid/inference/analysis/pass_manager.cc
+2
-2
paddle/fluid/inference/api/CMakeLists.txt
paddle/fluid/inference/api/CMakeLists.txt
+3
-2
paddle/fluid/inference/api/analysis_predictor.cc
paddle/fluid/inference/api/analysis_predictor.cc
+165
-0
paddle/fluid/inference/api/helper.cc
paddle/fluid/inference/api/helper.cc
+44
-0
paddle/fluid/inference/api/helper.h
paddle/fluid/inference/api/helper.h
+9
-20
paddle/fluid/inference/api/paddle_inference_api.h
paddle/fluid/inference/api/paddle_inference_api.h
+1
-0
paddle/fluid/inference/io.cc
paddle/fluid/inference/io.cc
+16
-0
paddle/fluid/inference/io.h
paddle/fluid/inference/io.h
+5
-0
paddle/fluid/operators/attention_lstm_op.cc
paddle/fluid/operators/attention_lstm_op.cc
+1
-1
paddle/fluid/operators/fusion_seqexpand_concat_fc_op.cc
paddle/fluid/operators/fusion_seqexpand_concat_fc_op.cc
+8
-3
paddle/fluid/platform/init.cc
paddle/fluid/platform/init.cc
+0
-6
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
902f19b4
...
...
@@ -5,14 +5,16 @@ cc_library(pass SRCS pass.cc DEPS graph node graph_helper)
cc_library
(
graph_viz_pass SRCS graph_viz_pass.cc DEPS graph pass graph_helper
)
cc_library
(
graph_to_program_pass SRCS graph_to_program_pass.cc DEPS graph pass graph_helper
)
cc_library
(
graph_traits SRCS graph_traits.cc DEPS graph
)
cc_library
(
graph_pattern_detecter SRCS graph_pattern_detecter.cc DEPS graph graph_helper graph_traits
)
cc_library
(
fc_fuse_pass SRCS fc_fuse_pass.cc DEPS graph graph_pattern_detecter
)
cc_library
(
graph_pattern_detector SRCS graph_pattern_detector.cc DEPS graph graph_helper graph_traits
)
cc_library
(
fc_fuse_pass SRCS fc_fuse_pass.cc DEPS graph graph_pattern_detector
)
cc_library
(
attention_lstm_fuse_pass SRCS attention_lstm_fuse_pass.cc DEPS graph graph_pattern_detector
)
cc_library
(
infer_clean_graph_pass SRCS infer_clean_graph_pass.cc DEPS graph pass
)
cc_library
(
fc_lstm_fuse_pass SRCS fc_lstm_fuse_pass.cc DEPS graph graph_pattern_detector
)
cc_library
(
seq_concat_fc_fuse_pass SRCS seq_concat_fc_fuse_pass.cc DEPS graph graph_pattern_detector
)
cc_test
(
pass_test SRCS pass_test.cc DEPS graph pass graph_helper
)
cc_test
(
graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry
)
cc_test
(
graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry
)
cc_test
(
graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph_to_program_pass
)
cc_test
(
test_graph_pattern_detect
er SRCS graph_pattern_detecter_tester.cc DEPS graph_pattern_detecte
r
)
cc_test
(
test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass graph_pattern_detect
e
r graph pass graph_traits framework_proto
)
cc_test
(
test_graph_pattern_detect
or SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detecto
r
)
cc_test
(
test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass graph_pattern_detect
o
r graph pass graph_traits framework_proto
)
paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc
0 → 100644
浏览文件 @
902f19b4
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/attention_lstm_fuse_pass.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/inference/api/helper.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
struct
Param
{
std
::
string
X
=
"concat_0.tmp_0"
;
std
::
string
C0
=
"cell_init"
;
std
::
string
H0
=
"hidden_init"
;
std
::
string
AttentionWeight
=
"attention_fc.w_0"
;
std
::
string
AttentionBias
=
"attention_fc.b_0"
;
std
::
string
AttentionScalar
=
"attention_output.w_0"
;
std
::
string
AttentionScalarBias
=
"attention_output.b_0"
;
std
::
string
LSTMWeight
=
"attention_w.new"
;
std
::
string
LSTMBias
=
"attention_b.new"
;
std
::
string
Hidden
=
"array_to_lod_tensor_0.tmp_0"
;
std
::
string
Cell
=
"at.cell.new"
;
std
::
string
AttentionedX
=
"at.x.new"
;
std
::
string
AttentionFCOut
=
"at.fc.new"
;
std
::
string
LSTMX
=
"at.lstmx.new"
;
std
::
string
LSTMOUT
=
"at.lstmout.new"
;
};
void
PrepareParameters
(
Graph
*
graph
,
const
Param
&
param
);
void
FindWhileOp
(
Graph
*
graph
)
{
GraphPatternDetector
gpd
;
std
::
unordered_set
<
int
>
fused_external_ops
(
{
35
,
36
,
37
,
38
,
43
,
44
,
49
,
45
,
46
,
47
,
41
,
42
,
53
,
54
,
48
,
57
,
55
,
56
,
52
,
74
,
80
,
77
,
78
,
79
,
50
,
77
,
39
,
40
,
51
});
gpd
.
mutable_pattern
()
->
NewNode
(
[
&
](
Node
*
n
)
{
return
fused_external_ops
.
count
(
n
->
id
());
},
"while"
);
if
(
!
graph
->
Has
(
kGraphvizMarkedNodeAttr
))
{
graph
->
Set
(
kGraphvizMarkedNodeAttr
,
new
GraphVizPass
::
marked_nodes_t
);
}
auto
&
marked_nodes
=
graph
->
Get
<
GraphVizPass
::
marked_nodes_t
>
(
kGraphvizMarkedNodeAttr
);
auto
handle
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
auto
*
while_pat_node
=
gpd
.
pattern
().
RetriveNode
(
"while"
);
auto
*
while_node
=
subgraph
.
at
(
while_pat_node
);
marked_nodes
.
insert
(
while_node
);
};
gpd
(
graph
,
handle
);
Param
param
;
// Add AttentionLSTM node
OpDesc
op_desc
;
op_desc
.
SetType
(
"attention_lstm"
);
#define OP_SET_IN(x) op_desc.SetInput(#x, {param.x});
#define OP_SET_OUT(x) op_desc.SetOutput(#x, {param.x});
OP_SET_IN
(
X
);
OP_SET_IN
(
C0
);
OP_SET_IN
(
H0
);
OP_SET_IN
(
AttentionWeight
);
OP_SET_IN
(
AttentionBias
);
OP_SET_IN
(
AttentionScalar
);
OP_SET_IN
(
AttentionScalarBias
);
OP_SET_IN
(
LSTMWeight
);
OP_SET_IN
(
LSTMBias
);
OP_SET_OUT
(
Hidden
);
OP_SET_OUT
(
Cell
);
OP_SET_OUT
(
AttentionedX
);
OP_SET_OUT
(
AttentionFCOut
);
OP_SET_OUT
(
LSTMX
);
OP_SET_OUT
(
LSTMOUT
);
#undef OP_SET_IN
#undef OP_SET_OUT
auto
*
X
=
graph
->
RetriveNode
(
34
);
auto
*
LSTMOUT
=
graph
->
RetriveNode
(
81
);
auto
*
cell_init
=
graph
->
RetriveNode
(
6
);
auto
*
hidden_init
=
graph
->
RetriveNode
(
8
);
#define LINK_TO(node0, node1) \
node0->outputs.push_back(node1); \
node1->inputs.push_back(node0);
auto
*
lstm_op
=
graph
->
CreateOpNode
(
&
op_desc
);
PrepareParameters
(
graph
,
param
);
LINK_TO
(
X
,
lstm_op
);
LINK_TO
(
cell_init
,
lstm_op
);
LINK_TO
(
hidden_init
,
lstm_op
);
LINK_TO
(
lstm_op
,
LSTMOUT
);
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
}
#define CHECK_P1(x) PADDLE_ENFORCE_NOT_NULL(x);
#define CHECK_P2(x0, x1) \
CHECK_P1(x0); \
CHECK_P1(x1);
#define CHECK_P3(x0, x1, x2) \
CHECK_P2(x0, x1); \
CHECK_P1(x2);
#define CHECK_P4(x0, x1, x2, x3) \
CHECK_P3(x0, x1, x2); \
CHECK_P1(x3);
#define CHECK_P5(x0, x1, x2, x3, x4) \
CHECK_P4(x0, x1, x2, x3); \
CHECK_P1(x4);
void
PrepareLSTMWeight
(
const
LoDTensor
&
W_forget_w0
,
const
LoDTensor
&
W_forget_w1
,
const
LoDTensor
&
W_input_w0
,
const
LoDTensor
&
W_input_w1
,
const
LoDTensor
&
W_output_w0
,
const
LoDTensor
&
W_output_w1
,
const
LoDTensor
&
W_cell_w0
,
const
LoDTensor
&
W_cell_w1
,
LoDTensor
*
out
);
void
PrepareLSTMBias
(
const
LoDTensor
&
B_forget
,
const
LoDTensor
&
B_input
,
const
LoDTensor
&
B_output
,
const
LoDTensor
&
B_cell
,
LoDTensor
*
out
);
void
PrepareParameters
(
Graph
*
graph
,
const
Param
&
param
)
{
// Check parameters
PADDLE_ENFORCE
(
graph
->
Has
(
kParamScopeAttr
));
auto
*
scope
=
graph
->
Get
<
Scope
*>
(
kParamScopeAttr
);
// Create new parameters.
scope
->
Var
(
param
.
LSTMWeight
)
->
GetMutable
<
LoDTensor
>
();
scope
->
Var
(
param
.
LSTMBias
)
->
GetMutable
<
LoDTensor
>
();
scope
->
Var
(
param
.
Hidden
)
->
GetMutable
<
LoDTensor
>
();
scope
->
Var
(
param
.
Cell
)
->
GetMutable
<
LoDTensor
>
();
scope
->
Var
(
param
.
AttentionedX
)
->
GetMutable
<
LoDTensor
>
();
scope
->
Var
(
param
.
AttentionFCOut
)
->
GetMutable
<
LoDTensor
>
();
scope
->
Var
(
param
.
LSTMX
)
->
GetMutable
<
LoDTensor
>
();
scope
->
Var
(
param
.
LSTMOUT
)
->
GetMutable
<
LoDTensor
>
();
#define GATE_W(name__) \
auto* W_##name__##_w0 = scope->FindVar(#name__ ".w_0"); \
auto* W_##name__##_w1 = scope->FindVar(#name__ ".w_1"); \
auto* W_##name__##_b0 = scope->FindVar(#name__ ".b_0"); \
CHECK_P3(W_##name__##_w0, W_##name__##_w1, W_##name__##_b0); \
VLOG(4) << #name__ "_w0" \
<< " shape: " << W_##name__##_w0->Get<LoDTensor>().dims(); \
VLOG(4) << #name__ "_w1" \
<< " shape: " << W_##name__##_w1->Get<LoDTensor>().dims(); \
VLOG(4) << #name__ "_b0" \
<< " shape: " << W_##name__##_b0->Get<LoDTensor>().dims(); \
auto& W_##name__##_w0_t = W_##name__##_w0->Get<LoDTensor>(); \
auto& W_##name__##_w1_t = W_##name__##_w1->Get<LoDTensor>(); \
auto& W_##name__##_b0_t = W_##name__##_b0->Get<LoDTensor>();
GATE_W
(
forget
);
GATE_W
(
input
);
GATE_W
(
output
);
GATE_W
(
c
);
#undef GATE_W
auto
*
attention_fc_w
=
scope
->
FindVar
(
"attention_fc.w_0"
);
auto
*
attention_fc_b
=
scope
->
FindVar
(
"attention_fc.b_0"
);
auto
*
attention_output_w
=
scope
->
FindVar
(
"attention_output.w_0"
);
auto
*
attention_output_b
=
scope
->
FindVar
(
"attention_output.b_0"
);
CHECK_P4
(
attention_fc_w
,
attention_fc_b
,
attention_output_w
,
attention_output_b
);
auto
*
lstm_weight
=
scope
->
Var
(
param
.
LSTMWeight
);
auto
*
lstm_weight_t
=
lstm_weight
->
GetMutable
<
LoDTensor
>
();
auto
*
lstm_bias
=
scope
->
Var
(
param
.
LSTMBias
);
auto
*
lstm_bias_t
=
lstm_bias
->
GetMutable
<
LoDTensor
>
();
// reshape attention_bias
auto
*
attention_bias_t
=
scope
->
FindVar
(
param
.
AttentionBias
)
->
GetMutable
<
LoDTensor
>
();
PADDLE_ENFORCE_EQ
(
attention_bias_t
->
dims
().
size
(),
1
);
attention_bias_t
->
Resize
(
make_ddim
({
1
,
attention_bias_t
->
dims
()[
0
]}));
auto
*
attention_scalar_bias_t
=
scope
->
FindVar
(
param
.
AttentionScalarBias
)
->
GetMutable
<
LoDTensor
>
();
attention_scalar_bias_t
->
Resize
(
make_ddim
({
1
,
attention_scalar_bias_t
->
dims
()[
0
]}));
PrepareLSTMWeight
(
W_forget_w0_t
,
W_forget_w1_t
,
W_input_w0_t
,
W_input_w1_t
,
W_output_w0_t
,
W_output_w1_t
,
W_c_w0_t
,
W_c_w1_t
,
lstm_weight_t
);
PrepareLSTMBias
(
W_forget_b0_t
,
W_input_b0_t
,
W_output_b0_t
,
W_c_b0_t
,
lstm_bias_t
);
}
// Prepare parameters
void
PrepareLSTMWeight
(
const
LoDTensor
&
W_forget_w0
,
const
LoDTensor
&
W_forget_w1
,
const
LoDTensor
&
W_input_w0
,
const
LoDTensor
&
W_input_w1
,
const
LoDTensor
&
W_output_w0
,
const
LoDTensor
&
W_output_w1
,
const
LoDTensor
&
W_cell_w0
,
const
LoDTensor
&
W_cell_w1
,
LoDTensor
*
out
)
{
int
D
=
W_forget_w0
.
dims
()[
0
];
int
M
=
W_forget_w1
.
dims
()[
0
];
out
->
Resize
(
make_ddim
({
D
+
M
,
4
*
D
}));
VLOG
(
3
)
<<
"LSTMWeight resized to "
<<
out
->
dims
();
float
*
out_data
=
out
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
std
::
array
<
const
float
*
,
4
>
tensors
(
{
W_forget_w0
.
data
<
float
>
(),
W_input_w0
.
data
<
float
>
(),
W_output_w0
.
data
<
float
>
(),
W_cell_w0
.
data
<
float
>
()});
std
::
array
<
const
float
*
,
4
>
tensors1
(
{
W_forget_w1
.
data
<
float
>
(),
W_input_w1
.
data
<
float
>
(),
W_output_w1
.
data
<
float
>
(),
W_cell_w1
.
data
<
float
>
()});
for
(
int
row
=
0
;
row
<
D
;
row
++
)
{
for
(
int
col
=
0
;
col
<
4
;
col
++
)
{
float
*
dst
=
out_data
+
4
*
D
*
row
+
D
*
col
;
const
float
*
src
=
tensors
[
col
]
+
D
*
row
;
memcpy
(
dst
,
src
,
D
*
sizeof
(
float
));
}
}
for
(
int
row
=
0
;
row
<
M
;
row
++
)
{
for
(
int
col
=
0
;
col
<
4
;
col
++
)
{
float
*
dst
=
out_data
+
4
*
D
*
(
D
+
row
)
+
D
*
col
;
const
float
*
src
=
tensors1
[
col
]
+
D
*
row
;
memcpy
(
dst
,
src
,
D
*
sizeof
(
float
));
}
}
}
void
PrepareLSTMBias
(
const
LoDTensor
&
B_forget
,
const
LoDTensor
&
B_input
,
const
LoDTensor
&
B_output
,
const
LoDTensor
&
B_cell
,
LoDTensor
*
out
)
{
std
::
array
<
const
float
*
,
4
>
tensors
(
{
B_forget
.
data
<
float
>
(),
B_input
.
data
<
float
>
(),
B_output
.
data
<
float
>
(),
B_cell
.
data
<
float
>
()});
PADDLE_ENFORCE_EQ
(
B_forget
.
dims
().
size
(),
1
);
int
D
=
B_forget
.
dims
()[
0
];
out
->
Resize
(
make_ddim
({
1
,
4
*
D
}));
auto
*
out_data
=
out
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
for
(
size_t
i
=
0
;
i
<
tensors
.
size
();
i
++
)
{
memcpy
(
out_data
+
D
*
i
,
tensors
[
i
],
D
*
sizeof
(
float
));
}
}
// Parameters
std
::
unique_ptr
<
ir
::
Graph
>
AttentionLSTMFusePass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
PDPattern
external_pattern
,
subblock_pattern
;
FindWhileOp
(
graph
.
get
());
return
graph
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
attention_lstm_fuse_pass
,
paddle
::
framework
::
ir
::
AttentionLSTMFusePass
);
paddle/fluid/
inference/analysis/dot.cc
→
paddle/fluid/
framework/ir/attention_lstm_fuse_pass.h
浏览文件 @
902f19b4
//
Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
...
...
@@ -12,12 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/analysis/dot.h"
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace
paddle
{
namespace
inference
{
namespace
analysis
{
size_t
Dot
::
counter
=
0
;
}
// namespace analysis
}
// namespace inference
namespace
framework
{
namespace
ir
{
class
AttentionLSTMFusePass
:
public
FusePassBase
{
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/fc_fuse_pass.cc
浏览文件 @
902f19b4
...
...
@@ -100,12 +100,10 @@ void BuildFCPattern(PDPattern* pattern) {
},
"elementwise_add_out"
);
pattern
->
AddEdge
(
mul_parameter_var
,
mul_op
);
pattern
->
AddEdge
(
mul_tmp_input_var
,
mul_op
);
pattern
->
AddEdge
(
mul_op
,
mul_out_var
);
pattern
->
AddEdge
(
mul_out_var
,
elementwise_add_op
);
pattern
->
AddEdge
(
elementwise_add_tmp_var
,
elementwise_add_op
);
pattern
->
AddEdge
(
elementwise_add_op
,
elementwise_add_out_var
);
mul_op
->
LinksFrom
({
mul_parameter_var
,
mul_tmp_input_var
})
.
LinksTo
({
mul_out_var
});
elementwise_add_op
->
LinksFrom
({
mul_out_var
,
elementwise_add_tmp_var
})
.
LinksTo
({
elementwise_add_out_var
});
}
// Replace the node `from` in the links to `to`
...
...
@@ -125,7 +123,7 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
std
::
unordered_set
<
Node
*>
nodes2delete
;
GraphPatternDetect
e
r
gpd
;
GraphPatternDetect
o
r
gpd
;
BuildFCPattern
(
gpd
.
mutable_pattern
());
#define GET_NODE(id) \
...
...
@@ -134,7 +132,7 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
auto* id = subgraph.at(gpd.pattern().RetriveNode(#id)); \
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
auto
handler
=
[
&
](
const
GraphPatternDetect
e
r
::
subgraph_t
&
subgraph
,
auto
handler
=
[
&
](
const
GraphPatternDetect
o
r
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
VLOG
(
4
)
<<
"handle FC fuse"
;
// Currently, there is no FC op available, so I will just simulate the
...
...
paddle/fluid/framework/ir/fc_fuse_pass.h
浏览文件 @
902f19b4
...
...
@@ -13,7 +13,7 @@
// limitations under the License.
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detect
e
r.h"
#include "paddle/fluid/framework/ir/graph_pattern_detect
o
r.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace
paddle
{
...
...
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
0 → 100644
浏览文件 @
902f19b4
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
std
::
unique_ptr
<
ir
::
Graph
>
FCLstmFusePass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
GraphPatternDetector
gpd
;
auto
*
pattern
=
gpd
.
mutable_pattern
();
std
::
unordered_set
<
int
>
fused_ops
({
// first lstm
13
,
15
,
16
,
// second lstm
23
,
25
,
26
});
pattern
->
NewNode
([
&
](
Node
*
x
)
{
return
fused_ops
.
count
(
x
->
id
());
},
"any_node"
);
std
::
unordered_set
<
Node
*>
marked_nodes
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
auto
*
id
=
subgraph
.
at
(
gpd
.
pattern
().
RetriveNode
(
"any_node"
));
marked_nodes
.
insert
(
id
);
};
gpd
(
graph
.
get
(),
handler
);
// Create New OpDesc
auto
lstm_creator
=
[
&
](
int
lstm
,
int
input
,
int
weight_x
,
int
weight_h
,
int
bias
,
int
hidden
,
int
cell
,
int
xx
)
{
#define GET_NODE(x) auto* x##_n = graph->RetriveNode(x);
GET_NODE
(
input
);
GET_NODE
(
weight_x
);
GET_NODE
(
weight_h
);
GET_NODE
(
bias
);
GET_NODE
(
hidden
);
GET_NODE
(
cell
);
GET_NODE
(
xx
);
GET_NODE
(
lstm
);
OpDesc
op_desc
;
op_desc
.
SetType
(
"fusion_lstm"
);
#define SET_IN(Key, node__) op_desc.SetInput(#Key, {node__##_n->Name()});
SET_IN
(
X
,
input
);
SET_IN
(
WeightX
,
weight_x
);
SET_IN
(
WeightH
,
weight_h
);
SET_IN
(
Bias
,
bias
);
#undef GET_NODE
#undef SET_IN
LOG
(
INFO
)
<<
"hidden_n: "
<<
hidden_n
->
Name
();
LOG
(
INFO
)
<<
"cell: "
<<
cell_n
->
Name
();
LOG
(
INFO
)
<<
"xx: "
<<
xx_n
->
Name
();
op_desc
.
SetInput
(
"H0"
,
{});
op_desc
.
SetInput
(
"C0"
,
{});
op_desc
.
SetOutput
(
"Hidden"
,
{
hidden_n
->
Name
()});
op_desc
.
SetOutput
(
"Cell"
,
{
cell_n
->
Name
()});
op_desc
.
SetOutput
(
"XX"
,
{
xx_n
->
Name
()});
op_desc
.
SetOutput
(
"BatchedGate"
,
{
"blstm_0.tmp_2"
});
op_desc
.
SetOutput
(
"BatchCellPreAct"
,
{
"blstm_1.tmp_2"
});
op_desc
.
SetAttr
(
"is_reverse"
,
lstm_n
->
Op
()
->
GetAttr
(
"is_reverse"
));
op_desc
.
SetAttr
(
"use_peepholes"
,
false
);
auto
*
op
=
graph
->
CreateOpNode
(
&
op_desc
);
#define LINK_TO(a, b) \
a->outputs.push_back(b); \
b->inputs.push_back(a);
LINK_TO
(
input_n
,
op
);
LINK_TO
(
weight_x_n
,
op
);
LINK_TO
(
weight_h_n
,
op
);
LINK_TO
(
bias_n
,
op
);
LINK_TO
(
op
,
hidden_n
);
#undef LINK_TO
return
op
;
};
lstm_creator
(
16
,
12
,
14
,
18
,
17
,
22
,
21
,
19
);
lstm_creator
(
26
,
12
,
24
,
28
,
27
,
32
,
31
,
29
);
// remove all the nodes
for
(
auto
*
node
:
marked_nodes
)
{
graph
->
RemoveNode
(
const_cast
<
Node
*>
(
node
));
}
for
(
auto
*
node
:
graph
->
Nodes
())
{
for
(
auto
it
=
node
->
inputs
.
begin
();
it
!=
node
->
inputs
.
end
();)
{
if
(
marked_nodes
.
count
(
*
it
))
{
it
=
const_cast
<
Node
*>
(
node
)
->
inputs
.
erase
(
it
);
}
else
it
++
;
}
for
(
auto
it
=
node
->
outputs
.
begin
();
it
!=
node
->
outputs
.
end
();)
{
if
(
marked_nodes
.
count
(
*
it
))
{
it
=
const_cast
<
Node
*>
(
node
)
->
outputs
.
erase
(
it
);
}
else
it
++
;
}
}
return
graph
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
fc_lstm_fuse_pass
,
paddle
::
framework
::
ir
::
FCLstmFusePass
);
paddle/fluid/framework/ir/fc_lstm_fuse_pass.h
0 → 100644
浏览文件 @
902f19b4
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
FCLstmFusePass
:
public
Pass
{
public:
virtual
~
FCLstmFusePass
()
{}
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/fuse_pass_base.h
0 → 100644
浏览文件 @
902f19b4
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/scope.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
static
const
char
kParamScopeAttr
[]
=
"param_scope"
;
class
FusePassBase
:
public
Pass
{
public:
void
Init
(
Graph
*
graph
)
const
{
graph_
=
graph
;
}
Scope
*
param_scope
()
const
{
PADDLE_ENFORCE
(
graph_
->
Has
(
kParamScopeAttr
));
return
graph_
->
Get
<
framework
::
Scope
*>
(
kParamScopeAttr
);
}
virtual
~
FusePassBase
()
{}
protected:
mutable
Graph
*
graph_
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/graph.h
浏览文件 @
902f19b4
...
...
@@ -99,13 +99,13 @@ class Graph {
// Create a normal variable with non-null VarDesc.
ir
::
Node
*
CreateVarNode
(
VarDesc
*
var_desc
)
{
PADDLE_ENFORCE
(
var_desc
);
return
AddNode
(
new
ir
::
Node
(
var_desc
));
return
AddNode
(
new
ir
::
Node
(
var_desc
,
node_count_
++
));
}
// Create a normal runnable operator with OpDesc.
ir
::
Node
*
CreateOpNode
(
OpDesc
*
op_desc
)
{
PADDLE_ENFORCE
(
op_desc
);
return
AddNode
(
new
ir
::
Node
(
op_desc
));
return
AddNode
(
new
ir
::
Node
(
op_desc
,
node_count_
++
));
}
// Create a control dependency var that connects 2 operations. The
...
...
@@ -115,13 +115,14 @@ class Graph {
// TODO(panyx0718): control var name should be really unique.
const
std
::
string
name
=
string
::
Sprintf
(
"%s@%llu"
,
ir
::
Node
::
kControlDepVarName
,
node_set_
.
size
());
return
AddNode
(
new
ir
::
Node
(
name
,
ir
::
Node
::
Type
::
kVariable
));
return
AddNode
(
new
ir
::
Node
(
name
,
ir
::
Node
::
Type
::
kVariable
,
node_count_
++
));
}
// A more free style way of creating a graph node. Mostly use for test
// or "copy" from another node. Avoid using it if possible.
ir
::
Node
*
CreateEmptyNode
(
const
std
::
string
&
name
,
ir
::
Node
::
Type
type
)
{
return
AddNode
(
new
ir
::
Node
(
name
,
type
));
return
AddNode
(
new
ir
::
Node
(
name
,
type
,
node_count_
++
));
}
// Clear all node information of the graph and return the ownership of the
...
...
@@ -142,12 +143,20 @@ class Graph {
nodes_
.
erase
(
node
);
}
Node
*
RetriveNode
(
int
id
)
{
auto
it
=
id2node_
.
find
(
id
);
if
(
it
!=
id2node_
.
end
())
return
it
->
second
;
return
nullptr
;
}
private:
// This method takes ownership of `node`.
ir
::
Node
*
AddNode
(
ir
::
Node
*
node
)
{
PADDLE_ENFORCE
(
node_set_
.
find
(
node
)
==
node_set_
.
end
());
nodes_
[
node
].
reset
(
node
);
node_set_
.
insert
(
node
);
PADDLE_ENFORCE
(
!
id2node_
.
count
(
node
->
id
()),
"duplicate id %d"
,
node
->
id
());
id2node_
[
node
->
id
()]
=
node
;
return
node
;
}
...
...
@@ -157,6 +166,8 @@ class Graph {
std
::
map
<
std
::
string
,
std
::
function
<
void
(
void
)
>>
attr_dels_
;
std
::
map
<
ir
::
Node
*
,
std
::
unique_ptr
<
ir
::
Node
>>
nodes_
;
std
::
unordered_set
<
ir
::
Node
*>
node_set_
;
std
::
map
<
int
,
Node
*>
id2node_
;
int
node_count_
{
0
};
};
bool
IsControlDepVar
(
const
ir
::
Node
&
var
);
...
...
paddle/fluid/framework/ir/graph_helper.cc
浏览文件 @
902f19b4
...
...
@@ -103,10 +103,10 @@ std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList(
for
(
auto
&
var
:
n
->
inputs
)
{
for
(
auto
&
adj_n
:
var
->
inputs
)
{
PADDLE_ENFORCE
(
adj_n
->
NodeType
()
==
ir
::
Node
::
Type
::
kOperation
);
adj_list
[
n
].
insert
(
adj_n
);
VLOG
(
4
)
<<
"adj "
<<
adj_n
->
Name
()
<<
reinterpret_cast
<
void
*>
(
adj_n
)
<<
" -> "
<<
n
->
Name
()
<<
reinterpret_cast
<
void
*>
(
n
)
<<
" via "
<<
var
->
Name
()
<<
reinterpret_cast
<
void
*>
(
var
);
adj_list
[
n
].
insert
(
adj_n
);
}
}
}
...
...
paddle/fluid/framework/ir/graph_pattern_detect
e
r.cc
→
paddle/fluid/framework/ir/graph_pattern_detect
o
r.cc
浏览文件 @
902f19b4
...
...
@@ -17,7 +17,7 @@
#include <vector>
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detect
e
r.h"
#include "paddle/fluid/framework/ir/graph_pattern_detect
o
r.h"
#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/platform/enforce.h"
...
...
@@ -34,7 +34,7 @@ PDNode* PDPattern::NewNode(PDNode::teller_t&& teller, const std::string& name) {
name
);
}
nodes_
.
emplace_back
(
new
PDNode
(
std
::
move
(
teller
),
name
));
nodes_
.
emplace_back
(
new
PDNode
(
std
::
move
(
teller
),
this
,
name
));
auto
*
cur
=
nodes_
.
back
().
get
();
node_map_
[
name
]
=
cur
;
return
cur
;
...
...
@@ -56,19 +56,22 @@ void PDPattern::AddEdge(PDNode* a, PDNode* b) {
edges_
.
emplace_back
(
a
,
b
);
}
void
GraphPatternDetect
e
r
::
operator
()(
Graph
*
graph
,
GraphPatternDetect
e
r
::
handle_t
handler
)
{
void
GraphPatternDetect
o
r
::
operator
()(
Graph
*
graph
,
GraphPatternDetect
o
r
::
handle_t
handler
)
{
if
(
!
MarkPDNodesInGraph
(
*
graph
))
return
;
auto
subgraphs
=
DetectPatterns
();
UniquePatterns
(
&
subgraphs
);
RemoveOverlappedMatch
(
&
subgraphs
);
LOG
(
INFO
)
<<
"detect "
<<
subgraphs
.
size
()
<<
" subgraph matches the pattern"
;
int
id
=
0
;
for
(
auto
&
g
:
subgraphs
)
{
LOG
(
INFO
)
<<
"optimizing #"
<<
id
++
<<
" subgraph"
;
handler
(
g
,
graph
);
}
}
bool
GraphPatternDetect
e
r
::
MarkPDNodesInGraph
(
const
ir
::
Graph
&
graph
)
{
bool
GraphPatternDetect
o
r
::
MarkPDNodesInGraph
(
const
ir
::
Graph
&
graph
)
{
VLOG
(
4
)
<<
"mark pdnodes in graph"
;
if
(
graph
.
Nodes
().
empty
())
return
false
;
...
...
@@ -114,13 +117,15 @@ bool IsNodesLink(Node* a, Node* b) {
return
false
;
}
std
::
vector
<
GraphPatternDetect
e
r
::
subgraph_t
>
GraphPatternDetect
e
r
::
DetectPatterns
()
{
std
::
vector
<
GraphPatternDetect
o
r
::
subgraph_t
>
GraphPatternDetect
o
r
::
DetectPatterns
()
{
// Init empty subgraphs.
std
::
vector
<
GraphPatternDetect
e
r
::
subgraph_t
>
result
;
std
::
vector
<
GraphPatternDetect
o
r
::
subgraph_t
>
result
;
std
::
vector
<
HitGroup
>
init_groups
;
PADDLE_ENFORCE
(
!
pattern_
.
edges
().
empty
(),
"At least one edge is needed"
);
auto
*
first_pnode
=
pattern_
.
edges
().
front
().
first
;
std
::
array
<
std
::
vector
<
HitGroup
>
,
2
>
bi_records
;
// PADDLE_ENFORCE(!pattern_.edges().empty(), "At least one edge is needed");
auto
*
first_pnode
=
pattern_
.
edges
().
empty
()
?
pattern
().
nodes
().
front
().
get
()
:
pattern_
.
edges
().
front
().
first
;
if
(
!
pdnodes2nodes_
.
count
(
first_pnode
))
return
result
;
for
(
auto
*
node
:
pdnodes2nodes_
[
first_pnode
])
{
HitGroup
group
;
...
...
@@ -129,7 +134,6 @@ GraphPatternDetecter::DetectPatterns() {
}
int
step
=
0
;
std
::
array
<
std
::
vector
<
HitGroup
>
,
2
>
bi_records
;
bi_records
[
0
]
=
std
::
move
(
init_groups
);
// Extend a PDNode to subgraphs by deducing the connection relations defined
...
...
@@ -141,6 +145,7 @@ GraphPatternDetecter::DetectPatterns() {
auto
&
pre_groups
=
bi_records
[
step
%
2
];
auto
&
cur_groups
=
bi_records
[
1
-
(
step
++
%
2
)];
cur_groups
.
clear
();
if
(
pre_groups
.
empty
())
break
;
// source -> target
for
(
Node
*
source
:
pdnodes2nodes_
[
edge
.
first
])
{
for
(
Node
*
target
:
pdnodes2nodes_
[
edge
.
second
])
{
...
...
@@ -163,7 +168,7 @@ GraphPatternDetecter::DetectPatterns() {
}
for
(
auto
&
group
:
bi_records
[
step
%
2
])
{
GraphPatternDetect
e
r
::
subgraph_t
subgraph
;
GraphPatternDetect
o
r
::
subgraph_t
subgraph
;
for
(
auto
&
role
:
group
.
roles
)
{
subgraph
.
emplace
(
role
.
first
,
role
.
second
);
}
...
...
@@ -172,10 +177,10 @@ GraphPatternDetecter::DetectPatterns() {
return
result
;
}
void
GraphPatternDetect
e
r
::
UniquePatterns
(
std
::
vector
<
GraphPatternDetect
e
r
::
subgraph_t
>*
subgraphs
)
{
void
GraphPatternDetect
o
r
::
UniquePatterns
(
std
::
vector
<
GraphPatternDetect
o
r
::
subgraph_t
>*
subgraphs
)
{
if
(
subgraphs
->
empty
())
return
;
std
::
vector
<
GraphPatternDetect
e
r
::
subgraph_t
>
result
;
std
::
vector
<
GraphPatternDetect
o
r
::
subgraph_t
>
result
;
std
::
unordered_set
<
size_t
>
set
;
for
(
auto
&
g
:
*
subgraphs
)
{
...
...
@@ -192,7 +197,7 @@ void GraphPatternDetecter::UniquePatterns(
*
subgraphs
=
result
;
}
void
GraphPatternDetect
e
r
::
RemoveOverlappedMatch
(
void
GraphPatternDetect
o
r
::
RemoveOverlappedMatch
(
std
::
vector
<
subgraph_t
>*
subgraphs
)
{
std
::
vector
<
subgraph_t
>
result
;
std
::
unordered_set
<
Node
*>
node_set
;
...
...
@@ -215,6 +220,46 @@ void GraphPatternDetecter::RemoveOverlappedMatch(
*
subgraphs
=
result
;
}
std
::
string
PDPattern
::
DotString
()
const
{
using
inference
::
analysis
::
Dot
;
Dot
dot
;
int
id
=
0
;
// Create Nodes
std
::
unordered_map
<
PDNode
*
,
std
::
string
>
node2dot
;
for
(
const
auto
&
node
:
nodes
())
{
std
::
string
node_id
=
"Node"
+
std
::
to_string
(
id
++
);
dot
.
AddNode
(
node_id
,
{},
node
->
name
());
node2dot
[
node
.
get
()]
=
node_id
;
}
// Create Edges
for
(
const
auto
&
edge
:
edges
())
{
if
(
!
node2dot
.
count
(
edge
.
first
)
||
!
node2dot
.
count
(
edge
.
second
))
{
LOG
(
ERROR
)
<<
"no node "
<<
edge
.
first
<<
" "
<<
edge
.
second
;
continue
;
}
auto
&
src
=
node2dot
.
at
(
edge
.
first
);
auto
&
trg
=
node2dot
.
at
(
edge
.
second
);
dot
.
AddEdge
(
src
,
trg
,
{});
}
return
dot
.
Build
();
}
PDNode
&
PDNode
::
LinksTo
(
const
std
::
vector
<
PDNode
*>&
others
)
{
// extend outlinks.
for
(
PDNode
*
x
:
others
)
{
pattern_
->
AddEdge
(
this
,
x
);
}
return
*
this
;
}
PDNode
&
PDNode
::
LinksFrom
(
const
std
::
vector
<
PDNode
*>&
others
)
{
// extend outlinks.
for
(
PDNode
*
x
:
others
)
{
pattern_
->
AddEdge
(
x
,
this
);
}
return
*
this
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/graph_pattern_detect
e
r.h
→
paddle/fluid/framework/ir/graph_pattern_detect
o
r.h
浏览文件 @
902f19b4
...
...
@@ -21,12 +21,14 @@
#include <numeric>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/inference/analysis/dot.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
PDPattern
;
// Some basic t
orminoly
gies:
// Some basic t
erminolo
gies:
// - PDPattern: a pattern defined as a data flow graph.
// - PDNode: the node in the pattern, each PDNode represents an `ir::Node`
// that meets some conditions defined in `PDNode.teller`.
...
...
@@ -36,30 +38,43 @@ namespace ir {
struct
PDNode
{
// tell whether an ir::Node* is a candidation for a PDNode.
using
teller_t
=
std
::
function
<
bool
(
Node
*
)
>
;
enum
class
Type
{
kOp
,
kVar
};
PDNode
(
teller_t
&&
teller
,
const
std
::
string
&
name
=
""
)
:
teller_
(
teller
),
name_
(
name
)
{
PADDLE_ENFORCE
(
teller_
!=
nullptr
,
"invalid teller functer is set."
);
}
PDNode
(
PDNode
&&
other
)
=
default
;
std
::
vector
<
PDNode
*>
inlinks
;
std
::
vector
<
PDNode
*>
outlinks
;
// this link to others
PDNode
&
LinksTo
(
const
std
::
vector
<
PDNode
*>&
others
);
PDNode
&
LinksFrom
(
const
std
::
vector
<
PDNode
*>&
others
);
bool
Tell
(
Node
*
node
)
const
{
PADDLE_ENFORCE
(
teller_
!=
nullptr
,
"teller should be set for a PDNode"
);
return
teller_
(
node
);
}
bool
IsOp
()
const
{
return
type_
==
Type
::
kOp
;
}
bool
IsVar
()
const
{
return
type_
==
Type
::
kVar
;
}
const
std
::
string
&
name
()
const
{
return
name_
;
}
PDNode
(
const
PDNode
&
)
=
delete
;
PDNode
&
operator
=
(
const
PDNode
&
)
=
delete
;
private:
PDNode
(
teller_t
&&
teller
,
PDPattern
*
pattern
,
const
std
::
string
&
name
=
""
,
Type
type
=
Type
::
kVar
)
:
teller_
(
std
::
move
(
teller
)),
pattern_
(
pattern
),
name_
(
name
),
type_
(
type
)
{
PADDLE_ENFORCE
(
teller_
!=
nullptr
,
"invalid teller functer is set."
);
}
PDNode
(
PDNode
&&
other
)
=
default
;
friend
class
PDPattern
;
teller_t
teller_
;
PDPattern
*
pattern_
;
std
::
string
name_
;
Type
type_
;
};
/*
...
...
@@ -102,6 +117,8 @@ class PDPattern {
const
std
::
vector
<
std
::
unique_ptr
<
PDNode
>>&
nodes
()
const
{
return
nodes_
;
}
const
std
::
vector
<
edge_t
>&
edges
()
const
{
return
edges_
;
}
std
::
string
DotString
()
const
;
private:
#ifdef PADDLE_WITH_TESTING
FRIEND_TEST
(
PDPattern
,
AddEdge
);
...
...
@@ -117,7 +134,7 @@ class PDPattern {
};
/*
* GraphPatternDetect
e
r helps to detect the specific patterns in the graph.
* GraphPatternDetect
o
r helps to detect the specific patterns in the graph.
* Input a pattern, output a list of the matched subgraphs/nodes.
* This helper can be used to support fuse(conv+batchnorm => batchnorm e.g.).
*
...
...
@@ -129,7 +146,7 @@ class PDPattern {
*
* Usage:
* // Create a detector
* GraphPatternDetect
e
r detector;
* GraphPatternDetect
o
r detector;
* // Define the detector's pattern, by adding PDNode and define the edges.
* auto* node0 = detector.mutable_pattern().AddNode(...)
* auto* node1 = detector.mutable_pattern().AddNode(...)
...
...
@@ -138,11 +155,11 @@ class PDPattern {
* detector.mutable_pattern().AddEdge(node0, node1);
* // Create an handler, to define the behavior of treating the filtered
* // subgraphs that comply with the patterns.
* GraphPatternDetect
e
r::handle_t handler = some labmda
* GraphPatternDetect
o
r::handle_t handler = some labmda
* // Execute the detector.
* detector(&graph, handler);
*/
class
GraphPatternDetect
e
r
{
class
GraphPatternDetect
o
r
{
public:
using
subgraph_t
=
std
::
unordered_map
<
PDNode
*
,
Node
*>
;
...
...
@@ -177,10 +194,62 @@ class GraphPatternDetecter {
using
hit_rcd_t
=
std
::
pair
<
Node
*
/*node in graph*/
,
PDNode
*
/*node in pattern*/
>
;
PDPattern
pattern_
;
std
::
vector
<
hit_rcd_t
>
marked_records_
;
std
::
unordered_map
<
const
PDNode
*
,
std
::
unordered_set
<
Node
*>>
pdnodes2nodes_
;
};
// some helper methods.
// Op's input.
static
bool
VarLinksToOp
(
Node
*
node
,
const
std
::
string
&
op_type
)
{
for
(
auto
*
out
:
node
->
outputs
)
{
if
(
out
->
IsOp
()
&&
out
->
Op
()
->
Type
()
==
op_type
)
{
return
true
;
}
}
return
false
;
}
// Op's output.
static
bool
VarLinksFromOp
(
Node
*
node
,
const
std
::
string
&
op_type
)
{
for
(
auto
*
out
:
node
->
inputs
)
{
if
(
out
->
IsOp
()
&&
out
->
Op
()
->
Type
()
==
op_type
)
{
return
true
;
}
}
return
false
;
}
// Check whether a var node is a op node's nth input.
static
bool
IsNthInput
(
Node
*
var
,
Node
*
op
,
const
std
::
string
&
argument
,
size_t
nth
)
{
PADDLE_ENFORCE
(
var
->
IsVar
());
PADDLE_ENFORCE
(
op
->
IsOp
());
if
(
op
->
inputs
.
size
()
<=
nth
)
return
false
;
return
var
->
Name
()
==
op
->
Op
()
->
Input
(
argument
)[
nth
];
}
static
void
GraphSafeRemoveNodes
(
Graph
*
graph
,
const
std
::
unordered_set
<
const
Node
*>&
nodes
)
{
for
(
auto
*
node
:
nodes
)
{
graph
->
RemoveNode
(
const_cast
<
Node
*>
(
node
));
}
for
(
auto
*
node
:
graph
->
Nodes
())
{
for
(
auto
it
=
node
->
inputs
.
begin
();
it
!=
node
->
inputs
.
end
();)
{
if
(
nodes
.
count
(
*
it
))
{
it
=
const_cast
<
Node
*>
(
node
)
->
inputs
.
erase
(
it
);
}
else
it
++
;
}
for
(
auto
it
=
node
->
outputs
.
begin
();
it
!=
node
->
outputs
.
end
();)
{
if
(
nodes
.
count
(
*
it
))
{
it
=
const_cast
<
Node
*>
(
node
)
->
outputs
.
erase
(
it
);
}
else
it
++
;
}
}
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/graph_pattern_detect
e
r_tester.cc
→
paddle/fluid/framework/ir/graph_pattern_detect
o
r_tester.cc
浏览文件 @
902f19b4
...
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/graph_pattern_detect
e
r.h"
#include "paddle/fluid/framework/ir/graph_pattern_detect
o
r.h"
#include <gtest/gtest.h>
...
...
@@ -82,7 +82,7 @@ TEST(PDPattern, AddEdge) {
}
TEST
(
GraphPatternDetecter
,
MarkPDNodesInGraph
)
{
GraphPatternDetect
e
r
x
;
GraphPatternDetect
o
r
x
;
// mark o2, o3, v2
// The pattern is a graph:
...
...
@@ -131,7 +131,7 @@ TEST(GraphPatternDetecter, MultiSubgraph) {
Graph
graph
(
program
);
BuildGraph
(
&
graph
);
GraphPatternDetect
e
r
x
;
GraphPatternDetect
o
r
x
;
// The pattern is a graph:
// op -> var
...
...
@@ -149,8 +149,8 @@ TEST(GraphPatternDetecter, MultiSubgraph) {
x
.
mutable_pattern
()
->
AddEdge
(
any_var
,
any_op1
);
int
count
=
0
;
GraphPatternDetect
e
r
::
handle_t
handle
=
[
&
](
const
GraphPatternDetect
e
r
::
subgraph_t
&
s
,
Graph
*
g
)
{
GraphPatternDetect
o
r
::
handle_t
handle
=
[
&
](
const
GraphPatternDetect
o
r
::
subgraph_t
&
s
,
Graph
*
g
)
{
LOG
(
INFO
)
<<
"Detect "
<<
s
.
at
(
any_op
)
->
Name
()
<<
" -> "
<<
s
.
at
(
any_var
)
->
Name
()
<<
" -> "
<<
s
.
at
(
any_op1
)
->
Name
();
count
++
;
...
...
paddle/fluid/framework/ir/graph_viz_pass.cc
浏览文件 @
902f19b4
...
...
@@ -16,11 +16,13 @@ limitations under the License. */
#include <unordered_set>
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/inference/analysis/dot.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
static
const
char
kGraphVizPath
[]
=
"graph_viz_path"
;
using
inference
::
analysis
::
Dot
;
std
::
unique_ptr
<
ir
::
Graph
>
GraphVizPass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
...
...
@@ -30,41 +32,65 @@ std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl(
PADDLE_ENFORCE
(
fout
->
good
());
std
::
ostream
&
sout
=
*
fout
;
size_t
var_id
=
0
;
std
::
unordered_map
<
const
ir
::
Node
*
,
size_t
>
vars
;
sout
<<
"digraph G {
\n
"
;
for
(
const
ir
::
Node
*
n
:
graph
->
Nodes
())
{
if
(
n
->
NodeType
()
!=
ir
::
Node
::
Type
::
kVariable
)
continue
;
size_t
cur_var_id
=
var_id
++
;
vars
[
n
]
=
cur_var_id
;
sout
<<
"var_"
<<
cur_var_id
<<
" [label=
\"
"
<<
n
->
Name
()
<<
"
\"
]"
<<
std
::
endl
;
}
size_t
op_id
=
0
;
for
(
const
ir
::
Node
*
n
:
graph
->
Nodes
())
{
if
(
n
->
NodeType
()
!=
ir
::
Node
::
Type
::
kOperation
)
continue
;
std
::
string
op_name
=
"op_"
+
std
::
to_string
(
op_id
++
);
sout
<<
op_name
<<
" [label=
\"
"
<<
n
->
Name
()
<<
"
\"
, shape=rect]"
<<
std
::
endl
;
for
(
auto
in
:
n
->
inputs
)
{
std
::
string
var_name
=
"var_"
+
std
::
to_string
(
vars
[
in
]);
sout
<<
var_name
<<
" -> "
<<
op_name
<<
std
::
endl
;
std
::
unordered_map
<
const
ir
::
Node
*
,
std
::
string
>
node2dot
;
Dot
dot
;
std
::
vector
<
Dot
::
Attr
>
op_attrs
({
Dot
::
Attr
(
"style"
,
"filled"
),
Dot
::
Attr
(
"shape"
,
"box"
),
Dot
::
Attr
(
"fillcolor"
,
"red"
)});
std
::
vector
<
Dot
::
Attr
>
var_attrs
({
Dot
::
Attr
(
"style"
,
"filled,rounded"
),
// Dot::Attr("shape", "diamond"),
Dot
::
Attr
(
"fillcolor"
,
"yellow"
)});
std
::
vector
<
Dot
::
Attr
>
marked_op_attrs
({
Dot
::
Attr
(
"style"
,
"filled"
),
Dot
::
Attr
(
"shape"
,
"box"
),
Dot
::
Attr
(
"fillcolor"
,
"lightgray"
)});
std
::
vector
<
Dot
::
Attr
>
marked_var_attrs
(
{
Dot
::
Attr
(
"style"
,
"filled,rounded"
),
// Dot::Attr("shape", "diamond"),
Dot
::
Attr
(
"fillcolor"
,
"lightgray"
)});
auto
marked_nodes
=
ConsumeMarkedNodes
(
graph
.
get
());
// Create nodes
for
(
const
Node
*
n
:
graph
->
Nodes
())
{
std
::
string
node_id
=
n
->
Name
()
+
"("
+
std
::
to_string
(
n
->
id
())
+
")"
;
if
(
n
->
IsOp
())
{
decltype
(
op_attrs
)
attr
=
marked_nodes
.
count
(
n
)
?
marked_op_attrs
:
op_attrs
;
dot
.
AddNode
(
node_id
,
attr
,
node_id
);
}
else
if
(
n
->
IsVar
())
{
decltype
(
op_attrs
)
attr
=
marked_nodes
.
count
(
n
)
?
marked_var_attrs
:
var_attrs
;
dot
.
AddNode
(
node_id
,
attr
,
node_id
);
}
for
(
auto
out
:
n
->
outputs
)
{
std
::
string
var_name
=
"var_"
+
std
::
to_string
(
vars
[
out
]);
sout
<<
op_name
<<
" -> "
<<
var_name
<<
std
::
endl
;
node2dot
[
n
]
=
node_id
;
}
// Create edges
for
(
const
Node
*
n
:
graph
->
Nodes
())
{
const
auto
&
src_id
=
node2dot
.
at
(
n
);
for
(
auto
*
out
:
n
->
outputs
)
{
const
auto
&
trg_id
=
node2dot
.
at
(
out
);
dot
.
AddEdge
(
src_id
,
trg_id
,
{});
}
}
sout
<<
"}
\n
"
;
sout
<<
dot
.
Build
();
return
graph
;
}
GraphVizPass
::
marked_nodes_t
GraphVizPass
::
ConsumeMarkedNodes
(
Graph
*
graph
)
const
{
marked_nodes_t
res
;
if
(
graph
->
Has
(
kGraphvizMarkedNodeAttr
))
{
auto
&
attr
=
graph
->
Get
<
marked_nodes_t
>
(
kGraphvizMarkedNodeAttr
);
res
=
attr
;
attr
.
clear
();
}
return
res
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
...
...
paddle/fluid/framework/ir/graph_viz_pass.h
浏览文件 @
902f19b4
...
...
@@ -27,10 +27,19 @@ namespace paddle {
namespace
framework
{
namespace
ir
{
const
char
kGraphvizMarkedNodeAttr
[]
=
"__graphviz__marked_node__"
;
class
GraphVizPass
:
public
Pass
{
public:
using
marked_nodes_t
=
std
::
unordered_set
<
const
Node
*>
;
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
;
// Tell whether there are any marked nodes in the graph. Consume the
// corresponding attribute.
marked_nodes_t
ConsumeMarkedNodes
(
Graph
*
graph
)
const
;
};
}
// namespace ir
...
...
paddle/fluid/framework/ir/node.h
浏览文件 @
902f19b4
...
...
@@ -29,20 +29,26 @@ class Node {
enum
class
Type
{
kOperation
,
kVariable
};
static
constexpr
char
kControlDepVarName
[]
=
"__control_var"
;
explicit
Node
(
const
std
::
string
&
name
,
Type
type
)
:
name_
(
name
),
var_desc_
(
nullptr
),
op_desc_
(
nullptr
),
type_
(
type
)
{}
explicit
Node
(
const
std
::
string
&
name
,
Type
type
,
int
id
=
-
1
)
:
name_
(
name
),
var_desc_
(
nullptr
),
op_desc_
(
nullptr
),
type_
(
type
),
id_
(
id
)
{}
explicit
Node
(
VarDesc
*
var_desc
)
explicit
Node
(
VarDesc
*
var_desc
,
int
id
=
-
1
)
:
name_
(
var_desc
->
Name
()),
var_desc_
(
new
VarDesc
(
*
var_desc
)),
op_desc_
(
nullptr
),
type_
(
Type
::
kVariable
)
{}
type_
(
Type
::
kVariable
),
id_
(
id
)
{}
explicit
Node
(
OpDesc
*
op_desc
)
explicit
Node
(
OpDesc
*
op_desc
,
int
id
=
-
1
)
:
name_
(
op_desc
->
Type
()),
var_desc_
(
nullptr
),
op_desc_
(
new
OpDesc
(
*
op_desc
,
op_desc
->
Block
())),
type_
(
Type
::
kOperation
)
{}
type_
(
Type
::
kOperation
),
id_
(
id
)
{}
Type
NodeType
()
const
{
return
type_
;
}
...
...
@@ -58,6 +64,8 @@ class Node {
return
op_desc_
.
get
();
}
int
id
()
const
{
return
id_
;
}
bool
IsOp
()
const
{
return
type_
==
Type
::
kOperation
;
}
bool
IsVar
()
const
{
return
type_
==
Type
::
kVariable
;
}
...
...
@@ -69,6 +77,7 @@ class Node {
std
::
unique_ptr
<
VarDesc
>
var_desc_
;
std
::
unique_ptr
<
OpDesc
>
op_desc_
;
Type
type_
;
int
id_
;
private:
DISABLE_COPY_AND_ASSIGN
(
Node
);
...
...
paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc
0 → 100644
浏览文件 @
902f19b4
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/lod_tensor.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
struct
FuseExpr
{};
// sequence expand, concat fuse pattern, return concat's output
PDNode
*
BuildSeqExpandConcatPattern
(
PDPattern
*
pattern
)
{
// The following operators will be fused:
// concat
// sequence_expand
// sequence_expand
// The following variables will be treat as inputs:
// concat mid input, 0th input for fused op
// sequence_expand input, 1th input for fused op
// sequence_expand input, 2th input for fused op
// The following variables will be treat as outputs:
// concat output
// So the following variables will be removed:
// sequence-expand output
// sequence-expand output
// Three operators
auto
*
sequence_expand0
=
pattern
->
NewNode
(
[](
Node
*
x
)
{
return
x
&&
x
->
IsOp
()
&&
x
->
Op
()
->
Type
()
==
"sequence_expand"
;
},
"sequence_expand0"
);
auto
*
sequence_expand1
=
pattern
->
NewNode
(
[](
Node
*
x
)
{
return
x
&&
x
->
IsOp
()
&&
x
->
Op
()
->
Type
()
==
"sequence_expand"
;
},
"sequence_expand1"
);
auto
*
concat
=
pattern
->
NewNode
(
[](
Node
*
x
)
{
return
x
&&
x
->
IsOp
()
&&
x
->
Op
()
->
Type
()
==
"concat"
&&
// basic check
x
->
Op
()
->
Input
(
"X"
).
size
()
==
3
;
// Special case
},
"concat"
);
auto
*
sequence_expand0_in
=
pattern
->
NewNode
(
[](
Node
*
x
)
{
return
x
&&
x
->
IsVar
()
&&
VarLinksToOp
(
x
,
"sequence_expand"
);
},
"sequence_expand0_in"
);
auto
*
sequence_expand1_in
=
pattern
->
NewNode
(
[](
Node
*
x
)
{
return
x
&&
x
->
IsVar
()
&&
VarLinksToOp
(
x
,
"sequence_expand"
);
},
"sequence_expand1_in"
);
// The variables
auto
*
sequence_expand0_out
=
pattern
->
NewNode
(
[](
Node
*
x
)
{
return
x
&&
x
->
IsVar
()
&&
VarLinksFromOp
(
x
,
"sequence_expand"
)
&&
// basic check
VarLinksToOp
(
x
,
"concat"
)
&&
// is concat's input
IsNthInput
(
x
,
x
->
outputs
[
0
],
"X"
,
1
);
// X[0]
},
"sequence_expand0_out"
);
auto
*
sequence_expand1_out
=
pattern
->
NewNode
(
[](
Node
*
x
)
{
return
x
&&
x
->
IsVar
()
&&
VarLinksFromOp
(
x
,
"sequence_expand"
)
&&
// basic check
VarLinksToOp
(
x
,
"concat"
)
&&
// is concat's input
IsNthInput
(
x
,
x
->
outputs
[
0
],
"X"
,
2
);
// x[2]
},
"sequence_expand1_out"
);
auto
*
concat_in0
=
pattern
->
NewNode
(
[](
Node
*
x
)
{
return
x
&&
x
->
IsVar
()
&&
VarLinksToOp
(
x
,
"concat"
);
},
"concat_in0"
);
auto
*
concat_out
=
pattern
->
NewNode
(
[](
Node
*
x
)
{
return
x
&&
x
->
IsVar
()
&&
VarLinksFromOp
(
x
,
"concat"
);
},
"concat_out"
);
// Links
sequence_expand0
->
LinksFrom
({
sequence_expand0_in
})
.
LinksTo
({
sequence_expand0_out
});
sequence_expand1
->
LinksFrom
({
sequence_expand1_in
})
.
LinksTo
({
sequence_expand1_out
});
concat
->
LinksFrom
({
sequence_expand0_out
,
sequence_expand1_out
,
concat_in0
})
.
LinksTo
({
concat_out
});
return
concat_out
;
}
PDNode
*
BuildFCPattern
(
PDPattern
*
pattern
,
PDNode
*
fc_x
)
{
PDNode
*
fc_w
=
pattern
->
NewNode
(
[](
Node
*
x
)
{
return
x
&&
x
->
IsVar
()
&&
// basic
VarLinksToOp
(
x
,
"mul"
)
&&
// link
x
->
Var
()
->
Proto
()
->
persistable
();
// is a parameter
},
"fc_w"
);
PDNode
*
mul_out
=
pattern
->
NewNode
(
[](
Node
*
x
)
{
return
x
&&
x
->
IsVar
()
&&
// basic
VarLinksFromOp
(
x
,
"mul"
)
&&
// link
VarLinksToOp
(
x
,
"elementwise_add"
)
&&
//
!
x
->
Var
()
->
Proto
()
->
persistable
();
// is a parameter
},
"mul_out"
);
PDNode
*
fc_mul
=
pattern
->
NewNode
(
[](
Node
*
x
)
{
return
x
&&
x
->
IsOp
()
&&
x
->
Op
()
->
Type
()
==
"mul"
;
// basic
},
"fc_mul"
);
PDNode
*
fc_bias
=
pattern
->
NewNode
(
[](
Node
*
x
)
{
return
x
&&
x
->
IsVar
()
&&
// basic
VarLinksToOp
(
x
,
"elementwise_add"
)
&&
// link
x
->
Var
()
->
Proto
()
->
persistable
();
// is a parameter
},
"fc_bias"
);
PDNode
*
elementwise_add
=
pattern
->
NewNode
(
[](
Node
*
x
)
{
return
x
&&
x
->
IsOp
()
&&
x
->
Op
()
->
Type
()
==
"elementwise_add"
;
},
"elementwise_add"
);
PDNode
*
add_out
=
pattern
->
NewNode
(
[](
Node
*
x
)
{
return
x
&&
x
->
IsVar
()
&&
// basic
VarLinksFromOp
(
x
,
"elementwise_add"
)
&&
// link
!
x
->
Var
()
->
Proto
()
->
persistable
();
// is a parameter
},
"add_out"
);
std
::
set
<
std
::
string
>
acts
({
"sigmoid"
,
"tanh"
,
"relu"
,
"identity"
});
PDNode
*
act
=
pattern
->
NewNode
(
[
=
](
Node
*
x
)
{
return
x
&&
x
->
IsOp
()
&&
acts
.
count
(
x
->
Op
()
->
Type
());
},
"act"
);
PDNode
*
fc_out
=
pattern
->
NewNode
(
[](
Node
*
x
)
{
return
x
&&
x
->
IsVar
()
&&
// basic
!
x
->
Var
()
->
Proto
()
->
persistable
();
// is a parameter
},
"fc_out"
);
fc_mul
->
LinksFrom
({
fc_w
,
fc_x
}).
LinksTo
({
mul_out
});
elementwise_add
->
LinksFrom
({
mul_out
,
fc_bias
}).
LinksTo
({
add_out
});
act
->
LinksFrom
({
add_out
}).
LinksTo
({
fc_out
});
return
fc_out
;
}
std
::
unique_ptr
<
ir
::
Graph
>
SeqConcatFcFusePass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
FusePassBase
::
Init
(
graph
.
get
());
GraphPatternDetector
detector
;
auto
*
pattern
=
detector
.
mutable_pattern
();
auto
*
concat_out
=
BuildSeqExpandConcatPattern
(
pattern
);
BuildFCPattern
(
pattern
,
concat_out
);
#define GET_NODE(id, pattern) \
PADDLE_ENFORCE(subgraph.count(pattern.RetriveNode(#id)), \
"pattern has no Node called %s", #id); \
auto* id = subgraph.at(pattern.RetriveNode(#id)); \
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
detector
(
graph
.
get
(),
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
VLOG
(
4
)
<<
"get one concat pattern"
;
// fc
GET_NODE
(
fc_w
,
detector
.
pattern
());
GET_NODE
(
fc_bias
,
detector
.
pattern
());
GET_NODE
(
act
,
detector
.
pattern
());
GET_NODE
(
fc_out
,
detector
.
pattern
());
// concat
GET_NODE
(
concat_in0
,
detector
.
pattern
());
GET_NODE
(
sequence_expand0_in
,
detector
.
pattern
());
GET_NODE
(
sequence_expand1_in
,
detector
.
pattern
());
OpDesc
op_desc
;
op_desc
.
SetType
(
"fusion_seqexpand_concat_fc"
);
op_desc
.
SetInput
(
"X"
,
{
concat_in0
->
Name
(),
sequence_expand0_in
->
Name
(),
sequence_expand1_in
->
Name
()});
op_desc
.
SetInput
(
"FCWeight"
,
{
fc_w
->
Name
()});
op_desc
.
SetInput
(
"FCBias"
,
{
fc_bias
->
Name
()});
const
std
::
string
fc_out_tmp
=
fc_out
->
Name
()
+
".tmp"
;
param_scope
()
->
Var
(
fc_out_tmp
)
->
GetMutable
<
framework
::
LoDTensor
>
();
op_desc
.
SetOutput
(
"FCOut"
,
{
fc_out_tmp
});
op_desc
.
SetOutput
(
"Out"
,
{
fc_out
->
Name
()});
op_desc
.
SetAttr
(
"fc_activation"
,
act
->
Op
()
->
Type
());
auto
*
op_node
=
graph
->
CreateOpNode
(
&
op_desc
);
// Add links
#define NODE_LINKS(a, b) \
a->outputs.push_back(b); \
b->inputs.push_back(a);
NODE_LINKS
(
fc_w
,
op_node
);
NODE_LINKS
(
fc_bias
,
op_node
);
NODE_LINKS
(
concat_in0
,
op_node
);
NODE_LINKS
(
sequence_expand0_in
,
op_node
);
NODE_LINKS
(
sequence_expand1_in
,
op_node
);
NODE_LINKS
(
op_node
,
fc_out
);
// Clean nodes.
std
::
unordered_set
<
const
Node
*>
marked_nodes
;
for
(
auto
&
item
:
subgraph
)
{
marked_nodes
.
insert
(
item
.
second
);
}
marked_nodes
.
erase
(
fc_w
);
marked_nodes
.
erase
(
fc_bias
);
marked_nodes
.
erase
(
concat_in0
);
marked_nodes
.
erase
(
sequence_expand0_in
);
marked_nodes
.
erase
(
sequence_expand1_in
);
marked_nodes
.
erase
(
fc_out
);
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
});
return
graph
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
seq_concat_fc_fuse_pass
,
paddle
::
framework
::
ir
::
SeqConcatFcFusePass
);
paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.h
0 → 100644
浏览文件 @
902f19b4
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
SeqConcatFcFusePass
:
public
FusePassBase
{
public:
virtual
~
SeqConcatFcFusePass
()
{}
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/inference/analysis/CMakeLists.txt
浏览文件 @
902f19b4
cc_library
(
ir_pass_manager SRCS ir_pass_manager.cc DEPS graph pass
)
cc_library
(
analysis SRCS pass_manager.cc dot.cc node.cc data_flow_graph.cc graph_traits.cc subgraph_splitter.cc
set
(
analysis_deps
framework_proto proto_desc ir_pass_manager graph pass paddle_fluid_api executor
)
cc_library
(
analysis SRCS pass_manager.cc node.cc data_flow_graph.cc graph_traits.cc subgraph_splitter.cc
analyzer.cc
helper.cc
# passes
...
...
@@ -10,11 +13,11 @@ cc_library(analysis SRCS pass_manager.cc dot.cc node.cc data_flow_graph.cc graph
tensorrt_subgraph_node_mark_pass.cc
fluid_to_ir_pass.cc
model_store_pass.cc
DEPS
framework_proto proto_desc ir_pass_manager graph pass
)
DEPS
${
analysis_deps
}
)
cc_test
(
test_node SRCS node_tester.cc DEPS analysis
)
cc_test
(
test_dot SRCS dot_tester.cc DEPS analysis
)
cc_binary
(
inference_analyzer SRCS analyzer_main.cc DEPS analysis
)
cc_binary
(
inference_analyzer SRCS analyzer_main.cc DEPS analysis
paddle_fluid
)
set
(
PYTHON_TESTS_DIR
${
PADDLE_BINARY_DIR
}
/python/paddle/fluid/tests
)
...
...
@@ -31,7 +34,7 @@ function (inference_analysis_test TARGET)
endif
()
cc_test
(
${
TARGET
}
SRCS
"
${
analysis_test_SRCS
}
"
DEPS analysis graph fc_fuse_pass graph_viz_pass infer_clean_graph_pass graph_pattern_detect
e
r pass
${
analysis_test_EXTRA_DEPS
}
DEPS analysis graph fc_fuse_pass graph_viz_pass infer_clean_graph_pass graph_pattern_detect
o
r pass
${
analysis_test_EXTRA_DEPS
}
ARGS --inference_model_dir=
${
PYTHON_TESTS_DIR
}
/book/word2vec.inference.model
${
mem_opt
}
)
set_tests_properties
(
${
TARGET
}
PROPERTIES DEPENDS test_word2vec
)
endif
(
WITH_TESTING
)
...
...
@@ -58,20 +61,25 @@ endif()
inference_analysis_test
(
test_analyzer SRCS analyzer_tester.cc
EXTRA_DEPS paddle_inference_api paddle_fluid_api ir_pass_manager analysis
analysis_predictor
# ir
fc_fuse_pass
fc_lstm_fuse_pass
seq_concat_fc_fuse_pass
graph_viz_pass
infer_clean_graph_pass
graph_pattern_detecter
infer_clean_graph_pass
graph_pattern_detector
infer_clean_graph_pass
attention_lstm_fuse_pass
paddle_inference_api
pass
ARGS --inference_model_dir=
${
PYTHON_TESTS_DIR
}
/book/word2vec.inference.model
--infer_ditu_rnn_model=
${
DITU_INSTALL_DIR
}
/model
--infer_ditu_rnn_data=
${
DITU_INSTALL_DIR
}
/data.txt
)
inference_analysis_test
(
test_data_flow_graph SRCS data_flow_graph_tester.cc
)
inference_analysis_test
(
test_data_flow_graph_to_fluid_pass SRCS data_flow_graph_to_fluid_pass_tester.cc
)
inference_analysis_test
(
test_fluid_to_ir_pass SRCS fluid_to_ir_pass_tester.cc
)
inference_analysis_test
(
test_data_flow_graph_to_fluid_pass SRCS data_flow_graph_to_fluid_pass_tester.cc
EXTRA_DEPS paddle_inference_api
)
inference_analysis_test
(
test_fluid_to_ir_pass SRCS fluid_to_ir_pass_tester.cc
EXTRA_DEPS paddle_fluid
)
inference_analysis_test
(
test_fluid_to_data_flow_graph_pass SRCS fluid_to_data_flow_graph_pass_tester.cc
)
inference_analysis_test
(
test_subgraph_splitter SRCS subgraph_splitter_tester.cc
)
inference_analysis_test
(
test_dfg_graphviz_draw_pass SRCS dfg_graphviz_draw_pass_tester.cc
)
...
...
paddle/fluid/inference/analysis/analyzer.cc
浏览文件 @
902f19b4
...
...
@@ -102,6 +102,19 @@ class DfgPassManagerImpl final : public DfgPassManager {
Analyzer
::
Analyzer
()
{
Register
(
"manager1"
,
new
DfgPassManagerImpl
);
}
void
Analyzer
::
Run
(
Argument
*
argument
)
{
// Ungly support fluid-to-ir-pass
argument
->
Set
(
kFluidToIrPassesAttr
,
new
std
::
vector
<
std
::
string
>
({
// Manual update the passes here.
"graph_viz_pass"
,
//
"infer_clean_graph_pass"
,
"graph_viz_pass"
,
//
"attention_lstm_fuse_pass"
,
"graph_viz_pass"
,
//
"fc_lstm_fuse_pass"
,
"graph_viz_pass"
,
//
"seq_concat_fc_fuse_pass"
,
"graph_viz_pass"
,
//
"fc_fuse_pass"
,
"graph_viz_pass"
//
}));
for
(
auto
&
x
:
data_
)
{
PADDLE_ENFORCE
(
x
->
Initialize
(
argument
));
x
->
RunAll
();
...
...
paddle/fluid/inference/analysis/analyzer_tester.cc
浏览文件 @
902f19b4
...
...
@@ -20,6 +20,7 @@
#include "paddle/fluid/inference/analysis/ut_helper.h"
#include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/platform/profiler.h"
DEFINE_string
(
infer_ditu_rnn_model
,
""
,
"model path for ditu RNN"
);
DEFINE_string
(
infer_ditu_rnn_data
,
""
,
"data path for ditu RNN"
);
...
...
@@ -264,39 +265,24 @@ void TestDituRNNPrediction(const std::string &model_path,
const
std
::
string
&
data_path
,
int
batch_size
,
bool
use_analysis
,
bool
activate_ir
,
int
num_times
=
1
)
{
FLAGS_IA_enable_ir
=
activate_ir
;
FLAGS_IA_enable_tensorrt_subgraph_engine
=
false
;
FLAGS_IA_output_storage_path
=
"./analysis.out"
;
std
::
string
model_out
;
if
(
use_analysis
)
{
Argument
argument
(
model_path
);
argument
.
model_output_store_path
.
reset
(
new
std
::
string
(
"./analysis.out"
));
Analyzer
analyzer
;
analyzer
.
Run
(
&
argument
);
// Should get the transformed model stored to ./analysis.out
model_out
=
"./analysis.out"
;
ASSERT_TRUE
(
PathExists
(
model_out
));
}
else
{
model_out
=
FLAGS_infer_ditu_rnn_model
;
}
NativeConfig
config
;
config
.
prog_file
=
model_out
+
"/__model__"
;
config
.
param_file
=
model_out
+
"/param"
;
config
.
prog_file
=
FLAGS_infer_ditu_rnn_model
+
"/__model__"
;
config
.
param_file
=
FLAGS_infer_ditu_rnn_model
+
"/param"
;
config
.
use_gpu
=
false
;
config
.
device
=
0
;
config
.
specify_input_name
=
true
;
auto
predictor
=
auto
base_
predictor
=
CreatePaddlePredictor
<
NativeConfig
,
PaddleEngineKind
::
kNative
>
(
config
);
auto
predictor
=
CreatePaddlePredictor
<
NativeConfig
,
PaddleEngineKind
::
kAnalysis
>
(
config
);
std
::
vector
<
PaddleTensor
>
input_slots
;
DataRecord
data
(
data_path
,
batch_size
);
// Prepare inputs.
PrepareInputs
(
&
input_slots
,
&
data
,
batch_size
);
std
::
vector
<
PaddleTensor
>
outputs
;
std
::
vector
<
PaddleTensor
>
outputs
,
base_outputs
;
base_predictor
->
Run
(
input_slots
,
&
base_outputs
);
Timer
timer
;
timer
.
tic
();
...
...
@@ -308,37 +294,25 @@ void TestDituRNNPrediction(const std::string &model_path,
<<
", latency: "
<<
timer
.
toc
()
/
num_times
<<
"ms"
;
LOG
(
INFO
)
<<
"====================================="
;
for
(
auto
&
out
:
outputs
)
{
PADDLE_ENFORCE_GT
(
outputs
.
size
(),
0
);
PADDLE_ENFORCE_EQ
(
outputs
.
size
(),
base_outputs
.
size
());
for
(
size_t
i
=
0
;
i
<
outputs
.
size
();
i
++
)
{
auto
&
out
=
outputs
[
i
];
auto
&
base_out
=
base_outputs
[
i
];
size_t
size
=
std
::
accumulate
(
out
.
shape
.
begin
(),
out
.
shape
.
end
(),
1
,
[](
int
a
,
int
b
)
{
return
a
*
b
;
});
size_t
size1
=
std
::
accumulate
(
base_out
.
shape
.
begin
(),
base_out
.
shape
.
end
(),
1
,
[](
int
a
,
int
b
)
{
return
a
*
b
;
});
PADDLE_ENFORCE_EQ
(
size
,
size1
);
PADDLE_ENFORCE_GT
(
size
,
0
);
float
*
data
=
static_cast
<
float
*>
(
out
.
data
.
data
());
for
(
size_t
i
=
0
;
i
<
std
::
min
(
sizeof
(
ditu_rnn_target_data
)
/
sizeof
(
float
),
size
);
i
++
)
{
EXPECT_NEAR
(
data
[
i
],
ditu_rnn_target_data
[
i
],
1e-3
);
float
*
base_data
=
static_cast
<
float
*>
(
base_out
.
data
.
data
());
for
(
size_t
i
=
0
;
i
<
size
;
i
++
)
{
EXPECT_NEAR
(
data
[
i
],
base_data
[
i
],
1e-3
);
}
}
}
// Turn on the IR pass supportion, run a real inference and check the result.
TEST
(
Analyzer
,
SupportIRPass
)
{
FLAGS_IA_enable_ir
=
true
;
FLAGS_IA_enable_tensorrt_subgraph_engine
=
false
;
FLAGS_IA_output_storage_path
=
"./analysis.out"
;
Argument
argument
(
FLAGS_inference_model_dir
);
argument
.
model_output_store_path
.
reset
(
new
std
::
string
(
"./analysis.out"
));
Analyzer
analyzer
;
analyzer
.
Run
(
&
argument
);
// Should get the transformed model stored to ./analysis.out
ASSERT_TRUE
(
PathExists
(
"./analysis.out"
));
// Inference from this path.
TestWord2vecPrediction
(
"./analysis.out"
);
}
// Directly infer with the original model.
TEST
(
Analyzer
,
DituRNN_without_analysis
)
{
TestDituRNNPrediction
(
FLAGS_infer_ditu_rnn_model
,
FLAGS_infer_ditu_rnn_data
,
...
...
@@ -365,5 +339,8 @@ TEST(Analyzer, DituRNN_with_analysis_with_IR) {
}
// namespace paddle
USE_PASS
(
fc_fuse_pass
);
USE_PASS
(
seq_concat_fc_fuse_pass
);
USE_PASS
(
fc_lstm_fuse_pass
);
USE_PASS
(
graph_viz_pass
);
USE_PASS
(
infer_clean_graph_pass
);
USE_PASS
(
attention_lstm_fuse_pass
);
paddle/fluid/inference/analysis/argument.h
浏览文件 @
902f19b4
...
...
@@ -26,6 +26,7 @@
#include <string>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/platform/variant.h"
namespace
paddle
{
namespace
inference
{
...
...
@@ -58,6 +59,46 @@ struct Argument {
// The output storage path of ModelStorePass.
std
::
unique_ptr
<
std
::
string
>
model_output_store_path
;
// Support for any other attributes.
template
<
typename
T
>
void
Set
(
const
std
::
string
&
key
,
T
*
data
)
{
PADDLE_ENFORCE_NOT_NULL
(
data
);
PADDLE_ENFORCE
(
!
attrs_
.
count
(
key
),
"duplicate attr called %s"
,
key
);
attrs_
[
key
]
=
data
;
attr_deleters_
[
key
]
=
[
data
,
key
,
this
]()
{
VLOG
(
3
)
<<
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
;
VLOG
(
3
)
<<
"argument delete attr: "
<<
key
;
delete
data
;
};
}
bool
Has
(
const
std
::
string
&
name
)
const
{
return
attrs_
.
count
(
name
);
}
template
<
typename
T
>
T
*
Release
(
const
std
::
string
&
key
)
{
PADDLE_ENFORCE
(
attrs_
.
count
(
key
));
auto
*
res
=
boost
::
any_cast
<
T
*>
(
attrs_
.
at
(
key
));
attrs_
.
erase
(
key
);
attr_deleters_
.
erase
(
key
);
return
res
;
}
template
<
typename
T
>
T
&
Get
(
const
std
::
string
&
key
)
{
PADDLE_ENFORCE
(
Has
(
key
));
return
*
boost
::
any_cast
<
T
*>
(
attrs_
.
at
(
key
));
}
~
Argument
()
{
for
(
auto
&
item
:
attr_deleters_
)
{
item
.
second
();
}
}
private:
std
::
unordered_map
<
std
::
string
,
boost
::
any
>
attrs_
;
std
::
unordered_map
<
std
::
string
,
std
::
function
<
void
()
>>
attr_deleters_
;
};
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
...
...
paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc
浏览文件 @
902f19b4
...
...
@@ -19,6 +19,7 @@
#include "paddle/fluid/framework/proto_desc.h"
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h"
#include "paddle/fluid/inference/io.h"
namespace
paddle
{
namespace
inference
{
...
...
@@ -65,6 +66,10 @@ void DataFlowGraphToFluidPass::Run(DataFlowGraph *graph) {
}
}
if
(
argument_
->
Has
(
"param_scope"
))
{
LOG
(
WARNING
)
<<
"parameter changes in the scope takes effect"
;
}
PADDLE_ENFORCE
(
argument_
->
transformed_program_desc
.
get
());
}
...
...
paddle/fluid/inference/analysis/dot.h
浏览文件 @
902f19b4
...
...
@@ -29,13 +29,13 @@ namespace paddle {
namespace
inference
{
namespace
analysis
{
static
size_t
dot_node_counter
{
0
};
/*
* A Dot template that helps to build a DOT graph definition.
*/
class
Dot
{
public:
static
size_t
counter
;
struct
Attr
{
std
::
string
key
;
std
::
string
value
;
...
...
@@ -57,7 +57,7 @@ class Dot {
Node
(
const
std
::
string
&
name
,
const
std
::
vector
<
Attr
>&
attrs
)
:
name
(
name
),
attrs
(
attrs
),
id_
(
"node_"
+
std
::
to_string
(
Dot
::
counter
++
))
{}
id_
(
"node_"
+
std
::
to_string
(
dot_node_
counter
++
))
{}
std
::
string
id
()
const
{
return
id_
;
}
...
...
@@ -65,6 +65,10 @@ class Dot {
std
::
stringstream
ss
;
CHECK
(
!
name
.
empty
());
ss
<<
id_
;
if
(
attrs
.
empty
())
{
ss
<<
"[label="
<<
'"'
<<
name
<<
'"'
<<
"]"
;
return
ss
.
str
();
}
for
(
size_t
i
=
0
;
i
<
attrs
.
size
();
i
++
)
{
if
(
i
==
0
)
{
ss
<<
"[label="
<<
'"'
<<
name
<<
'"'
<<
" "
;
...
...
@@ -108,9 +112,11 @@ class Dot {
explicit
Dot
(
const
std
::
vector
<
Attr
>&
attrs
)
:
attrs_
(
attrs
)
{}
void
AddNode
(
const
std
::
string
&
name
,
const
std
::
vector
<
Attr
>&
attrs
)
{
CHECK
(
!
nodes_
.
count
(
name
))
<<
"duplicate Node '"
<<
name
<<
"'"
;
nodes_
.
emplace
(
name
,
Node
{
name
,
attrs
});
void
AddNode
(
const
std
::
string
&
id
,
const
std
::
vector
<
Attr
>&
attrs
,
std
::
string
label
=
""
)
{
CHECK
(
!
nodes_
.
count
(
id
))
<<
"duplicate Node '"
<<
id
<<
"'"
;
if
(
label
.
empty
())
label
=
id
;
nodes_
.
emplace
(
id
,
Node
{
label
,
attrs
});
}
void
AddEdge
(
const
std
::
string
&
source
,
const
std
::
string
&
target
,
...
...
paddle/fluid/inference/analysis/fluid_to_ir_pass.cc
浏览文件 @
902f19b4
...
...
@@ -13,3 +13,47 @@
// limitations under the License.
#include "paddle/fluid/inference/analysis/fluid_to_ir_pass.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
namespace
paddle
{
namespace
inference
{
namespace
analysis
{
void
FluidToIrPass
::
EnableParamModify
(
const
std
::
string
&
model_dir
,
const
std
::
string
&
prog_file
,
const
std
::
string
&
param_file
)
{
PADDLE_ENFORCE
(
argument_
);
argument_
->
Set
(
"param_scope"
,
new
framework
::
Scope
);
// Load parameters.
VLOG
(
3
)
<<
"Loading parameters from "
<<
model_dir
;
LoadParams
(
&
argument_
->
Get
<
framework
::
Scope
>
(
"param_scope"
),
model_dir
,
prog_file
,
param_file
);
}
bool
FluidToIrPass
::
LoadParams
(
framework
::
Scope
*
scope
,
const
std
::
string
&
dir
,
const
std
::
string
&
prog_file
,
const
std
::
string
&
param_file
)
{
platform
::
CPUPlace
place
;
platform
::
CPUDeviceContext
ctx
(
place
);
framework
::
Executor
executor
(
place
);
PADDLE_ENFORCE
(
argument_
->
origin_program_desc
.
get
());
framework
::
ProgramDesc
program
(
*
argument_
->
origin_program_desc
);
if
((
!
prog_file
.
empty
())
&&
(
!
param_file
.
empty
()))
{
LOG
(
INFO
)
<<
"load single model file from "
<<
prog_file
;
Load
(
&
executor
,
scope
,
prog_file
,
param_file
);
}
else
if
(
!
dir
.
empty
())
{
LOG
(
INFO
)
<<
"load from dir "
<<
dir
;
Load
(
&
executor
,
scope
,
dir
);
}
else
{
LOG
(
ERROR
)
<<
"failed to load parameters"
;
return
false
;
}
return
true
;
}
}
// namespace analysis
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/analysis/fluid_to_ir_pass.h
浏览文件 @
902f19b4
...
...
@@ -21,12 +21,17 @@ namespace paddle {
namespace
inference
{
namespace
analysis
{
static
const
char
kFluidToIrPassesAttr
[]
=
"__fluid_to_ir_passes__"
;
class
FluidToIrPass
final
:
public
DataFlowGraphPass
{
public:
FluidToIrPass
()
=
default
;
bool
Initialize
(
Argument
*
argument
)
override
{
ANALYSIS_ARGUMENT_CHECK_FIELD
(
argument
);
PADDLE_ENFORCE
(
argument
->
Has
(
kFluidToIrPassesAttr
),
"argument need the attr %s"
,
kFluidToIrPassesAttr
);
argument_
=
argument
;
if
(
argument
->
origin_program_desc
)
{
LOG
(
WARNING
)
<<
"argument's origin_program_desc is already set, might "
"duplicate called"
;
...
...
@@ -46,12 +51,21 @@ class FluidToIrPass final : public DataFlowGraphPass {
if
(
!
argument
->
main_dfg
)
{
argument
->
main_dfg
.
reset
(
new
DataFlowGraph
);
}
// Persist the ProgramDesc in graph's attribute. The IR graph just keep the
// address, will segfault if the original ProgramDesc destroys.
auto
&
ir_program_p
=
argument
->
main_dfg
->
Attr
(
"ir_program_desc"
).
Pointer
();
ir_program_p
=
new
framework
::
ProgramDesc
(
program
);
argument
->
Set
(
"ir_program_desc"
,
new
framework
::
ProgramDesc
(
program
));
LOG
(
INFO
)
<<
"Loading parameters"
;
// Load parameters to argument if needed.
if
(
argument
->
fluid_model_dir
||
(
argument
->
fluid_model_program_path
&&
argument
->
fluid_model_param_path
))
{
#define SAFE_GET(ATTR) std::string ATTR = argument->ATTR ? *argument->ATTR : "";
SAFE_GET
(
fluid_model_dir
);
SAFE_GET
(
fluid_model_program_path
);
SAFE_GET
(
fluid_model_param_path
);
#undef SAFE_GET
EnableParamModify
(
fluid_model_dir
,
fluid_model_program_path
,
fluid_model_param_path
);
}
argument_
=
argument
;
return
true
;
}
...
...
@@ -59,20 +73,36 @@ class FluidToIrPass final : public DataFlowGraphPass {
void
Run
(
DataFlowGraph
*
graph
)
override
{
// Call all the IR Passes
IRPassManager
ir_passes
(
*
static_cast
<
framework
::
ProgramDesc
*>
(
argument_
->
main_dfg
->
Attr
(
"ir_program_desc"
).
Pointer
()));
ir_passes
.
Apply
(
std
::
vector
<
std
::
string
>
(
{
// Manual update the passes here.
"graph_viz_pass"
,
"infer_clean_graph_pass"
,
"graph_viz_pass"
,
"fc_fuse_pass"
,
"graph_viz_pass"
}));
IRPassManager
ir_passes
(
argument_
->
Get
<
framework
::
ProgramDesc
>
(
"ir_program_desc"
),
nullptr
);
// Pass the scope from analysis to IR if needed.
if
(
argument_
->
Has
(
"param_scope"
))
{
// Here the address is passed, attention that IR doesn't own the scope, so
// the real scope in analysis should live during the IR phase.
ir_passes
.
graph
().
Set
(
"param_scope"
,
new
framework
::
Scope
*
(
&
argument_
->
Get
<
framework
::
Scope
>
(
"param_scope"
)));
}
const
auto
&
ir_passes_to_apply
=
argument_
->
Get
<
std
::
vector
<
std
::
string
>>
(
kFluidToIrPassesAttr
);
ir_passes
.
Apply
(
ir_passes_to_apply
);
PADDLE_ENFORCE
(
argument_
->
main_dfg
.
get
());
argument_
->
main_dfg
->
Build
(
ir_passes
.
graph
());
// PADDLE_ENFORCE(argument_->main_dfg->IsFullyConnected());
}
void
EnableParamModify
(
const
std
::
string
&
model_dir
,
const
std
::
string
&
prog_file
,
const
std
::
string
&
param_file
);
std
::
string
repr
()
const
override
{
return
"fluid-to-ir-pass"
;
}
private:
// Load parameters from a single file or from a directory.
bool
LoadParams
(
framework
::
Scope
*
scope
,
const
std
::
string
&
dir
,
const
std
::
string
&
prog_file
,
const
std
::
string
&
param_file
);
private:
Argument
*
argument_
{
nullptr
};
};
...
...
paddle/fluid/inference/analysis/fluid_to_ir_pass_tester.cc
浏览文件 @
902f19b4
...
...
@@ -24,6 +24,8 @@ namespace analysis {
TEST
(
FluidToIrPass
,
Test
)
{
FluidToIrPass
pass
;
Argument
argument
(
FLAGS_inference_model_dir
);
argument
.
Set
(
kFluidToIrPassesAttr
,
new
std
::
vector
<
std
::
string
>
({
"infer_clean_graph_pass"
}));
pass
.
Initialize
(
&
argument
);
pass
.
Run
(
argument
.
main_dfg
.
get
());
}
...
...
@@ -32,6 +34,9 @@ TEST(FluidToIrPass, Test) {
}
// namespace inference
}
// namespace paddle
USE_PASS
(
fc_fuse_pass
);
USE_PASS
(
graph_viz_pass
);
USE_PASS
(
infer_clean_graph_pass
);
USE_PASS
(
attention_lstm_fuse_pass
);
USE_PASS
(
fc_lstm_fuse_pass
);
USE_PASS
(
seq_concat_fc_fuse_pass
);
USE_PASS
(
fc_fuse_pass
);
paddle/fluid/inference/analysis/ir_pass_manager.cc
浏览文件 @
902f19b4
...
...
@@ -14,20 +14,24 @@
#include "paddle/fluid/inference/analysis/ir_pass_manager.h"
#include <string>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/scope.h"
namespace
paddle
{
namespace
inference
{
namespace
analysis
{
IRPassManager
::
IRPassManager
(
const
ProgramDesc
&
program
)
{
IRPassManager
::
IRPassManager
(
const
ProgramDesc
&
program
,
framework
::
Scope
*
scope
)
:
program_
(
program
)
{
graph_
.
reset
(
new
framework
::
ir
::
Graph
(
program
));
if
(
scope
)
graph_
->
Set
(
"param_scope"
,
new
framework
::
Scope
*
(
scope
));
}
void
IRPassManager
::
Apply
(
const
std
::
vector
<
std
::
string
>&
passes
)
{
graph_
->
Set
(
"graph_viz_path"
,
new
std
::
string
(
"./1.dot"
));
void
IRPassManager
::
Apply
(
const
std
::
vector
<
std
::
string
>
&
passes
)
{
// Apply all the passes
std
::
string
pre_pass
;
for
(
const
std
::
string
&
pass_name
:
passes
)
{
for
(
const
std
::
string
&
pass_name
:
passes
)
{
LOG
(
WARNING
)
<<
"Running IR pass ["
<<
pass_name
<<
"]"
;
auto
pass
=
framework
::
ir
::
PassRegistry
::
Instance
().
Get
(
pass_name
);
if
(
pass_name
==
"graph_viz_pass"
)
{
...
...
paddle/fluid/inference/analysis/ir_pass_manager.h
浏览文件 @
902f19b4
...
...
@@ -23,6 +23,7 @@
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
namespace
paddle
{
namespace
inference
{
...
...
@@ -31,14 +32,15 @@ using framework::ProgramDesc;
class
IRPassManager
final
{
public:
IRPassManager
(
const
ProgramDesc
&
program
);
IRPassManager
(
const
ProgramDesc
&
program
,
framework
::
Scope
*
scope
);
void
Apply
(
const
std
::
vector
<
std
::
string
>
&
passes
);
void
Apply
(
const
std
::
vector
<
std
::
string
>
&
passes
);
framework
::
ir
::
Graph
&
graph
()
const
{
return
*
graph_
;
}
framework
::
ir
::
Graph
&
graph
()
const
{
return
*
graph_
;
}
private:
std
::
unique_ptr
<
framework
::
ir
::
Graph
>
graph_
;
ProgramDesc
program_
;
};
}
// namespace analysis
...
...
paddle/fluid/inference/analysis/pass_manager.cc
浏览文件 @
902f19b4
...
...
@@ -33,9 +33,9 @@ bool PassManager::Initialize(Argument* argument) {
void
DfgPassManager
::
RunAll
()
{
PADDLE_ENFORCE
(
argument_
);
LOG
(
INFO
)
<<
"Total "
<<
data_
.
size
()
<<
" passes"
;
LOG
(
INFO
)
<<
"Total "
<<
data_
.
size
()
<<
"
Analysys
passes"
;
for
(
auto
&
pass
:
data_
)
{
LOG
(
WARNING
)
<<
"Running pass ["
<<
pass
->
repr
()
<<
"]"
;
LOG
(
WARNING
)
<<
"Running
Analysis
pass ["
<<
pass
->
repr
()
<<
"]"
;
pass
->
Run
(
argument_
->
main_dfg
.
get
());
}
}
...
...
paddle/fluid/inference/api/CMakeLists.txt
浏览文件 @
902f19b4
...
...
@@ -20,7 +20,7 @@ endif(APPLE)
set
(
inference_deps paddle_inference_api paddle_fluid_api analysis pass ir_pass_manager
graph_viz_pass fc_fuse_pass
infer_clean_graph_pass
infer_clean_graph_pass
)
if
(
WITH_GPU AND TENSORRT_FOUND
)
...
...
@@ -46,7 +46,8 @@ function(inference_api_test TARGET_NAME)
endif
(
WITH_TESTING
)
endfunction
(
inference_api_test
)
cc_library
(
paddle_inference_api SRCS api.cc api_impl.cc DEPS lod_tensor
)
cc_library
(
paddle_inference_api SRCS api.cc api_impl.cc helper.cc DEPS lod_tensor
)
cc_library
(
analysis_predictor SRCS analysis_predictor.cc DEPS paddle_inference_api
)
cc_test
(
test_paddle_inference_api
SRCS api_tester.cc
...
...
paddle/fluid/inference/api/analysis_predictor.cc
0 → 100644
浏览文件 @
902f19b4
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/api/api_impl.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/utils/singleton.h"
namespace
paddle
{
using
inference
::
analysis
::
Argument
;
using
inference
::
Singleton
;
using
inference
::
analysis
::
Analyzer
;
using
framework
::
proto
::
ProgramDesc
;
/* This predictor is based on the original native predictor with IR and Analysis
* support. It will optimize IR and Parameters in the runtime.
* TODO(Superjomn) Replace the Navive predictor?
*/
class
AnalysisPredictor
:
public
NativePaddlePredictor
{
public:
explicit
AnalysisPredictor
(
const
NativeConfig
&
config
)
:
NativePaddlePredictor
(
config
),
config_
(
config
)
{}
bool
Init
(
const
std
::
shared_ptr
<
framework
::
Scope
>&
parent_scope
)
{
VLOG
(
3
)
<<
"Predictor::init()"
;
if
(
config_
.
use_gpu
)
{
place_
=
paddle
::
platform
::
CUDAPlace
(
config_
.
device
);
}
else
{
place_
=
paddle
::
platform
::
CPUPlace
();
}
PADDLE_ENFORCE
(
!
parent_scope
);
if
(
parent_scope
)
{
scope_
=
parent_scope
;
sub_scope_
=
&
(
parent_scope
->
NewScope
());
}
else
{
paddle
::
framework
::
InitDevices
(
false
);
scope_
.
reset
(
new
paddle
::
framework
::
Scope
());
}
executor_
.
reset
(
new
paddle
::
framework
::
Executor
(
place_
));
// Initialize the inference program
if
(
!
config_
.
model_dir
.
empty
())
{
// Parameters are saved in separate files sited in
// the specified `dirname`.
inference_program_
=
paddle
::
inference
::
Load
(
executor_
.
get
(),
scope_
.
get
(),
config_
.
model_dir
);
}
else
if
(
!
config_
.
prog_file
.
empty
()
&&
!
config_
.
param_file
.
empty
())
{
// All parameters are saved in a single file.
// The file names should be consistent with that used
// in Python API `fluid.io.save_inference_model`.
inference_program_
=
paddle
::
inference
::
Load
(
executor_
.
get
(),
scope_
.
get
(),
config_
.
prog_file
,
config_
.
param_file
);
}
else
{
LOG
(
ERROR
)
<<
"fail to load inference model."
;
return
false
;
}
OptimizeInferenceProgram
();
ctx_
=
executor_
->
Prepare
(
*
inference_program_
,
0
);
VLOG
(
5
)
<<
"to create variables"
;
PADDLE_ENFORCE
(
scope_
.
get
());
executor_
->
CreateVariables
(
*
inference_program_
,
sub_scope_
?
sub_scope_
:
scope_
.
get
(),
0
);
// Get the feed_target_names and fetch_target_names
feed_target_names_
=
inference_program_
->
GetFeedTargetNames
();
fetch_target_names_
=
inference_program_
->
GetFetchTargetNames
();
return
true
;
}
bool
Run
(
const
std
::
vector
<
PaddleTensor
>&
inputs
,
std
::
vector
<
PaddleTensor
>*
output_data
,
int
batch_size
=
-
1
)
override
{
return
NativePaddlePredictor
::
Run
(
inputs
,
output_data
,
batch_size
);
}
void
OptimizeInferenceProgram
()
{
LOG
(
INFO
)
<<
"optimize begin"
;
FLAGS_IA_enable_ir
=
true
;
FLAGS_IA_enable_tensorrt_subgraph_engine
=
false
;
FLAGS_IA_output_storage_path
=
""
;
// Don't output the model.
// Analyze inference_program
Argument
argument
;
if
(
!
config_
.
model_dir
.
empty
())
{
argument
.
fluid_model_dir
.
reset
(
new
std
::
string
(
config_
.
model_dir
));
}
else
{
PADDLE_ENFORCE
(
!
config_
.
param_file
.
empty
(),
"Either model_dir or (param_file, prog_file) should be set."
);
PADDLE_ENFORCE
(
!
config_
.
prog_file
.
empty
());
argument
.
fluid_model_program_path
.
reset
(
new
std
::
string
(
config_
.
prog_file
));
argument
.
fluid_model_param_path
.
reset
(
new
std
::
string
(
config_
.
param_file
));
}
argument
.
origin_program_desc
.
reset
(
new
ProgramDesc
(
*
inference_program_
->
Proto
()));
Singleton
<
Analyzer
>::
Global
().
Run
(
&
argument
);
CHECK
(
argument
.
transformed_program_desc
);
VLOG
(
5
)
<<
"to prepare executor"
;
// LOG(INFO) << "transformed_parogram_desc " <<
// argument.transformed_program_desc->DebugString();
inference_program_
.
reset
(
new
framework
::
ProgramDesc
(
*
argument
.
transformed_program_desc
));
PADDLE_ENFORCE
(
argument
.
Has
(
"param_scope"
));
// Update scope.
scope_
.
reset
(
argument
.
Release
<
framework
::
Scope
>
(
"param_scope"
));
LOG
(
INFO
)
<<
"optimize end =="
;
}
private:
NativeConfig
config_
;
};
template
<
>
std
::
unique_ptr
<
PaddlePredictor
>
CreatePaddlePredictor
<
NativeConfig
,
PaddleEngineKind
::
kAnalysis
>
(
const
NativeConfig
&
config
)
{
VLOG
(
3
)
<<
"create NativePredictor"
;
if
(
config
.
use_gpu
)
{
// 1. GPU memeroy
PADDLE_ENFORCE_GT
(
config
.
fraction_of_gpu_memory
,
0.
f
,
"fraction_of_gpu_memory in the config should be set to range (0., 1.]"
);
PADDLE_ENFORCE_GE
(
config
.
device
,
0
,
"Invalid device id %d"
,
config
.
device
);
std
::
vector
<
std
::
string
>
flags
;
if
(
config
.
fraction_of_gpu_memory
>=
0.0
f
||
config
.
fraction_of_gpu_memory
<=
0.95
f
)
{
flags
.
push_back
(
"dummpy"
);
std
::
string
flag
=
"--fraction_of_gpu_memory_to_use="
+
std
::
to_string
(
config
.
fraction_of_gpu_memory
);
flags
.
push_back
(
flag
);
VLOG
(
3
)
<<
"set flag: "
<<
flag
;
framework
::
InitGflags
(
flags
);
}
}
std
::
unique_ptr
<
PaddlePredictor
>
predictor
(
new
AnalysisPredictor
(
config
));
if
(
!
dynamic_cast
<
AnalysisPredictor
*>
(
predictor
.
get
())
->
Init
(
nullptr
))
{
return
nullptr
;
}
return
predictor
;
}
}
// namespace paddle
USE_PASS
(
fc_fuse_pass
);
USE_PASS
(
graph_viz_pass
);
USE_PASS
(
infer_clean_graph_pass
);
paddle/fluid/inference/api/helper.cc
0 → 100644
浏览文件 @
902f19b4
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/api/helper.h"
namespace
paddle
{
namespace
inference
{
template
<
>
std
::
string
to_string
<
std
::
vector
<
float
>>
(
const
std
::
vector
<
std
::
vector
<
float
>>
&
vec
)
{
std
::
stringstream
ss
;
for
(
const
auto
&
piece
:
vec
)
{
ss
<<
to_string
(
piece
)
<<
"
\n
"
;
}
return
ss
.
str
();
}
template
<
>
std
::
string
to_string
<
std
::
vector
<
std
::
vector
<
float
>>>
(
const
std
::
vector
<
std
::
vector
<
std
::
vector
<
float
>>>
&
vec
)
{
std
::
stringstream
ss
;
for
(
const
auto
&
line
:
vec
)
{
for
(
const
auto
&
rcd
:
line
)
{
ss
<<
to_string
(
rcd
)
<<
";
\t
"
;
}
ss
<<
'\n'
;
}
return
ss
.
str
();
}
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/api/helper.h
浏览文件 @
902f19b4
...
...
@@ -44,7 +44,8 @@ class Timer {
}
};
void
split
(
const
std
::
string
&
str
,
char
sep
,
std
::
vector
<
std
::
string
>
*
pieces
)
{
static
void
split
(
const
std
::
string
&
str
,
char
sep
,
std
::
vector
<
std
::
string
>
*
pieces
)
{
pieces
->
clear
();
if
(
str
.
empty
())
{
return
;
...
...
@@ -60,7 +61,8 @@ void split(const std::string &str, char sep, std::vector<std::string> *pieces) {
pieces
->
push_back
(
str
.
substr
(
pos
));
}
}
void
split_to_float
(
const
std
::
string
&
str
,
char
sep
,
std
::
vector
<
float
>
*
fs
)
{
static
void
split_to_float
(
const
std
::
string
&
str
,
char
sep
,
std
::
vector
<
float
>
*
fs
)
{
std
::
vector
<
std
::
string
>
pieces
;
split
(
str
,
sep
,
&
pieces
);
std
::
transform
(
pieces
.
begin
(),
pieces
.
end
(),
std
::
back_inserter
(
*
fs
),
...
...
@@ -76,27 +78,14 @@ std::string to_string(const std::vector<T> &vec) {
}
template
<
>
std
::
string
to_string
<
std
::
vector
<
float
>>
(
const
std
::
vector
<
std
::
vector
<
float
>>
&
vec
)
{
std
::
stringstream
ss
;
for
(
const
auto
&
piece
:
vec
)
{
ss
<<
to_string
(
piece
)
<<
"
\n
"
;
}
return
ss
.
str
();
}
const
std
::
vector
<
std
::
vector
<
float
>>
&
vec
);
template
<
>
std
::
string
to_string
<
std
::
vector
<
std
::
vector
<
float
>>>
(
const
std
::
vector
<
std
::
vector
<
std
::
vector
<
float
>>>
&
vec
)
{
std
::
stringstream
ss
;
for
(
const
auto
&
line
:
vec
)
{
for
(
const
auto
&
rcd
:
line
)
{
ss
<<
to_string
(
rcd
)
<<
";
\t
"
;
}
ss
<<
'\n'
;
}
return
ss
.
str
();
}
const
std
::
vector
<
std
::
vector
<
std
::
vector
<
float
>>>
&
vec
);
// clang-format off
void
TensorAssignData
(
PaddleTensor
*
tensor
,
const
std
::
vector
<
std
::
vector
<
float
>>
&
data
)
{
static
void
TensorAssignData
(
PaddleTensor
*
tensor
,
const
std
::
vector
<
std
::
vector
<
float
>>
&
data
)
{
// Assign buffer
int
dim
=
std
::
accumulate
(
tensor
->
shape
.
begin
(),
tensor
->
shape
.
end
(),
1
,
[](
int
a
,
int
b
)
{
return
a
*
b
;
});
tensor
->
data
.
Resize
(
sizeof
(
float
)
*
dim
);
...
...
paddle/fluid/inference/api/paddle_inference_api.h
浏览文件 @
902f19b4
...
...
@@ -77,6 +77,7 @@ enum class PaddleEngineKind {
kNative
=
0
,
// Use the native Fluid facility.
kAnakin
,
// Use Anakin for inference.
kAutoMixedTensorRT
,
// Automatically mix Fluid with TensorRT.
kAnalysis
// TODO(Superjomn) support following engines latter.
// kTensorRT, // Use TensorRT for inference.
// kAutoMixedAnakin, // Automatically mix Fluid with Anakin.
...
...
paddle/fluid/inference/io.cc
浏览文件 @
902f19b4
...
...
@@ -143,5 +143,21 @@ std::unique_ptr<framework::ProgramDesc> Load(
return
main_program
;
}
void
SaveVars
(
const
framework
::
Scope
&
scope
,
const
std
::
vector
<
std
::
string
>&
vars
,
const
std
::
string
&
dirname
,
bool
predicate
)
{
framework
::
ProgramDesc
prog
;
auto
*
block
=
prog
.
MutableBlock
(
0
);
auto
*
op
=
block
->
AppendOp
();
op
->
SetType
(
"save_combine"
);
op
->
SetInput
(
"X"
,
vars
);
op
->
SetAttr
(
"file_path"
,
dirname
+
"/param"
);
op
->
CheckAttrs
();
platform
::
CPUPlace
place
;
framework
::
Executor
exe
(
place
);
exe
.
Run
(
prog
,
const_cast
<
framework
::
Scope
*>
(
&
scope
),
0
,
true
,
true
);
}
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/io.h
浏览文件 @
902f19b4
...
...
@@ -41,5 +41,10 @@ std::unique_ptr<framework::ProgramDesc> Load(framework::Executor* executor,
const
std
::
string
&
prog_filename
,
const
std
::
string
&
param_filename
);
// Save the variables from a scope to disk.
void
SaveVars
(
const
framework
::
Scope
&
scope
,
const
std
::
vector
<
std
::
string
>&
vars
,
const
std
::
string
&
dirname
,
bool
predicate
=
true
);
}
// namespace inference
}
// namespace paddle
paddle/fluid/operators/attention_lstm_op.cc
浏览文件 @
902f19b4
...
...
@@ -56,7 +56,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
const
int
D
=
w_dims
[
1
]
/
4
;
PADDLE_ENFORCE_EQ
(
w_dims
.
size
(),
2
,
"Input(LSTMWeight)'s rank must be 2."
);
PADDLE_ENFORCE_EQ
(
w_dims
[
0
],
D
+
M
,
"LSTMWeight dims should be (%d + %d) * %d."
,
D
+
M
,
4
*
D
);
"LSTMWeight dims should be (%d + %d) * %d."
,
D
,
M
,
4
*
D
);
auto
b_dims
=
ctx
->
GetInputDim
(
"LSTMBias"
);
PADDLE_ENFORCE_EQ
(
b_dims
.
size
(),
2
,
"Input(LSTMBias)'s rank must be 2."
);
...
...
paddle/fluid/operators/fusion_seqexpand_concat_fc_op.cc
浏览文件 @
902f19b4
...
...
@@ -49,9 +49,14 @@ void FusionSeqExpandConcatFCOp::InferShape(
"FC height should be sum of all inputs width."
);
if
(
ctx
->
HasInput
(
"FCBias"
))
{
auto
b_dims
=
ctx
->
GetInputDim
(
"FCBias"
);
PADDLE_ENFORCE_EQ
(
b_dims
.
size
(),
2
,
"Input(FCBias)'s rank must be 2."
);
PADDLE_ENFORCE_EQ
(
b_dims
[
0
],
1
,
"FCBias shapes must be 1 * %d."
,
D
);
PADDLE_ENFORCE_EQ
(
b_dims
[
1
],
D
,
"FCBias shapes must be 1 * %d."
,
D
);
PADDLE_ENFORCE
(
b_dims
.
size
()
==
1
||
b_dims
.
size
()
==
2
,
"b_dims should be 1 or 2, get %d"
,
b_dims
.
size
());
if
(
b_dims
.
size
()
==
1
)
{
PADDLE_ENFORCE_EQ
(
b_dims
[
0
],
D
,
"FCBias shapes must be %d."
,
D
);
}
else
{
PADDLE_ENFORCE_EQ
(
b_dims
[
0
],
1
,
"FCBias shapes must be 1x%d."
,
D
);
PADDLE_ENFORCE_EQ
(
b_dims
[
1
],
D
,
"FCBias shapes must be 1x%d."
,
D
);
}
}
ctx
->
SetOutputDim
(
"Out"
,
{
ins_dims
[
0
][
0
],
D
});
...
...
paddle/fluid/platform/init.cc
浏览文件 @
902f19b4
...
...
@@ -85,9 +85,6 @@ void InitDevices(bool init_p2p) {
}
catch
(
const
std
::
exception
&
exp
)
{
LOG
(
WARNING
)
<<
"Compiled with WITH_GPU, but no GPU found in runtime."
;
}
#else
LOG
(
WARNING
)
<<
"'CUDA' is not supported, Please re-compile with WITH_GPU option"
;
#endif
InitDevices
(
init_p2p
,
devices
);
}
...
...
@@ -101,9 +98,6 @@ void InitDevices(bool init_p2p, const std::vector<int> devices) {
}
catch
(
const
std
::
exception
&
exp
)
{
LOG
(
WARNING
)
<<
"Compiled with WITH_GPU, but no GPU found in runtime."
;
}
#else
LOG
(
WARNING
)
<<
"'CUDA' is not supported, Please re-compile with WITH_GPU option"
;
#endif
for
(
size_t
i
=
0
;
i
<
devices
.
size
();
++
i
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录