Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
af15f6f0
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
af15f6f0
编写于
8月 31, 2018
作者:
Y
Yan Chunwei
提交者:
GitHub
8月 31, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fea/refine fuse (#13076)
上级
819af27d
变更
20
隐藏空白更改
内联
并排
Showing
20 changed file
with
545 addition
and
226 deletion
+545
-226
paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc
paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc
+1
-1
paddle/fluid/framework/ir/fc_fuse_pass.cc
paddle/fluid/framework/ir/fc_fuse_pass.cc
+37
-75
paddle/fluid/framework/ir/fc_fuse_pass.h
paddle/fluid/framework/ir/fc_fuse_pass.h
+2
-1
paddle/fluid/framework/ir/fc_fuse_pass_tester.cc
paddle/fluid/framework/ir/fc_fuse_pass_tester.cc
+7
-2
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
+4
-4
paddle/fluid/framework/ir/fuse_pass_base.h
paddle/fluid/framework/ir/fuse_pass_base.h
+18
-2
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+180
-1
paddle/fluid/framework/ir/graph_pattern_detector.h
paddle/fluid/framework/ir/graph_pattern_detector.h
+81
-15
paddle/fluid/framework/ir/graph_pattern_detector_tester.cc
paddle/fluid/framework/ir/graph_pattern_detector_tester.cc
+33
-0
paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc
paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc
+5
-5
paddle/fluid/inference/analysis/analyzer.cc
paddle/fluid/inference/analysis/analyzer.cc
+0
-1
paddle/fluid/inference/analysis/analyzer_tester.cc
paddle/fluid/inference/analysis/analyzer_tester.cc
+19
-0
paddle/fluid/inference/analysis/argument.h
paddle/fluid/inference/analysis/argument.h
+2
-1
paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc
...fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc
+2
-2
paddle/fluid/inference/analysis/fluid_to_ir_pass.cc
paddle/fluid/inference/analysis/fluid_to_ir_pass.cc
+4
-3
paddle/fluid/inference/analysis/fluid_to_ir_pass.h
paddle/fluid/inference/analysis/fluid_to_ir_pass.h
+18
-9
paddle/fluid/inference/analysis/ir_pass_manager.cc
paddle/fluid/inference/analysis/ir_pass_manager.cc
+3
-1
paddle/fluid/inference/api/analysis_predictor.cc
paddle/fluid/inference/api/analysis_predictor.cc
+77
-102
paddle/fluid/inference/api/analysis_predictor.h
paddle/fluid/inference/api/analysis_predictor.h
+51
-0
paddle/fluid/inference/api/api_impl.cc
paddle/fluid/inference/api/api_impl.cc
+1
-1
未找到文件。
paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc
浏览文件 @
af15f6f0
...
@@ -59,7 +59,7 @@ void FindWhileOp(Graph* graph) {
...
@@ -59,7 +59,7 @@ void FindWhileOp(Graph* graph) {
auto
handle
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
auto
handle
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
Graph
*
g
)
{
auto
*
while_pat_node
=
gpd
.
pattern
().
RetriveNode
(
"while"
);
auto
*
while_pat_node
=
gpd
.
pattern
().
Retri
e
veNode
(
"while"
);
auto
*
while_node
=
subgraph
.
at
(
while_pat_node
);
auto
*
while_node
=
subgraph
.
at
(
while_pat_node
);
marked_nodes
.
insert
(
while_node
);
marked_nodes
.
insert
(
while_node
);
};
};
...
...
paddle/fluid/framework/ir/fc_fuse_pass.cc
浏览文件 @
af15f6f0
...
@@ -31,77 +31,34 @@ bool VarOutLinksToOp(Node* node, const std::string& op_type) {
...
@@ -31,77 +31,34 @@ bool VarOutLinksToOp(Node* node, const std::string& op_type) {
}
}
void
BuildFCPattern
(
PDPattern
*
pattern
)
{
void
BuildFCPattern
(
PDPattern
*
pattern
)
{
// make sure the selected MUL op has one input argument is a parameter.
// Create Operators
auto
*
mul_parameter_var
=
pattern
->
NewNode
(
auto
*
mul_op
=
pattern
->
NewNode
(
"mul"
)
->
assert_is_op
(
"mul"
);
[](
Node
*
node
)
{
auto
*
elementwise_add_op
=
return
node
->
IsVar
()
&&
node
->
outputs
.
size
()
==
1UL
&&
pattern
->
NewNode
(
"elementwise_add"
)
->
assert_is_op
(
"elementwise_add"
);
node
->
outputs
.
front
()
->
Op
()
->
Type
()
==
"mul"
&&
node
->
Var
()
&&
// Create variables
node
->
Var
()
->
Persistable
();
// check is a parameter
// w
},
auto
*
mul_weight_var
=
pattern
->
NewNode
(
"mul_weight"
)
"mul_weight"
/*name*/
);
->
AsInput
()
->
assert_is_op_nth_input
(
"mul"
,
"Y"
,
0
);
auto
*
mul_tmp_input_var
=
pattern
->
NewNode
(
// x
[](
Node
*
node
)
{
auto
*
mul_tmp_var
=
pattern
->
NewNode
(
"mul_tmp_var"
)
bool
result
=
->
AsInput
()
node
->
IsVar
()
&&
node
->
outputs
.
size
()
>=
1UL
&&
node
->
Var
()
&&
->
assert_is_op_nth_input
(
"mul"
,
"X"
,
0
);
!
node
->
Var
()
->
Persistable
();
// this input is not an parameter.
// intermediate variable, will be removed in the IR after fuse.
if
(
!
result
)
return
false
;
auto
*
mul_out_var
=
pattern
->
NewNode
(
"mul_out"
)
// check whether one output is MUL op.
->
AsIntermediate
()
for
(
auto
*
op
:
node
->
outputs
)
{
->
assert_is_only_output_of_op
(
"mul"
)
if
(
op
->
IsOp
()
&&
op
->
Op
()
->
Type
()
==
"mul"
)
return
true
;
->
assert_is_op_input
(
"elementwise_add"
);
}
// bias
return
false
;
auto
*
elementwise_add_tmp_var
=
pattern
->
NewNode
(
"elementwise_add_tmpvar"
)
},
->
assert_is_op_input
(
"elementwise_add"
)
"mul_tmp_var"
/*name*/
);
->
AsInput
();
// output
// select a MUL op
auto
*
elementwise_add_out_var
=
pattern
->
NewNode
(
"elementwise_add_out"
)
auto
*
mul_op
=
pattern
->
NewNode
(
->
AsOutput
()
[](
Node
*
node
)
{
->
assert_is_op_output
(
"elementwise_add"
);
return
node
->
IsOp
()
&&
// start from an Op
node
->
Op
()
->
Type
()
==
"mul"
;
// type is mul
mul_op
->
LinksFrom
({
mul_weight_var
,
mul_tmp_var
}).
LinksTo
({
mul_out_var
});
// the output should be consumed only by one element_add, that check
// leaves in a Var PDNode.
},
"mul"
/*name*/
);
// make sure the MUL op's output has only one consumer and links to an
// ELEMENTWISE_ADD op.
auto
*
mul_out_var
=
pattern
->
NewNode
(
[](
Node
*
node
)
{
return
node
->
IsVar
()
&&
// starts from a Var
node
->
outputs
.
size
()
==
1UL
&&
// only has one consumer
node
->
outputs
.
front
()
->
IsOp
()
&&
// check basic logic
node
->
Var
()
&&
// not a ControlDepVar
node
->
outputs
.
front
()
->
Op
()
->
Type
()
==
"elementwise_add"
;
// a very strong validation
},
"mul_out"
);
// this check is not essential, just to make the corresponding variable Node
// retrival easier.
auto
*
elementwise_add_tmp_var
=
pattern
->
NewNode
(
[](
Node
*
node
)
{
return
node
->
IsVar
()
&&
node
->
outputs
.
size
()
>=
1UL
&&
node
->
Var
()
&&
VarOutLinksToOp
(
node
,
"elementwise_add"
);
},
"elementwise_add_tmpvar"
);
// select an ELEMENTWISE_ADD op
auto
*
elementwise_add_op
=
pattern
->
NewNode
(
[](
Node
*
node
)
{
return
node
->
IsOp
()
&&
node
->
Op
()
->
Type
()
==
"elementwise_add"
;
},
"elementwise_add"
/*name*/
);
// get the ELEMENTWISE_ADD op's output
auto
*
elementwise_add_out_var
=
pattern
->
NewNode
(
[](
Node
*
node
)
{
return
node
->
IsVar
()
&&
node
->
inputs
.
size
()
==
1UL
&&
node
->
Var
()
&&
node
->
inputs
.
front
()
->
Op
()
->
Type
()
==
"elementwise_add"
;
},
"elementwise_add_out"
);
mul_op
->
LinksFrom
({
mul_parameter_var
,
mul_tmp_input_var
})
.
LinksTo
({
mul_out_var
});
elementwise_add_op
->
LinksFrom
({
mul_out_var
,
elementwise_add_tmp_var
})
elementwise_add_op
->
LinksFrom
({
mul_out_var
,
elementwise_add_tmp_var
})
.
LinksTo
({
elementwise_add_out_var
});
.
LinksTo
({
elementwise_add_out_var
});
}
}
...
@@ -120,18 +77,20 @@ bool LinksReplace(std::vector<Node*>* links, Node* from, Node* to) {
...
@@ -120,18 +77,20 @@ bool LinksReplace(std::vector<Node*>* links, Node* from, Node* to) {
std
::
unique_ptr
<
ir
::
Graph
>
FCFusePass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
FCFusePass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
PADDLE_ENFORCE
(
graph
.
get
());
PADDLE_ENFORCE
(
graph
.
get
());
FusePassBase
::
Init
(
"fc"
,
graph
.
get
());
std
::
unordered_set
<
Node
*>
nodes2delete
;
std
::
unordered_set
<
Node
*>
nodes2delete
;
GraphPatternDetector
gpd
;
GraphPatternDetector
gpd
;
BuildFCPattern
(
gpd
.
mutable_pattern
());
BuildFCPattern
(
gpd
.
mutable_pattern
());
#define GET_NODE(id) \
#define GET_NODE(id)
\
PADDLE_ENFORCE(subgraph.count(gpd.pattern().RetriveNode(#id)), \
PADDLE_ENFORCE(subgraph.count(gpd.pattern().Retri
e
veNode(#id)), \
"pattern has no Node called %s", #id); \
"pattern has no Node called %s", #id);
\
auto* id = subgraph.at(gpd.pattern().RetriveNode(#id)); \
auto* id = subgraph.at(gpd.pattern().Retri
e
veNode(#id)); \
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
int
found_fc_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
Graph
*
g
)
{
VLOG
(
4
)
<<
"handle FC fuse"
;
VLOG
(
4
)
<<
"handle FC fuse"
;
...
@@ -176,10 +135,13 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
...
@@ -176,10 +135,13 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
graph
->
RemoveNode
(
mul
);
graph
->
RemoveNode
(
mul
);
graph
->
RemoveNode
(
elementwise_add
);
graph
->
RemoveNode
(
elementwise_add
);
graph
->
RemoveNode
(
mul_out
);
// tmp variable
graph
->
RemoveNode
(
mul_out
);
// tmp variable
found_fc_count
++
;
};
};
gpd
(
graph
.
get
(),
handler
);
gpd
(
graph
.
get
(),
handler
);
AddStatis
(
found_fc_count
);
return
graph
;
return
graph
;
}
}
...
...
paddle/fluid/framework/ir/fc_fuse_pass.h
浏览文件 @
af15f6f0
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass.h"
...
@@ -23,7 +24,7 @@ namespace ir {
...
@@ -23,7 +24,7 @@ namespace ir {
/*
/*
* Fuse the MUL and ELEMENTWISE_ADD to a FCOp.
* Fuse the MUL and ELEMENTWISE_ADD to a FCOp.
*/
*/
class
FCFusePass
:
public
Pass
{
class
FCFusePass
:
public
FusePassBase
{
public:
public:
virtual
~
FCFusePass
()
{}
virtual
~
FCFusePass
()
{}
...
...
paddle/fluid/framework/ir/fc_fuse_pass_tester.cc
浏览文件 @
af15f6f0
...
@@ -25,8 +25,13 @@ void SetOp(ProgramDesc* prog, const std::string& type,
...
@@ -25,8 +25,13 @@ void SetOp(ProgramDesc* prog, const std::string& type,
const
std
::
vector
<
std
::
string
>&
outputs
)
{
const
std
::
vector
<
std
::
string
>&
outputs
)
{
auto
*
op
=
prog
->
MutableBlock
(
0
)
->
AppendOp
();
auto
*
op
=
prog
->
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
type
);
op
->
SetType
(
type
);
op
->
SetInput
(
"Xs"
,
inputs
);
if
(
type
==
"mul"
)
{
op
->
SetOutput
(
"Ys"
,
outputs
);
op
->
SetInput
(
"X"
,
{
inputs
[
0
]});
op
->
SetInput
(
"Y"
,
{
inputs
[
1
]});
}
else
if
(
type
==
"elementwise_add"
)
{
op
->
SetInput
(
"X"
,
inputs
);
}
op
->
SetOutput
(
"Out"
,
outputs
);
}
}
// a->OP0->b
// a->OP0->b
...
...
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
浏览文件 @
af15f6f0
...
@@ -36,7 +36,7 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
...
@@ -36,7 +36,7 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
Graph
*
g
)
{
auto
*
id
=
subgraph
.
at
(
gpd
.
pattern
().
RetriveNode
(
"any_node"
));
auto
*
id
=
subgraph
.
at
(
gpd
.
pattern
().
Retri
e
veNode
(
"any_node"
));
marked_nodes
.
insert
(
id
);
marked_nodes
.
insert
(
id
);
};
};
gpd
(
graph
.
get
(),
handler
);
gpd
(
graph
.
get
(),
handler
);
...
@@ -64,9 +64,9 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
...
@@ -64,9 +64,9 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
#undef GET_NODE
#undef GET_NODE
#undef SET_IN
#undef SET_IN
LOG
(
INFO
)
<<
"hidden_n: "
<<
hidden_n
->
Name
();
VLOG
(
4
)
<<
"hidden_n: "
<<
hidden_n
->
Name
();
LOG
(
INFO
)
<<
"cell: "
<<
cell_n
->
Name
();
VLOG
(
4
)
<<
"cell: "
<<
cell_n
->
Name
();
LOG
(
INFO
)
<<
"xx: "
<<
xx_n
->
Name
();
VLOG
(
4
)
<<
"xx: "
<<
xx_n
->
Name
();
op_desc
.
SetInput
(
"H0"
,
{});
op_desc
.
SetInput
(
"H0"
,
{});
op_desc
.
SetInput
(
"C0"
,
{});
op_desc
.
SetInput
(
"C0"
,
{});
...
...
paddle/fluid/framework/ir/fuse_pass_base.h
浏览文件 @
af15f6f0
...
@@ -22,21 +22,37 @@ namespace paddle {
...
@@ -22,21 +22,37 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
ir
{
namespace
ir
{
static
const
char
kParamScopeAttr
[]
=
"param_scope"
;
static
const
char
kParamScopeAttr
[]
=
"__param_scope__"
;
static
const
char
kFuseStatisAttr
[]
=
"__fuse_statis__"
;
class
FusePassBase
:
public
Pass
{
class
FusePassBase
:
public
Pass
{
public:
public:
void
Init
(
Graph
*
graph
)
const
{
graph_
=
graph
;
}
void
Init
(
const
std
::
string
&
repr
,
Graph
*
graph
)
const
{
repr_
=
repr
;
graph_
=
graph
;
}
Scope
*
param_scope
()
const
{
Scope
*
param_scope
()
const
{
PADDLE_ENFORCE
(
graph_
->
Has
(
kParamScopeAttr
));
PADDLE_ENFORCE
(
graph_
->
Has
(
kParamScopeAttr
));
return
graph_
->
Get
<
framework
::
Scope
*>
(
kParamScopeAttr
);
return
graph_
->
Get
<
framework
::
Scope
*>
(
kParamScopeAttr
);
}
}
void
AddStatis
(
int
count_of_fused
)
const
{
PADDLE_ENFORCE
(
graph_
);
PADDLE_ENFORCE
(
!
repr_
.
empty
());
if
(
!
graph_
->
Has
(
kFuseStatisAttr
))
{
graph_
->
Set
(
kFuseStatisAttr
,
new
std
::
unordered_map
<
std
::
string
,
int
>
);
}
auto
&
info
=
graph_
->
Get
<
std
::
unordered_map
<
std
::
string
,
int
>>
(
kFuseStatisAttr
);
info
[
repr_
]
=
count_of_fused
;
}
virtual
~
FusePassBase
()
{}
virtual
~
FusePassBase
()
{}
protected:
protected:
mutable
Graph
*
graph_
;
mutable
Graph
*
graph_
;
mutable
std
::
string
repr_
;
};
};
}
// namespace ir
}
// namespace ir
...
...
paddle/fluid/framework/ir/graph_pattern_detector.cc
浏览文件 @
af15f6f0
...
@@ -27,6 +27,19 @@ namespace ir {
...
@@ -27,6 +27,19 @@ namespace ir {
size_t
PDPattern
::
id_
=
0UL
;
size_t
PDPattern
::
id_
=
0UL
;
PDNode
*
PDPattern
::
NewNode
(
const
std
::
string
&
name
)
{
if
(
!
name
.
empty
())
{
PADDLE_ENFORCE_EQ
(
node_map_
.
count
(
name
),
0
,
"PDNode's name should be unique, get duplicate [%s]"
,
name
);
}
nodes_
.
emplace_back
(
new
PDNode
(
this
,
name
));
auto
*
cur
=
nodes_
.
back
().
get
();
node_map_
[
name
]
=
cur
;
return
cur
;
}
PDNode
*
PDPattern
::
NewNode
(
PDNode
::
teller_t
&&
teller
,
const
std
::
string
&
name
)
{
PDNode
*
PDPattern
::
NewNode
(
PDNode
::
teller_t
&&
teller
,
const
std
::
string
&
name
)
{
if
(
!
name
.
empty
())
{
if
(
!
name
.
empty
())
{
PADDLE_ENFORCE_EQ
(
node_map_
.
count
(
name
),
0
,
PADDLE_ENFORCE_EQ
(
node_map_
.
count
(
name
),
0
,
...
@@ -40,7 +53,7 @@ PDNode* PDPattern::NewNode(PDNode::teller_t&& teller, const std::string& name) {
...
@@ -40,7 +53,7 @@ PDNode* PDPattern::NewNode(PDNode::teller_t&& teller, const std::string& name) {
return
cur
;
return
cur
;
}
}
PDNode
*
PDPattern
::
RetriveNode
(
const
std
::
string
&
id
)
const
{
PDNode
*
PDPattern
::
Retri
e
veNode
(
const
std
::
string
&
id
)
const
{
auto
it
=
node_map_
.
find
(
id
);
auto
it
=
node_map_
.
find
(
id
);
if
(
it
==
node_map_
.
end
())
{
if
(
it
==
node_map_
.
end
())
{
return
nullptr
;
return
nullptr
;
...
@@ -62,7 +75,9 @@ void GraphPatternDetector::operator()(Graph* graph,
...
@@ -62,7 +75,9 @@ void GraphPatternDetector::operator()(Graph* graph,
auto
subgraphs
=
DetectPatterns
();
auto
subgraphs
=
DetectPatterns
();
UniquePatterns
(
&
subgraphs
);
UniquePatterns
(
&
subgraphs
);
RemoveOverlappedMatch
(
&
subgraphs
);
RemoveOverlappedMatch
(
&
subgraphs
);
ValidateByNodeRole
(
&
subgraphs
);
if
(
subgraphs
.
empty
())
return
;
LOG
(
INFO
)
<<
"detect "
<<
subgraphs
.
size
()
<<
" subgraph matches the pattern"
;
LOG
(
INFO
)
<<
"detect "
<<
subgraphs
.
size
()
<<
" subgraph matches the pattern"
;
int
id
=
0
;
int
id
=
0
;
for
(
auto
&
g
:
subgraphs
)
{
for
(
auto
&
g
:
subgraphs
)
{
...
@@ -83,10 +98,54 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) {
...
@@ -83,10 +98,54 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) {
}
}
}
}
}
}
// Check to early stop if some PDNode can't find matched Node.
for
(
auto
&
pdnode
:
pattern_
.
nodes
())
{
if
(
!
pdnodes2nodes_
.
count
(
pdnode
.
get
()))
{
VLOG
(
4
)
<<
pdnode
->
name
()
<<
" can't find matched Node, early stop"
;
return
false
;
}
}
VLOG
(
3
)
<<
pdnodes2nodes_
.
size
()
<<
" nodes marked"
;
VLOG
(
3
)
<<
pdnodes2nodes_
.
size
()
<<
" nodes marked"
;
return
!
pdnodes2nodes_
.
empty
();
return
!
pdnodes2nodes_
.
empty
();
}
}
// The intermediate Nodes can only link to the nodes inside the pattern, or this
// subgraph will be droped.
void
GraphPatternDetector
::
ValidateByNodeRole
(
std
::
vector
<
GraphPatternDetector
::
subgraph_t
>*
subgraphs
)
{
std
::
vector
<
GraphPatternDetector
::
subgraph_t
>
result
;
subgraphs
->
erase
(
std
::
remove_if
(
subgraphs
->
begin
(),
subgraphs
->
end
(),
[](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
)
->
bool
{
// Collect the inputs and outputs.
std
::
unordered_set
<
Node
*>
ios
;
for
(
auto
&
item
:
subgraph
)
{
if
(
!
item
.
first
->
IsIntermediate
())
{
ios
.
insert
(
item
.
second
);
}
}
for
(
auto
&
item
:
subgraph
)
{
if
(
item
.
first
->
IsIntermediate
())
{
for
(
auto
*
x
:
item
.
second
->
inputs
)
{
if
(
!
ios
.
count
(
x
))
{
return
true
;
}
}
for
(
auto
*
x
:
item
.
second
->
outputs
)
{
if
(
!
ios
.
count
(
x
))
{
return
true
;
}
}
}
}
return
false
;
}),
subgraphs
->
end
());
}
struct
HitGroup
{
struct
HitGroup
{
std
::
unordered_map
<
PDNode
*
,
Node
*>
roles
;
std
::
unordered_map
<
PDNode
*
,
Node
*>
roles
;
...
@@ -140,6 +199,7 @@ GraphPatternDetector::DetectPatterns() {
...
@@ -140,6 +199,7 @@ GraphPatternDetector::DetectPatterns() {
// in edges of PDNodes.
// in edges of PDNodes.
for
(
const
auto
&
edge
:
pattern_
.
edges
())
{
for
(
const
auto
&
edge
:
pattern_
.
edges
())
{
VLOG
(
4
)
<<
"check "
<<
edge
.
first
->
name
()
<<
" -> "
<<
edge
.
second
->
name
();
VLOG
(
4
)
<<
"check "
<<
edge
.
first
->
name
()
<<
" -> "
<<
edge
.
second
->
name
();
// TODO(Superjomn) Fix bug here, the groups might be duplicate here.
// Each role has two PDNodes, which indicates two roles.
// Each role has two PDNodes, which indicates two roles.
// Detect two Nodes that can match these two roles and they are connected.
// Detect two Nodes that can match these two roles and they are connected.
auto
&
pre_groups
=
bi_records
[
step
%
2
];
auto
&
pre_groups
=
bi_records
[
step
%
2
];
...
@@ -149,6 +209,7 @@ GraphPatternDetector::DetectPatterns() {
...
@@ -149,6 +209,7 @@ GraphPatternDetector::DetectPatterns() {
// source -> target
// source -> target
for
(
Node
*
source
:
pdnodes2nodes_
[
edge
.
first
])
{
for
(
Node
*
source
:
pdnodes2nodes_
[
edge
.
first
])
{
for
(
Node
*
target
:
pdnodes2nodes_
[
edge
.
second
])
{
for
(
Node
*
target
:
pdnodes2nodes_
[
edge
.
second
])
{
VLOG
(
8
)
<<
"check "
<<
source
->
id
()
<<
" -- "
<<
target
->
id
();
// TODO(Superjomn) add some prune strategies.
// TODO(Superjomn) add some prune strategies.
for
(
const
auto
&
group
:
pre_groups
)
{
for
(
const
auto
&
group
:
pre_groups
)
{
HitGroup
new_group
=
group
;
HitGroup
new_group
=
group
;
...
@@ -165,6 +226,12 @@ GraphPatternDetector::DetectPatterns() {
...
@@ -165,6 +226,12 @@ GraphPatternDetector::DetectPatterns() {
}
}
}
}
VLOG
(
3
)
<<
"step "
<<
step
<<
" get records: "
<<
cur_groups
.
size
();
VLOG
(
3
)
<<
"step "
<<
step
<<
" get records: "
<<
cur_groups
.
size
();
for
(
auto
&
group
:
cur_groups
)
{
for
(
auto
&
item
:
group
.
roles
)
{
VLOG
(
4
)
<<
"node "
<<
item
.
second
->
id
()
<<
" as "
<<
item
.
first
->
name
();
}
VLOG
(
4
)
<<
"========================================================="
;
}
}
}
for
(
auto
&
group
:
bi_records
[
step
%
2
])
{
for
(
auto
&
group
:
bi_records
[
step
%
2
])
{
...
@@ -260,6 +327,118 @@ PDNode& PDNode::LinksFrom(const std::vector<PDNode*>& others) {
...
@@ -260,6 +327,118 @@ PDNode& PDNode::LinksFrom(const std::vector<PDNode*>& others) {
return
*
this
;
return
*
this
;
}
}
PDNode
*
PDNode
::
assert_is_op
()
{
asserts_
.
emplace_back
([
this
](
Node
*
x
)
{
return
x
&&
x
->
IsOp
();
});
return
this
;
}
PDNode
*
PDNode
::
assert_is_op
(
const
std
::
string
&
op_type
)
{
asserts_
.
emplace_back
([
this
,
op_type
](
Node
*
x
)
{
return
x
&&
x
->
IsOp
()
&&
x
->
Op
()
->
Type
()
==
op_type
;
});
return
this
;
}
PDNode
*
PDNode
::
assert_is_var
()
{
asserts_
.
emplace_back
([
this
](
Node
*
x
)
{
return
x
&&
x
->
IsVar
();
});
return
this
;
}
PDNode
*
PDNode
::
assert_var_not_persistable
()
{
assert_is_var
();
asserts_
.
emplace_back
([
this
](
Node
*
x
)
{
return
!
x
->
Var
()
->
Persistable
();
});
return
this
;
}
PDNode
*
PDNode
::
assert_is_persistable_var
()
{
assert_is_var
();
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
return
x
->
Var
()
->
Persistable
();
});
return
this
;
}
PDNode
*
PDNode
::
assert_is_op_nth_input
(
const
std
::
string
&
op_type
,
const
std
::
string
&
argument
,
int
nth
)
{
assert_is_var
();
assert_is_op_input
(
op_type
);
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
for
(
auto
*
op
:
x
->
outputs
)
{
if
(
IsNthInput
(
x
,
op
,
argument
,
nth
))
return
true
;
}
return
false
;
});
return
this
;
}
PDNode
*
PDNode
::
assert_is_op_nth_output
(
const
std
::
string
&
op_type
,
const
std
::
string
&
argument
,
int
nth
)
{
assert_is_var
();
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
for
(
auto
*
op
:
x
->
inputs
)
{
if
(
IsNthOutput
(
x
,
op
,
argument
,
nth
))
return
true
;
}
return
false
;
});
return
this
;
}
PDNode
*
PDNode
::
assert_is_only_input_of_op
(
const
std
::
string
&
op_type
)
{
assert_is_var
();
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
for
(
auto
*
op
:
x
->
outputs
)
{
if
(
op
&&
op
->
IsOp
()
&&
op
->
Op
()
&&
op
->
Op
()
->
Type
()
==
op_type
&&
op
->
inputs
.
size
()
==
1
)
{
return
true
;
}
}
return
false
;
});
return
this
;
}
PDNode
*
PDNode
::
assert_is_only_output_of_op
(
const
std
::
string
&
op_type
)
{
assert_is_var
();
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
for
(
auto
*
op
:
x
->
inputs
)
{
if
(
op
&&
op
->
IsOp
()
&&
op
->
Op
()
&&
op
->
Op
()
->
Type
()
==
op_type
&&
op
->
outputs
.
size
()
==
1
)
{
return
true
;
}
}
return
false
;
});
return
this
;
}
PDNode
*
PDNode
::
assert_is_op_output
(
const
std
::
string
&
op_type
)
{
assert_is_var
();
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
for
(
auto
*
op
:
x
->
inputs
)
{
if
(
op
&&
op
->
IsOp
()
&&
op
->
Op
()
&&
op
->
Op
()
->
Type
()
==
op_type
)
{
return
true
;
}
}
return
false
;
});
return
this
;
}
PDNode
*
PDNode
::
assert_is_op_input
(
const
std
::
string
&
op_type
)
{
assert_is_var
();
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
for
(
auto
*
op
:
x
->
outputs
)
{
if
(
op
&&
op
->
IsOp
()
&&
op
->
Op
()
&&
op
->
Op
()
->
Type
()
==
op_type
)
{
return
true
;
}
}
return
false
;
});
return
this
;
}
PDNode
*
PDNode
::
assert_op_has_n_inputs
(
const
std
::
string
&
op_type
,
size_t
n
)
{
assert_is_op
(
op_type
);
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
return
x
->
inputs
.
size
()
==
n
;
});
return
this
;
}
PDNode
*
PDNode
::
assert_op_has_n_outputs
(
const
std
::
string
&
op_type
,
size_t
n
)
{
assert_is_op
(
op_type
);
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
return
x
->
outputs
.
size
()
==
n
;
});
return
this
;
}
PDNode
*
PDNode
::
assert_more
(
PDNode
::
teller_t
&&
teller
)
{
asserts_
.
emplace_back
(
std
::
move
(
teller
));
return
this
;
}
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/ir/graph_pattern_detector.h
浏览文件 @
af15f6f0
...
@@ -39,14 +39,24 @@ struct PDNode {
...
@@ -39,14 +39,24 @@ struct PDNode {
// tell whether an ir::Node* is a candidation for a PDNode.
// tell whether an ir::Node* is a candidation for a PDNode.
using
teller_t
=
std
::
function
<
bool
(
Node
*
)
>
;
using
teller_t
=
std
::
function
<
bool
(
Node
*
)
>
;
enum
class
Type
{
kOp
,
kVar
};
enum
class
Type
{
kOp
,
kVar
};
enum
class
Role
{
kUnknown
,
// No role,
kInput
,
// an input and will be retained,
kOutput
,
// an output and will be retained,
kIntermediate
// will be removed after handler.
};
// this link to others
// this link to others
PDNode
&
LinksTo
(
const
std
::
vector
<
PDNode
*>&
others
);
PDNode
&
LinksTo
(
const
std
::
vector
<
PDNode
*>&
others
);
PDNode
&
LinksFrom
(
const
std
::
vector
<
PDNode
*>&
others
);
PDNode
&
LinksFrom
(
const
std
::
vector
<
PDNode
*>&
others
);
bool
Tell
(
Node
*
node
)
const
{
bool
Tell
(
Node
*
node
)
const
{
PADDLE_ENFORCE
(
teller_
!=
nullptr
,
"teller should be set for a PDNode"
);
if
(
teller_
)
return
teller_
(
node
);
return
teller_
(
node
);
for
(
auto
&
asrt
:
asserts_
)
{
if
(
!
asrt
(
node
))
return
false
;
}
return
true
;
}
}
bool
IsOp
()
const
{
return
type_
==
Type
::
kOp
;
}
bool
IsOp
()
const
{
return
type_
==
Type
::
kOp
;
}
...
@@ -54,10 +64,52 @@ struct PDNode {
...
@@ -54,10 +64,52 @@ struct PDNode {
const
std
::
string
&
name
()
const
{
return
name_
;
}
const
std
::
string
&
name
()
const
{
return
name_
;
}
PDNode
(
const
PDNode
&
)
=
delete
;
PDNode
&
operator
=
(
const
PDNode
&
)
=
delete
;
PDNode
&
operator
=
(
const
PDNode
&
)
=
delete
;
PDNode
(
const
PDNode
&
)
=
delete
;
// Mark this node is an Input of a subgraph and will be retained.
PDNode
*
AsInput
()
{
role_
=
Role
::
kInput
;
return
this
;
}
// Mark this node is an Output of a subgraph and will be retained.
PDNode
*
AsOutput
()
{
role_
=
Role
::
kOutput
;
return
this
;
}
// Mark this node will be removed, so all the links should be inside a matched
// sub-graph.
PDNode
*
AsIntermediate
()
{
role_
=
Role
::
kIntermediate
;
return
this
;
}
bool
IsIntermediate
()
const
{
return
role_
==
Role
::
kIntermediate
;
}
bool
IsInput
()
const
{
return
role_
==
Role
::
kInput
;
}
bool
IsOutput
()
const
{
return
role_
==
Role
::
kOutput
;
}
// Assertions, helper functions to simplify the pattern definition.
PDNode
*
assert_is_op
();
PDNode
*
assert_is_op
(
const
std
::
string
&
op_type
);
PDNode
*
assert_is_var
();
PDNode
*
assert_var_not_persistable
();
PDNode
*
assert_is_persistable_var
();
PDNode
*
assert_is_op_output
(
const
std
::
string
&
op_type
);
PDNode
*
assert_is_op_input
(
const
std
::
string
&
op_type
);
PDNode
*
assert_is_op_nth_input
(
const
std
::
string
&
op_type
,
const
std
::
string
&
argument
,
int
nth
);
PDNode
*
assert_is_op_nth_output
(
const
std
::
string
&
op_type
,
const
std
::
string
&
argument
,
int
nth
);
PDNode
*
assert_is_only_input_of_op
(
const
std
::
string
&
op_type
);
PDNode
*
assert_is_only_output_of_op
(
const
std
::
string
&
op_type
);
PDNode
*
assert_op_has_n_inputs
(
const
std
::
string
&
op_type
,
size_t
n
);
PDNode
*
assert_op_has_n_outputs
(
const
std
::
string
&
op_type
,
size_t
n
);
PDNode
*
assert_more
(
teller_t
&&
teller
);
private:
private:
PDNode
(
PDPattern
*
pattern
,
const
std
::
string
&
name
=
""
,
Type
type
=
Type
::
kVar
)
:
pattern_
(
pattern
),
name_
(
name
),
type_
(
type
)
{}
PDNode
(
teller_t
&&
teller
,
PDPattern
*
pattern
,
const
std
::
string
&
name
=
""
,
PDNode
(
teller_t
&&
teller
,
PDPattern
*
pattern
,
const
std
::
string
&
name
=
""
,
Type
type
=
Type
::
kVar
)
Type
type
=
Type
::
kVar
)
:
teller_
(
std
::
move
(
teller
)),
:
teller_
(
std
::
move
(
teller
)),
...
@@ -71,10 +123,13 @@ struct PDNode {
...
@@ -71,10 +123,13 @@ struct PDNode {
friend
class
PDPattern
;
friend
class
PDPattern
;
// Will removed latter.
teller_t
teller_
;
teller_t
teller_
;
std
::
vector
<
teller_t
>
asserts_
;
PDPattern
*
pattern_
;
PDPattern
*
pattern_
;
std
::
string
name_
;
std
::
string
name_
;
Type
type_
;
Type
type_
;
Role
role_
{
Role
::
kUnknown
};
};
};
/*
/*
...
@@ -87,19 +142,18 @@ struct PDNode {
...
@@ -87,19 +142,18 @@ struct PDNode {
* This pattern can be defined as with the following pseudo codes
* This pattern can be defined as with the following pseudo codes
*
*
* // Create two operator PDNodes.
* // Create two operator PDNodes.
* MUL = PDPattern.NewNode()
* MUL = PDPattern.NewNode()
.assert_is_op("mul");
* ELE = PDPattern.NewNode()
* ELE = PDPattern.NewNode()
.assert_is_op("elementwise_add");
* // Create the variable PDNodes.
* // Create the variable PDNodes.
* MUL_out = PDPattern.NewNode()
* MUL_out = PDPattern.NewNode().assert_is_op_output("mul") \
* // Add teller to define some rules that help to filter the target Nodes.
* .assert_is_op_input("elementwise_add") \
* MUL.teller = lambda(node): node->IsOp() && node->Op()->Type == "mul";
* .AsIntermediate();
* ELE.teller = lambda(node): \
* // Add relations.
* node->IsOp() && node->Op()->Type == "elementwise_add";
* MUL->LinksTo({MUL_out});
* MUL_out.teller = lambda(node): node->IsVar() && (MUL in node->inputs)
* MUL_out->LinksTo({ELE});
* && (ELE in node->outputs)
*
*
* One can add more specific
teller
s for PDNodes or edges, both the Operator
* One can add more specific
assert
s for PDNodes or edges, both the Operator
* and Variable Nodes can be ruled in PDNode.
teller
.
* and Variable Nodes can be ruled in PDNode.
assert_more(...)
.
*
*
* PDPattern can record the general patterns, such as the pattern represents
* PDPattern can record the general patterns, such as the pattern represents
* - Op in CPU -> Op in GPU -> Op in CPU, to findout the IO abnormal place.
* - Op in CPU -> Op in GPU -> Op in CPU, to findout the IO abnormal place.
...
@@ -112,7 +166,8 @@ class PDPattern {
...
@@ -112,7 +166,8 @@ class PDPattern {
void
AddEdge
(
PDNode
*
a
,
PDNode
*
b
);
void
AddEdge
(
PDNode
*
a
,
PDNode
*
b
);
PDNode
*
NewNode
(
PDNode
::
teller_t
&&
teller
,
const
std
::
string
&
name
=
NewID
());
PDNode
*
NewNode
(
PDNode
::
teller_t
&&
teller
,
const
std
::
string
&
name
=
NewID
());
PDNode
*
RetriveNode
(
const
std
::
string
&
id
)
const
;
PDNode
*
NewNode
(
const
std
::
string
&
name
=
NewID
());
PDNode
*
RetrieveNode
(
const
std
::
string
&
id
)
const
;
const
std
::
vector
<
std
::
unique_ptr
<
PDNode
>>&
nodes
()
const
{
return
nodes_
;
}
const
std
::
vector
<
std
::
unique_ptr
<
PDNode
>>&
nodes
()
const
{
return
nodes_
;
}
const
std
::
vector
<
edge_t
>&
edges
()
const
{
return
edges_
;
}
const
std
::
vector
<
edge_t
>&
edges
()
const
{
return
edges_
;
}
...
@@ -185,6 +240,9 @@ class GraphPatternDetector {
...
@@ -185,6 +240,9 @@ class GraphPatternDetector {
// Remove overlapped match subgraphs, when overlapped, keep the previous one.
// Remove overlapped match subgraphs, when overlapped, keep the previous one.
void
RemoveOverlappedMatch
(
std
::
vector
<
subgraph_t
>*
subgraphs
);
void
RemoveOverlappedMatch
(
std
::
vector
<
subgraph_t
>*
subgraphs
);
// Validate whether the intermediate nodes are linked by external nodes.
void
ValidateByNodeRole
(
std
::
vector
<
subgraph_t
>*
subgraphs
);
#ifdef PADDLE_WITH_TESTING
#ifdef PADDLE_WITH_TESTING
FRIEND_TEST
(
GraphPatternDetecter
,
MarkPDNodesInGraph
);
FRIEND_TEST
(
GraphPatternDetecter
,
MarkPDNodesInGraph
);
FRIEND_TEST
(
GraphPatternDetecter
,
DetectPatterns
);
FRIEND_TEST
(
GraphPatternDetecter
,
DetectPatterns
);
...
@@ -228,6 +286,14 @@ static bool IsNthInput(Node* var, Node* op, const std::string& argument,
...
@@ -228,6 +286,14 @@ static bool IsNthInput(Node* var, Node* op, const std::string& argument,
return
var
->
Name
()
==
op
->
Op
()
->
Input
(
argument
)[
nth
];
return
var
->
Name
()
==
op
->
Op
()
->
Input
(
argument
)[
nth
];
}
}
static
bool
IsNthOutput
(
Node
*
var
,
Node
*
op
,
const
std
::
string
&
argument
,
size_t
nth
)
{
PADDLE_ENFORCE
(
var
->
IsVar
());
PADDLE_ENFORCE
(
op
->
IsOp
());
if
(
op
->
inputs
.
size
()
<=
nth
)
return
false
;
return
var
->
Name
()
==
op
->
Op
()
->
Output
(
argument
)[
nth
];
}
static
void
GraphSafeRemoveNodes
(
Graph
*
graph
,
static
void
GraphSafeRemoveNodes
(
Graph
*
graph
,
const
std
::
unordered_set
<
const
Node
*>&
nodes
)
{
const
std
::
unordered_set
<
const
Node
*>&
nodes
)
{
for
(
auto
*
node
:
nodes
)
{
for
(
auto
*
node
:
nodes
)
{
...
...
paddle/fluid/framework/ir/graph_pattern_detector_tester.cc
浏览文件 @
af15f6f0
...
@@ -167,6 +167,39 @@ TEST(GraphPatternDetecter, MultiSubgraph) {
...
@@ -167,6 +167,39 @@ TEST(GraphPatternDetecter, MultiSubgraph) {
ASSERT_LE
(
count
,
2
);
ASSERT_LE
(
count
,
2
);
}
}
TEST
(
GraphPatternDetector
,
IntermediateCheck
)
{
ProgramDesc
program
;
Graph
graph
(
program
);
BuildGraph
(
&
graph
);
// o2->v2->o3
// o2->v2->o4
// check o2+o3 fuse, should fail because v2 also link to o4.
GraphPatternDetector
detector
;
auto
*
op2
=
detector
.
mutable_pattern
()
->
NewNode
(
[](
Node
*
x
)
{
return
x
&&
x
->
IsOp
()
&&
x
->
Name
()
==
"op2"
;
},
"op2"
);
auto
*
op3
=
detector
.
mutable_pattern
()
->
NewNode
(
[](
Node
*
x
)
{
return
x
&&
x
->
IsOp
()
&&
x
->
Name
()
==
"op3"
;
},
"op3"
);
auto
*
v2
=
detector
.
mutable_pattern
()
->
NewNode
(
[](
Node
*
x
)
{
return
x
&&
x
->
IsVar
()
&&
x
->
Name
()
==
"var2"
;
},
"var2"
)
->
AsIntermediate
();
v2
->
LinksFrom
({
op2
}).
LinksTo
({
op3
});
int
count
=
0
;
detector
(
&
graph
,
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
g
,
Graph
*
graph
)
{
++
count
;
});
EXPECT_EQ
(
count
,
0
);
count
=
0
;
v2
->
AsInput
();
detector
(
&
graph
,
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
g
,
Graph
*
graph
)
{
++
count
;
});
ASSERT_EQ
(
count
,
1
);
}
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc
浏览文件 @
af15f6f0
...
@@ -180,16 +180,16 @@ PDNode* BuildFCPattern(PDPattern* pattern, PDNode* fc_x) {
...
@@ -180,16 +180,16 @@ PDNode* BuildFCPattern(PDPattern* pattern, PDNode* fc_x) {
std
::
unique_ptr
<
ir
::
Graph
>
SeqConcatFcFusePass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
SeqConcatFcFusePass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
FusePassBase
::
Init
(
graph
.
get
());
FusePassBase
::
Init
(
"seq_concat_fc_fuse"
,
graph
.
get
());
GraphPatternDetector
detector
;
GraphPatternDetector
detector
;
auto
*
pattern
=
detector
.
mutable_pattern
();
auto
*
pattern
=
detector
.
mutable_pattern
();
auto
*
concat_out
=
BuildSeqExpandConcatPattern
(
pattern
);
auto
*
concat_out
=
BuildSeqExpandConcatPattern
(
pattern
);
BuildFCPattern
(
pattern
,
concat_out
);
BuildFCPattern
(
pattern
,
concat_out
);
#define GET_NODE(id, pattern) \
#define GET_NODE(id, pattern)
\
PADDLE_ENFORCE(subgraph.count(pattern.RetriveNode(#id)), \
PADDLE_ENFORCE(subgraph.count(pattern.Retri
e
veNode(#id)), \
"pattern has no Node called %s", #id); \
"pattern has no Node called %s", #id);
\
auto* id = subgraph.at(pattern.RetriveNode(#id)); \
auto* id = subgraph.at(pattern.Retri
e
veNode(#id)); \
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
detector
(
graph
.
get
(),
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
detector
(
graph
.
get
(),
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
...
...
paddle/fluid/inference/analysis/analyzer.cc
浏览文件 @
af15f6f0
...
@@ -93,7 +93,6 @@ class DfgPassManagerImpl final : public DfgPassManager {
...
@@ -93,7 +93,6 @@ class DfgPassManagerImpl final : public DfgPassManager {
void
AddGraphvizDebugerPass
(
Pass
*
pass
)
{
void
AddGraphvizDebugerPass
(
Pass
*
pass
)
{
auto
*
debuger_pass
=
pass
->
CreateGraphvizDebugerPass
();
auto
*
debuger_pass
=
pass
->
CreateGraphvizDebugerPass
();
if
(
debuger_pass
)
{
if
(
debuger_pass
)
{
LOG
(
INFO
)
<<
" - register debug pass ["
<<
debuger_pass
->
repr
()
<<
"]"
;
Register
(
debuger_pass
->
repr
(),
debuger_pass
);
Register
(
debuger_pass
->
repr
(),
debuger_pass
);
}
}
}
}
...
...
paddle/fluid/inference/analysis/analyzer_tester.cc
浏览文件 @
af15f6f0
...
@@ -16,10 +16,13 @@
...
@@ -16,10 +16,13 @@
#include <google/protobuf/text_format.h>
#include <google/protobuf/text_format.h>
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/inference/analysis/ut_helper.h"
#include "paddle/fluid/inference/analysis/ut_helper.h"
#include "paddle/fluid/inference/api/analysis_predictor.h"
#include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/utils/singleton.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/profiler.h"
DEFINE_string
(
infer_ditu_rnn_model
,
""
,
"model path for ditu RNN"
);
DEFINE_string
(
infer_ditu_rnn_model
,
""
,
"model path for ditu RNN"
);
...
@@ -31,6 +34,8 @@ namespace paddle {
...
@@ -31,6 +34,8 @@ namespace paddle {
namespace
inference
{
namespace
inference
{
namespace
analysis
{
namespace
analysis
{
using
namespace
framework
;
TEST
(
Analyzer
,
analysis_without_tensorrt
)
{
TEST
(
Analyzer
,
analysis_without_tensorrt
)
{
FLAGS_IA_enable_tensorrt_subgraph_engine
=
false
;
FLAGS_IA_enable_tensorrt_subgraph_engine
=
false
;
Argument
argument
;
Argument
argument
;
...
@@ -311,6 +316,20 @@ void TestDituRNNPrediction(const std::string &model_path,
...
@@ -311,6 +316,20 @@ void TestDituRNNPrediction(const std::string &model_path,
EXPECT_NEAR
(
data
[
i
],
base_data
[
i
],
1e-3
);
EXPECT_NEAR
(
data
[
i
],
base_data
[
i
],
1e-3
);
}
}
}
}
if
(
use_analysis
&&
activate_ir
)
{
AnalysisPredictor
*
analysis_predictor
=
dynamic_cast
<
AnalysisPredictor
*>
(
predictor
.
get
());
auto
&
fuse_statis
=
analysis_predictor
->
analysis_argument
()
.
Get
<
std
::
unordered_map
<
std
::
string
,
int
>>
(
framework
::
ir
::
kFuseStatisAttr
);
for
(
auto
&
item
:
fuse_statis
)
{
LOG
(
INFO
)
<<
"fused "
<<
item
.
first
<<
" "
<<
item
.
second
;
}
ASSERT_TRUE
(
fuse_statis
.
count
(
"fc"
));
EXPECT_EQ
(
fuse_statis
.
at
(
"fc"
),
1
);
}
}
}
// Directly infer with the original model.
// Directly infer with the original model.
...
...
paddle/fluid/inference/analysis/argument.h
浏览文件 @
af15f6f0
...
@@ -64,7 +64,8 @@ struct Argument {
...
@@ -64,7 +64,8 @@ struct Argument {
template
<
typename
T
>
template
<
typename
T
>
void
Set
(
const
std
::
string
&
key
,
T
*
data
)
{
void
Set
(
const
std
::
string
&
key
,
T
*
data
)
{
PADDLE_ENFORCE_NOT_NULL
(
data
);
PADDLE_ENFORCE_NOT_NULL
(
data
);
PADDLE_ENFORCE
(
!
attrs_
.
count
(
key
),
"duplicate attr called %s"
,
key
);
PADDLE_ENFORCE
(
!
attrs_
.
count
(
key
),
"Duplicate set Argument's attr [%s]"
,
key
);
attrs_
[
key
]
=
data
;
attrs_
[
key
]
=
data
;
attr_deleters_
[
key
]
=
[
data
,
key
,
this
]()
{
attr_deleters_
[
key
]
=
[
data
,
key
,
this
]()
{
VLOG
(
3
)
<<
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
;
VLOG
(
3
)
<<
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
;
...
...
paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc
浏览文件 @
af15f6f0
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h"
#include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h"
#include <vector>
#include <vector>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/proto_desc.h"
#include "paddle/fluid/framework/proto_desc.h"
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/analysis/analyzer.h"
...
@@ -34,7 +35,6 @@ std::vector<std::string> ExtractParameters(
...
@@ -34,7 +35,6 @@ std::vector<std::string> ExtractParameters(
bool
DataFlowGraphToFluidPass
::
Initialize
(
Argument
*
argument
)
{
bool
DataFlowGraphToFluidPass
::
Initialize
(
Argument
*
argument
)
{
ANALYSIS_ARGUMENT_CHECK_FIELD
(
argument
)
ANALYSIS_ARGUMENT_CHECK_FIELD
(
argument
)
ANALYSIS_ARGUMENT_CHECK_FIELD
(
argument
->
origin_program_desc
)
ANALYSIS_ARGUMENT_CHECK_FIELD
(
argument
->
origin_program_desc
)
PADDLE_ENFORCE
(
!
argument
->
transformed_program_desc
);
// The transformed_program_desc should inherit all the VarDesc and BlockDesc
// The transformed_program_desc should inherit all the VarDesc and BlockDesc
// from the original program desc. The operators of the main block(the first
// from the original program desc. The operators of the main block(the first
// block) should rewritten by data flow graph.
// block) should rewritten by data flow graph.
...
@@ -66,7 +66,7 @@ void DataFlowGraphToFluidPass::Run(DataFlowGraph *graph) {
...
@@ -66,7 +66,7 @@ void DataFlowGraphToFluidPass::Run(DataFlowGraph *graph) {
}
}
}
}
if
(
argument_
->
Has
(
"param_scope"
))
{
if
(
argument_
->
Has
(
framework
::
ir
::
kParamScopeAttr
))
{
LOG
(
WARNING
)
<<
"parameter changes in the scope takes effect"
;
LOG
(
WARNING
)
<<
"parameter changes in the scope takes effect"
;
}
}
...
...
paddle/fluid/inference/analysis/fluid_to_ir_pass.cc
浏览文件 @
af15f6f0
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#include "paddle/fluid/inference/analysis/fluid_to_ir_pass.h"
#include "paddle/fluid/inference/analysis/fluid_to_ir_pass.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/place.h"
...
@@ -26,11 +27,11 @@ void FluidToIrPass::EnableParamModify(const std::string &model_dir,
...
@@ -26,11 +27,11 @@ void FluidToIrPass::EnableParamModify(const std::string &model_dir,
const
std
::
string
&
prog_file
,
const
std
::
string
&
prog_file
,
const
std
::
string
&
param_file
)
{
const
std
::
string
&
param_file
)
{
PADDLE_ENFORCE
(
argument_
);
PADDLE_ENFORCE
(
argument_
);
argument_
->
Set
(
"param_scope"
,
new
framework
::
Scope
);
argument_
->
Set
(
framework
::
ir
::
kParamScopeAttr
,
new
framework
::
Scope
);
// Load parameters.
// Load parameters.
VLOG
(
3
)
<<
"Loading parameters from "
<<
model_dir
;
VLOG
(
3
)
<<
"Loading parameters from "
<<
model_dir
;
LoadParams
(
&
argument_
->
Get
<
framework
::
Scope
>
(
"param_scope"
),
model_dir
,
LoadParams
(
&
argument_
->
Get
<
framework
::
Scope
>
(
framework
::
ir
::
kParamScopeAttr
)
,
prog_file
,
param_file
);
model_dir
,
prog_file
,
param_file
);
}
}
bool
FluidToIrPass
::
LoadParams
(
framework
::
Scope
*
scope
,
const
std
::
string
&
dir
,
bool
FluidToIrPass
::
LoadParams
(
framework
::
Scope
*
scope
,
const
std
::
string
&
dir
,
...
...
paddle/fluid/inference/analysis/fluid_to_ir_pass.h
浏览文件 @
af15f6f0
...
@@ -14,12 +14,14 @@
...
@@ -14,12 +14,14 @@
#pragma once
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/inference/analysis/ir_pass_manager.h"
#include "paddle/fluid/inference/analysis/ir_pass_manager.h"
#include "paddle/fluid/inference/analysis/pass.h"
#include "paddle/fluid/inference/analysis/pass.h"
namespace
paddle
{
namespace
paddle
{
namespace
inference
{
namespace
inference
{
namespace
analysis
{
namespace
analysis
{
using
namespace
framework
;
static
const
char
kFluidToIrPassesAttr
[]
=
"__fluid_to_ir_passes__"
;
static
const
char
kFluidToIrPassesAttr
[]
=
"__fluid_to_ir_passes__"
;
...
@@ -45,13 +47,12 @@ class FluidToIrPass final : public DataFlowGraphPass {
...
@@ -45,13 +47,12 @@ class FluidToIrPass final : public DataFlowGraphPass {
ANALYSIS_ARGUMENT_CHECK_FIELD
(
argument
->
fluid_model_program_path
);
ANALYSIS_ARGUMENT_CHECK_FIELD
(
argument
->
fluid_model_program_path
);
// Load program.
// Load program.
auto
program
=
LoadProgramDesc
(
*
argument
->
fluid_model_program_path
);
auto
program
=
LoadProgramDesc
(
*
argument
->
fluid_model_program_path
);
argument
->
origin_program_desc
.
reset
(
argument
->
origin_program_desc
.
reset
(
new
proto
::
ProgramDesc
(
program
));
new
framework
::
proto
::
ProgramDesc
(
program
));
// Create main data flow graph.
// Create main data flow graph.
if
(
!
argument
->
main_dfg
)
{
if
(
!
argument
->
main_dfg
)
{
argument
->
main_dfg
.
reset
(
new
DataFlowGraph
);
argument
->
main_dfg
.
reset
(
new
DataFlowGraph
);
}
}
argument
->
Set
(
"ir_program_desc"
,
new
framework
::
ProgramDesc
(
program
));
argument
->
Set
(
"ir_program_desc"
,
new
ProgramDesc
(
program
));
LOG
(
INFO
)
<<
"Loading parameters"
;
LOG
(
INFO
)
<<
"Loading parameters"
;
// Load parameters to argument if needed.
// Load parameters to argument if needed.
...
@@ -73,15 +74,15 @@ class FluidToIrPass final : public DataFlowGraphPass {
...
@@ -73,15 +74,15 @@ class FluidToIrPass final : public DataFlowGraphPass {
void
Run
(
DataFlowGraph
*
graph
)
override
{
void
Run
(
DataFlowGraph
*
graph
)
override
{
// Call all the IR Passes
// Call all the IR Passes
IRPassManager
ir_passes
(
IRPassManager
ir_passes
(
argument_
->
Get
<
ProgramDesc
>
(
"ir_program_desc"
),
argument_
->
Get
<
framework
::
ProgramDesc
>
(
"ir_program_desc"
),
nullptr
);
nullptr
);
// Pass the scope from analysis to IR if needed.
// Pass the scope from analysis to IR if needed.
if
(
argument_
->
Has
(
"param_scope"
))
{
if
(
argument_
->
Has
(
ir
::
kParamScopeAttr
))
{
// Here the address is passed, attention that IR doesn't own the scope, so
// Here the address is passed, attention that IR doesn't own the scope, so
// the real scope in analysis should live during the IR phase.
// the real scope in analysis should live during the IR phase.
ir_passes
.
graph
().
Set
(
ir_passes
.
graph
().
Set
(
"param_scope"
,
new
framework
::
Scope
*
(
ir
::
kParamScopeAttr
,
&
argument_
->
Get
<
framework
::
Scope
>
(
"param_scope"
)));
new
Scope
*
(
&
argument_
->
Get
<
Scope
>
(
ir
::
kParamScopeAttr
)));
}
}
const
auto
&
ir_passes_to_apply
=
const
auto
&
ir_passes_to_apply
=
...
@@ -90,6 +91,14 @@ class FluidToIrPass final : public DataFlowGraphPass {
...
@@ -90,6 +91,14 @@ class FluidToIrPass final : public DataFlowGraphPass {
PADDLE_ENFORCE
(
argument_
->
main_dfg
.
get
());
PADDLE_ENFORCE
(
argument_
->
main_dfg
.
get
());
argument_
->
main_dfg
->
Build
(
ir_passes
.
graph
());
argument_
->
main_dfg
->
Build
(
ir_passes
.
graph
());
// inherit the arguments from ir.
if
(
ir_passes
.
graph
().
Has
(
ir
::
kFuseStatisAttr
))
{
argument_
->
Set
(
ir
::
kFuseStatisAttr
,
new
std
::
unordered_map
<
std
::
string
,
int
>
(
ir_passes
.
graph
().
Get
<
std
::
unordered_map
<
std
::
string
,
int
>>
(
ir
::
kFuseStatisAttr
)));
}
}
}
void
EnableParamModify
(
const
std
::
string
&
model_dir
,
void
EnableParamModify
(
const
std
::
string
&
model_dir
,
...
@@ -100,7 +109,7 @@ class FluidToIrPass final : public DataFlowGraphPass {
...
@@ -100,7 +109,7 @@ class FluidToIrPass final : public DataFlowGraphPass {
private:
private:
// Load parameters from a single file or from a directory.
// Load parameters from a single file or from a directory.
bool
LoadParams
(
framework
::
Scope
*
scope
,
const
std
::
string
&
dir
,
bool
LoadParams
(
Scope
*
scope
,
const
std
::
string
&
dir
,
const
std
::
string
&
prog_file
,
const
std
::
string
&
param_file
);
const
std
::
string
&
prog_file
,
const
std
::
string
&
param_file
);
private:
private:
...
...
paddle/fluid/inference/analysis/ir_pass_manager.cc
浏览文件 @
af15f6f0
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#include "paddle/fluid/inference/analysis/ir_pass_manager.h"
#include "paddle/fluid/inference/analysis/ir_pass_manager.h"
#include <string>
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
...
@@ -25,7 +26,8 @@ IRPassManager::IRPassManager(const ProgramDesc &program,
...
@@ -25,7 +26,8 @@ IRPassManager::IRPassManager(const ProgramDesc &program,
framework
::
Scope
*
scope
)
framework
::
Scope
*
scope
)
:
program_
(
program
)
{
:
program_
(
program
)
{
graph_
.
reset
(
new
framework
::
ir
::
Graph
(
program
));
graph_
.
reset
(
new
framework
::
ir
::
Graph
(
program
));
if
(
scope
)
graph_
->
Set
(
"param_scope"
,
new
framework
::
Scope
*
(
scope
));
if
(
scope
)
graph_
->
Set
(
framework
::
ir
::
kParamScopeAttr
,
new
framework
::
Scope
*
(
scope
));
}
}
void
IRPassManager
::
Apply
(
const
std
::
vector
<
std
::
string
>
&
passes
)
{
void
IRPassManager
::
Apply
(
const
std
::
vector
<
std
::
string
>
&
passes
)
{
...
...
paddle/fluid/inference/api/analysis_predictor.cc
浏览文件 @
af15f6f0
...
@@ -12,121 +12,96 @@
...
@@ -12,121 +12,96 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/inference/api/analysis_predictor.h"
#include <memory>
#include <memory>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/api/api_impl.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/utils/singleton.h"
#include "paddle/fluid/inference/utils/singleton.h"
namespace
paddle
{
namespace
paddle
{
using
inference
::
analysis
::
Argument
;
bool
AnalysisPredictor
::
Init
(
using
inference
::
Singleton
;
const
std
::
shared_ptr
<
framework
::
Scope
>&
parent_scope
)
{
using
inference
::
analysis
::
Analyzer
;
VLOG
(
3
)
<<
"Predictor::init()"
;
using
framework
::
proto
::
ProgramDesc
;
if
(
config_
.
use_gpu
)
{
place_
=
paddle
::
platform
::
CUDAPlace
(
config_
.
device
);
/* This predictor is based on the original native predictor with IR and Analysis
}
else
{
* support. It will optimize IR and Parameters in the runtime.
place_
=
paddle
::
platform
::
CPUPlace
();
* TODO(Superjomn) Replace the Navive predictor?
*/
class
AnalysisPredictor
:
public
NativePaddlePredictor
{
public:
explicit
AnalysisPredictor
(
const
NativeConfig
&
config
)
:
NativePaddlePredictor
(
config
),
config_
(
config
)
{}
bool
Init
(
const
std
::
shared_ptr
<
framework
::
Scope
>&
parent_scope
)
{
VLOG
(
3
)
<<
"Predictor::init()"
;
if
(
config_
.
use_gpu
)
{
place_
=
paddle
::
platform
::
CUDAPlace
(
config_
.
device
);
}
else
{
place_
=
paddle
::
platform
::
CPUPlace
();
}
PADDLE_ENFORCE
(
!
parent_scope
);
if
(
parent_scope
)
{
scope_
=
parent_scope
;
sub_scope_
=
&
(
parent_scope
->
NewScope
());
}
else
{
paddle
::
framework
::
InitDevices
(
false
);
scope_
.
reset
(
new
paddle
::
framework
::
Scope
());
}
executor_
.
reset
(
new
paddle
::
framework
::
Executor
(
place_
));
// Initialize the inference program
if
(
!
config_
.
model_dir
.
empty
())
{
// Parameters are saved in separate files sited in
// the specified `dirname`.
inference_program_
=
paddle
::
inference
::
Load
(
executor_
.
get
(),
scope_
.
get
(),
config_
.
model_dir
);
}
else
if
(
!
config_
.
prog_file
.
empty
()
&&
!
config_
.
param_file
.
empty
())
{
// All parameters are saved in a single file.
// The file names should be consistent with that used
// in Python API `fluid.io.save_inference_model`.
inference_program_
=
paddle
::
inference
::
Load
(
executor_
.
get
(),
scope_
.
get
(),
config_
.
prog_file
,
config_
.
param_file
);
}
else
{
LOG
(
ERROR
)
<<
"fail to load inference model."
;
return
false
;
}
OptimizeInferenceProgram
();
ctx_
=
executor_
->
Prepare
(
*
inference_program_
,
0
);
VLOG
(
5
)
<<
"to create variables"
;
PADDLE_ENFORCE
(
scope_
.
get
());
executor_
->
CreateVariables
(
*
inference_program_
,
sub_scope_
?
sub_scope_
:
scope_
.
get
(),
0
);
// Get the feed_target_names and fetch_target_names
PrepareFeedFetch
();
return
true
;
}
}
PADDLE_ENFORCE
(
!
parent_scope
);
bool
Run
(
const
std
::
vector
<
PaddleTensor
>&
inputs
,
if
(
parent_scope
)
{
std
::
vector
<
PaddleTensor
>*
output_data
,
scope_
=
parent_scope
;
int
batch_size
=
-
1
)
override
{
sub_scope_
=
&
(
parent_scope
->
NewScope
());
return
NativePaddlePredictor
::
Run
(
inputs
,
output_data
,
batch_size
);
}
else
{
paddle
::
framework
::
InitDevices
(
false
);
scope_
.
reset
(
new
paddle
::
framework
::
Scope
());
}
}
void
OptimizeInferenceProgram
()
{
executor_
.
reset
(
new
paddle
::
framework
::
Executor
(
place_
));
LOG
(
INFO
)
<<
"optimize begin"
;
FLAGS_IA_enable_ir
=
true
;
// Initialize the inference program
FLAGS_IA_enable_tensorrt_subgraph_engine
=
false
;
if
(
!
config_
.
model_dir
.
empty
())
{
FLAGS_IA_output_storage_path
=
""
;
// Don't output the model.
// Parameters are saved in separate files sited in
// Analyze inference_program
// the specified `dirname`.
Argument
argument
;
inference_program_
=
paddle
::
inference
::
Load
(
executor_
.
get
(),
scope_
.
get
(),
if
(
!
config_
.
model_dir
.
empty
())
{
config_
.
model_dir
);
argument
.
fluid_model_dir
.
reset
(
new
std
::
string
(
config_
.
model_dir
));
}
else
if
(
!
config_
.
prog_file
.
empty
()
&&
!
config_
.
param_file
.
empty
())
{
}
else
{
// All parameters are saved in a single file.
PADDLE_ENFORCE
(
// The file names should be consistent with that used
!
config_
.
param_file
.
empty
(),
// in Python API `fluid.io.save_inference_model`.
"Either model_dir or (param_file, prog_file) should be set."
);
inference_program_
=
paddle
::
inference
::
Load
(
PADDLE_ENFORCE
(
!
config_
.
prog_file
.
empty
());
executor_
.
get
(),
scope_
.
get
(),
config_
.
prog_file
,
config_
.
param_file
);
argument
.
fluid_model_program_path
.
reset
(
}
else
{
new
std
::
string
(
config_
.
prog_file
));
LOG
(
ERROR
)
<<
"fail to load inference model."
;
argument
.
fluid_model_param_path
.
reset
(
return
false
;
new
std
::
string
(
config_
.
param_file
));
}
argument
.
origin_program_desc
.
reset
(
new
ProgramDesc
(
*
inference_program_
->
Proto
()));
Singleton
<
Analyzer
>::
Global
().
Run
(
&
argument
);
CHECK
(
argument
.
transformed_program_desc
);
VLOG
(
5
)
<<
"to prepare executor"
;
// LOG(INFO) << "transformed_parogram_desc " <<
// argument.transformed_program_desc->DebugString();
inference_program_
.
reset
(
new
framework
::
ProgramDesc
(
*
argument
.
transformed_program_desc
));
PADDLE_ENFORCE
(
argument
.
Has
(
"param_scope"
));
// Update scope.
scope_
.
reset
(
argument
.
Release
<
framework
::
Scope
>
(
"param_scope"
));
LOG
(
INFO
)
<<
"optimize end =="
;
}
}
private:
OptimizeInferenceProgram
();
NativeConfig
config_
;
ctx_
=
executor_
->
Prepare
(
*
inference_program_
,
0
);
};
VLOG
(
5
)
<<
"to create variables"
;
PADDLE_ENFORCE
(
scope_
.
get
());
executor_
->
CreateVariables
(
*
inference_program_
,
sub_scope_
?
sub_scope_
:
scope_
.
get
(),
0
);
// Get the feed_target_names and fetch_target_names
PrepareFeedFetch
();
return
true
;
}
void
AnalysisPredictor
::
OptimizeInferenceProgram
()
{
LOG
(
INFO
)
<<
"optimize begin"
;
FLAGS_IA_enable_ir
=
true
;
FLAGS_IA_enable_tensorrt_subgraph_engine
=
false
;
FLAGS_IA_output_storage_path
=
""
;
// Don't output the model.
// Analyze inference_program
if
(
!
config_
.
model_dir
.
empty
())
{
argument_
.
fluid_model_dir
.
reset
(
new
std
::
string
(
config_
.
model_dir
));
}
else
{
PADDLE_ENFORCE
(
!
config_
.
param_file
.
empty
(),
"Either model_dir or (param_file, prog_file) should be set."
);
PADDLE_ENFORCE
(
!
config_
.
prog_file
.
empty
());
argument_
.
fluid_model_program_path
.
reset
(
new
std
::
string
(
config_
.
prog_file
));
argument_
.
fluid_model_param_path
.
reset
(
new
std
::
string
(
config_
.
param_file
));
}
argument_
.
origin_program_desc
.
reset
(
new
ProgramDesc
(
*
inference_program_
->
Proto
()));
Analyzer
().
Run
(
&
argument_
);
CHECK
(
argument_
.
transformed_program_desc
);
VLOG
(
5
)
<<
"to prepare executor"
;
// LOG(INFO) << "transformed_parogram_desc " <<
// argument.transformed_program_desc->DebugString();
inference_program_
.
reset
(
new
framework
::
ProgramDesc
(
*
argument_
.
transformed_program_desc
));
PADDLE_ENFORCE
(
argument_
.
Has
(
framework
::
ir
::
kParamScopeAttr
));
// Update scope.
scope_
.
reset
(
argument_
.
Release
<
framework
::
Scope
>
(
framework
::
ir
::
kParamScopeAttr
));
LOG
(
INFO
)
<<
"optimize end =="
;
}
template
<
>
template
<
>
std
::
unique_ptr
<
PaddlePredictor
>
CreatePaddlePredictor
<
std
::
unique_ptr
<
PaddlePredictor
>
CreatePaddlePredictor
<
...
...
paddle/fluid/inference/api/analysis_predictor.h
0 → 100644
浏览文件 @
af15f6f0
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/api/api_impl.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
namespace
paddle
{
using
inference
::
analysis
::
Argument
;
using
inference
::
analysis
::
Analyzer
;
using
framework
::
proto
::
ProgramDesc
;
/* This predictor is based on the original native predictor with IR and Analysis
* support. It will optimize IR and Parameters in the runtime.
* TODO(Superjomn) Replace the Navive predictor?
*/
class
AnalysisPredictor
:
public
NativePaddlePredictor
{
public:
explicit
AnalysisPredictor
(
const
NativeConfig
&
config
)
:
NativePaddlePredictor
(
config
),
config_
(
config
)
{}
bool
Init
(
const
std
::
shared_ptr
<
framework
::
Scope
>&
parent_scope
);
bool
Run
(
const
std
::
vector
<
PaddleTensor
>&
inputs
,
std
::
vector
<
PaddleTensor
>*
output_data
,
int
batch_size
=
-
1
)
override
{
return
NativePaddlePredictor
::
Run
(
inputs
,
output_data
,
batch_size
);
}
void
OptimizeInferenceProgram
();
Argument
&
analysis_argument
()
{
return
argument_
;
}
private:
NativeConfig
config_
;
Argument
argument_
;
};
}
// namespace paddle
paddle/fluid/inference/api/api_impl.cc
浏览文件 @
af15f6f0
...
@@ -62,7 +62,7 @@ void NativePaddlePredictor::PrepareFeedFetch() {
...
@@ -62,7 +62,7 @@ void NativePaddlePredictor::PrepareFeedFetch() {
for
(
auto
*
op
:
inference_program_
->
Block
(
0
).
AllOps
())
{
for
(
auto
*
op
:
inference_program_
->
Block
(
0
).
AllOps
())
{
if
(
op
->
Type
()
==
"feed"
)
{
if
(
op
->
Type
()
==
"feed"
)
{
int
idx
=
boost
::
get
<
int
>
(
op
->
GetAttr
(
"col"
));
int
idx
=
boost
::
get
<
int
>
(
op
->
GetAttr
(
"col"
));
if
(
feeds_
.
size
()
<=
idx
)
{
if
(
feeds_
.
size
()
<=
static_cast
<
size_t
>
(
idx
)
)
{
feeds_
.
resize
(
idx
+
1
);
feeds_
.
resize
(
idx
+
1
);
}
}
feeds_
[
idx
]
=
op
;
feeds_
[
idx
]
=
op
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录