Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
af15f6f0
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
af15f6f0
编写于
8月 31, 2018
作者:
Y
Yan Chunwei
提交者:
GitHub
8月 31, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fea/refine fuse (#13076)
上级
819af27d
变更
20
隐藏空白更改
内联
并排
Showing
20 changed file
with
545 addition
and
226 deletion
+545
-226
paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc
paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc
+1
-1
paddle/fluid/framework/ir/fc_fuse_pass.cc
paddle/fluid/framework/ir/fc_fuse_pass.cc
+37
-75
paddle/fluid/framework/ir/fc_fuse_pass.h
paddle/fluid/framework/ir/fc_fuse_pass.h
+2
-1
paddle/fluid/framework/ir/fc_fuse_pass_tester.cc
paddle/fluid/framework/ir/fc_fuse_pass_tester.cc
+7
-2
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
+4
-4
paddle/fluid/framework/ir/fuse_pass_base.h
paddle/fluid/framework/ir/fuse_pass_base.h
+18
-2
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+180
-1
paddle/fluid/framework/ir/graph_pattern_detector.h
paddle/fluid/framework/ir/graph_pattern_detector.h
+81
-15
paddle/fluid/framework/ir/graph_pattern_detector_tester.cc
paddle/fluid/framework/ir/graph_pattern_detector_tester.cc
+33
-0
paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc
paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc
+5
-5
paddle/fluid/inference/analysis/analyzer.cc
paddle/fluid/inference/analysis/analyzer.cc
+0
-1
paddle/fluid/inference/analysis/analyzer_tester.cc
paddle/fluid/inference/analysis/analyzer_tester.cc
+19
-0
paddle/fluid/inference/analysis/argument.h
paddle/fluid/inference/analysis/argument.h
+2
-1
paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc
...fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc
+2
-2
paddle/fluid/inference/analysis/fluid_to_ir_pass.cc
paddle/fluid/inference/analysis/fluid_to_ir_pass.cc
+4
-3
paddle/fluid/inference/analysis/fluid_to_ir_pass.h
paddle/fluid/inference/analysis/fluid_to_ir_pass.h
+18
-9
paddle/fluid/inference/analysis/ir_pass_manager.cc
paddle/fluid/inference/analysis/ir_pass_manager.cc
+3
-1
paddle/fluid/inference/api/analysis_predictor.cc
paddle/fluid/inference/api/analysis_predictor.cc
+77
-102
paddle/fluid/inference/api/analysis_predictor.h
paddle/fluid/inference/api/analysis_predictor.h
+51
-0
paddle/fluid/inference/api/api_impl.cc
paddle/fluid/inference/api/api_impl.cc
+1
-1
未找到文件。
paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc
浏览文件 @
af15f6f0
...
@@ -59,7 +59,7 @@ void FindWhileOp(Graph* graph) {
...
@@ -59,7 +59,7 @@ void FindWhileOp(Graph* graph) {
auto
handle
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
auto
handle
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
Graph
*
g
)
{
auto
*
while_pat_node
=
gpd
.
pattern
().
RetriveNode
(
"while"
);
auto
*
while_pat_node
=
gpd
.
pattern
().
Retri
e
veNode
(
"while"
);
auto
*
while_node
=
subgraph
.
at
(
while_pat_node
);
auto
*
while_node
=
subgraph
.
at
(
while_pat_node
);
marked_nodes
.
insert
(
while_node
);
marked_nodes
.
insert
(
while_node
);
};
};
...
...
paddle/fluid/framework/ir/fc_fuse_pass.cc
浏览文件 @
af15f6f0
...
@@ -31,77 +31,34 @@ bool VarOutLinksToOp(Node* node, const std::string& op_type) {
...
@@ -31,77 +31,34 @@ bool VarOutLinksToOp(Node* node, const std::string& op_type) {
}
}
void
BuildFCPattern
(
PDPattern
*
pattern
)
{
void
BuildFCPattern
(
PDPattern
*
pattern
)
{
// make sure the selected MUL op has one input argument is a parameter.
// Create Operators
auto
*
mul_parameter_var
=
pattern
->
NewNode
(
auto
*
mul_op
=
pattern
->
NewNode
(
"mul"
)
->
assert_is_op
(
"mul"
);
[](
Node
*
node
)
{
auto
*
elementwise_add_op
=
return
node
->
IsVar
()
&&
node
->
outputs
.
size
()
==
1UL
&&
pattern
->
NewNode
(
"elementwise_add"
)
->
assert_is_op
(
"elementwise_add"
);
node
->
outputs
.
front
()
->
Op
()
->
Type
()
==
"mul"
&&
node
->
Var
()
&&
// Create variables
node
->
Var
()
->
Persistable
();
// check is a parameter
// w
},
auto
*
mul_weight_var
=
pattern
->
NewNode
(
"mul_weight"
)
"mul_weight"
/*name*/
);
->
AsInput
()
->
assert_is_op_nth_input
(
"mul"
,
"Y"
,
0
);
auto
*
mul_tmp_input_var
=
pattern
->
NewNode
(
// x
[](
Node
*
node
)
{
auto
*
mul_tmp_var
=
pattern
->
NewNode
(
"mul_tmp_var"
)
bool
result
=
->
AsInput
()
node
->
IsVar
()
&&
node
->
outputs
.
size
()
>=
1UL
&&
node
->
Var
()
&&
->
assert_is_op_nth_input
(
"mul"
,
"X"
,
0
);
!
node
->
Var
()
->
Persistable
();
// this input is not an parameter.
// intermediate variable, will be removed in the IR after fuse.
if
(
!
result
)
return
false
;
auto
*
mul_out_var
=
pattern
->
NewNode
(
"mul_out"
)
// check whether one output is MUL op.
->
AsIntermediate
()
for
(
auto
*
op
:
node
->
outputs
)
{
->
assert_is_only_output_of_op
(
"mul"
)
if
(
op
->
IsOp
()
&&
op
->
Op
()
->
Type
()
==
"mul"
)
return
true
;
->
assert_is_op_input
(
"elementwise_add"
);
}
// bias
return
false
;
auto
*
elementwise_add_tmp_var
=
pattern
->
NewNode
(
"elementwise_add_tmpvar"
)
},
->
assert_is_op_input
(
"elementwise_add"
)
"mul_tmp_var"
/*name*/
);
->
AsInput
();
// output
// select a MUL op
auto
*
elementwise_add_out_var
=
pattern
->
NewNode
(
"elementwise_add_out"
)
auto
*
mul_op
=
pattern
->
NewNode
(
->
AsOutput
()
[](
Node
*
node
)
{
->
assert_is_op_output
(
"elementwise_add"
);
return
node
->
IsOp
()
&&
// start from an Op
node
->
Op
()
->
Type
()
==
"mul"
;
// type is mul
mul_op
->
LinksFrom
({
mul_weight_var
,
mul_tmp_var
}).
LinksTo
({
mul_out_var
});
// the output should be consumed only by one element_add, that check
// leaves in a Var PDNode.
},
"mul"
/*name*/
);
// make sure the MUL op's output has only one consumer and links to an
// ELEMENTWISE_ADD op.
auto
*
mul_out_var
=
pattern
->
NewNode
(
[](
Node
*
node
)
{
return
node
->
IsVar
()
&&
// starts from a Var
node
->
outputs
.
size
()
==
1UL
&&
// only has one consumer
node
->
outputs
.
front
()
->
IsOp
()
&&
// check basic logic
node
->
Var
()
&&
// not a ControlDepVar
node
->
outputs
.
front
()
->
Op
()
->
Type
()
==
"elementwise_add"
;
// a very strong validation
},
"mul_out"
);
// this check is not essential, just to make the corresponding variable Node
// retrival easier.
auto
*
elementwise_add_tmp_var
=
pattern
->
NewNode
(
[](
Node
*
node
)
{
return
node
->
IsVar
()
&&
node
->
outputs
.
size
()
>=
1UL
&&
node
->
Var
()
&&
VarOutLinksToOp
(
node
,
"elementwise_add"
);
},
"elementwise_add_tmpvar"
);
// select an ELEMENTWISE_ADD op
auto
*
elementwise_add_op
=
pattern
->
NewNode
(
[](
Node
*
node
)
{
return
node
->
IsOp
()
&&
node
->
Op
()
->
Type
()
==
"elementwise_add"
;
},
"elementwise_add"
/*name*/
);
// get the ELEMENTWISE_ADD op's output
auto
*
elementwise_add_out_var
=
pattern
->
NewNode
(
[](
Node
*
node
)
{
return
node
->
IsVar
()
&&
node
->
inputs
.
size
()
==
1UL
&&
node
->
Var
()
&&
node
->
inputs
.
front
()
->
Op
()
->
Type
()
==
"elementwise_add"
;
},
"elementwise_add_out"
);
mul_op
->
LinksFrom
({
mul_parameter_var
,
mul_tmp_input_var
})
.
LinksTo
({
mul_out_var
});
elementwise_add_op
->
LinksFrom
({
mul_out_var
,
elementwise_add_tmp_var
})
elementwise_add_op
->
LinksFrom
({
mul_out_var
,
elementwise_add_tmp_var
})
.
LinksTo
({
elementwise_add_out_var
});
.
LinksTo
({
elementwise_add_out_var
});
}
}
...
@@ -120,18 +77,20 @@ bool LinksReplace(std::vector<Node*>* links, Node* from, Node* to) {
...
@@ -120,18 +77,20 @@ bool LinksReplace(std::vector<Node*>* links, Node* from, Node* to) {
std
::
unique_ptr
<
ir
::
Graph
>
FCFusePass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
FCFusePass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
PADDLE_ENFORCE
(
graph
.
get
());
PADDLE_ENFORCE
(
graph
.
get
());
FusePassBase
::
Init
(
"fc"
,
graph
.
get
());
std
::
unordered_set
<
Node
*>
nodes2delete
;
std
::
unordered_set
<
Node
*>
nodes2delete
;
GraphPatternDetector
gpd
;
GraphPatternDetector
gpd
;
BuildFCPattern
(
gpd
.
mutable_pattern
());
BuildFCPattern
(
gpd
.
mutable_pattern
());
#define GET_NODE(id) \
#define GET_NODE(id)
\
PADDLE_ENFORCE(subgraph.count(gpd.pattern().RetriveNode(#id)), \
PADDLE_ENFORCE(subgraph.count(gpd.pattern().Retri
e
veNode(#id)), \
"pattern has no Node called %s", #id); \
"pattern has no Node called %s", #id);
\
auto* id = subgraph.at(gpd.pattern().RetriveNode(#id)); \
auto* id = subgraph.at(gpd.pattern().Retri
e
veNode(#id)); \
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
int
found_fc_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
Graph
*
g
)
{
VLOG
(
4
)
<<
"handle FC fuse"
;
VLOG
(
4
)
<<
"handle FC fuse"
;
...
@@ -176,10 +135,13 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
...
@@ -176,10 +135,13 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
graph
->
RemoveNode
(
mul
);
graph
->
RemoveNode
(
mul
);
graph
->
RemoveNode
(
elementwise_add
);
graph
->
RemoveNode
(
elementwise_add
);
graph
->
RemoveNode
(
mul_out
);
// tmp variable
graph
->
RemoveNode
(
mul_out
);
// tmp variable
found_fc_count
++
;
};
};
gpd
(
graph
.
get
(),
handler
);
gpd
(
graph
.
get
(),
handler
);
AddStatis
(
found_fc_count
);
return
graph
;
return
graph
;
}
}
...
...
paddle/fluid/framework/ir/fc_fuse_pass.h
浏览文件 @
af15f6f0
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass.h"
...
@@ -23,7 +24,7 @@ namespace ir {
...
@@ -23,7 +24,7 @@ namespace ir {
/*
/*
* Fuse the MUL and ELEMENTWISE_ADD to a FCOp.
* Fuse the MUL and ELEMENTWISE_ADD to a FCOp.
*/
*/
class
FCFusePass
:
public
Pass
{
class
FCFusePass
:
public
FusePassBase
{
public:
public:
virtual
~
FCFusePass
()
{}
virtual
~
FCFusePass
()
{}
...
...
paddle/fluid/framework/ir/fc_fuse_pass_tester.cc
浏览文件 @
af15f6f0
...
@@ -25,8 +25,13 @@ void SetOp(ProgramDesc* prog, const std::string& type,
...
@@ -25,8 +25,13 @@ void SetOp(ProgramDesc* prog, const std::string& type,
const
std
::
vector
<
std
::
string
>&
outputs
)
{
const
std
::
vector
<
std
::
string
>&
outputs
)
{
auto
*
op
=
prog
->
MutableBlock
(
0
)
->
AppendOp
();
auto
*
op
=
prog
->
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
type
);
op
->
SetType
(
type
);
op
->
SetInput
(
"Xs"
,
inputs
);
if
(
type
==
"mul"
)
{
op
->
SetOutput
(
"Ys"
,
outputs
);
op
->
SetInput
(
"X"
,
{
inputs
[
0
]});
op
->
SetInput
(
"Y"
,
{
inputs
[
1
]});
}
else
if
(
type
==
"elementwise_add"
)
{
op
->
SetInput
(
"X"
,
inputs
);
}
op
->
SetOutput
(
"Out"
,
outputs
);
}
}
// a->OP0->b
// a->OP0->b
...
...
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
浏览文件 @
af15f6f0
...
@@ -36,7 +36,7 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
...
@@ -36,7 +36,7 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
Graph
*
g
)
{
auto
*
id
=
subgraph
.
at
(
gpd
.
pattern
().
RetriveNode
(
"any_node"
));
auto
*
id
=
subgraph
.
at
(
gpd
.
pattern
().
Retri
e
veNode
(
"any_node"
));
marked_nodes
.
insert
(
id
);
marked_nodes
.
insert
(
id
);
};
};
gpd
(
graph
.
get
(),
handler
);
gpd
(
graph
.
get
(),
handler
);
...
@@ -64,9 +64,9 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
...
@@ -64,9 +64,9 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
#undef GET_NODE
#undef GET_NODE
#undef SET_IN
#undef SET_IN
LOG
(
INFO
)
<<
"hidden_n: "
<<
hidden_n
->
Name
();
VLOG
(
4
)
<<
"hidden_n: "
<<
hidden_n
->
Name
();
LOG
(
INFO
)
<<
"cell: "
<<
cell_n
->
Name
();
VLOG
(
4
)
<<
"cell: "
<<
cell_n
->
Name
();
LOG
(
INFO
)
<<
"xx: "
<<
xx_n
->
Name
();
VLOG
(
4
)
<<
"xx: "
<<
xx_n
->
Name
();
op_desc
.
SetInput
(
"H0"
,
{});
op_desc
.
SetInput
(
"H0"
,
{});
op_desc
.
SetInput
(
"C0"
,
{});
op_desc
.
SetInput
(
"C0"
,
{});
...
...
paddle/fluid/framework/ir/fuse_pass_base.h
浏览文件 @
af15f6f0
...
@@ -22,21 +22,37 @@ namespace paddle {
...
@@ -22,21 +22,37 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
ir
{
namespace
ir
{
static
const
char
kParamScopeAttr
[]
=
"param_scope"
;
static
const
char
kParamScopeAttr
[]
=
"__param_scope__"
;
static
const
char
kFuseStatisAttr
[]
=
"__fuse_statis__"
;
class
FusePassBase
:
public
Pass
{
class
FusePassBase
:
public
Pass
{
public:
public:
void
Init
(
Graph
*
graph
)
const
{
graph_
=
graph
;
}
void
Init
(
const
std
::
string
&
repr
,
Graph
*
graph
)
const
{
repr_
=
repr
;
graph_
=
graph
;
}
Scope
*
param_scope
()
const
{
Scope
*
param_scope
()
const
{
PADDLE_ENFORCE
(
graph_
->
Has
(
kParamScopeAttr
));
PADDLE_ENFORCE
(
graph_
->
Has
(
kParamScopeAttr
));
return
graph_
->
Get
<
framework
::
Scope
*>
(
kParamScopeAttr
);
return
graph_
->
Get
<
framework
::
Scope
*>
(
kParamScopeAttr
);
}
}
void
AddStatis
(
int
count_of_fused
)
const
{
PADDLE_ENFORCE
(
graph_
);
PADDLE_ENFORCE
(
!
repr_
.
empty
());
if
(
!
graph_
->
Has
(
kFuseStatisAttr
))
{
graph_
->
Set
(
kFuseStatisAttr
,
new
std
::
unordered_map
<
std
::
string
,
int
>
);
}
auto
&
info
=
graph_
->
Get
<
std
::
unordered_map
<
std
::
string
,
int
>>
(
kFuseStatisAttr
);
info
[
repr_
]
=
count_of_fused
;
}
virtual
~
FusePassBase
()
{}
virtual
~
FusePassBase
()
{}
protected:
protected:
mutable
Graph
*
graph_
;
mutable
Graph
*
graph_
;
mutable
std
::
string
repr_
;
};
};
}
// namespace ir
}
// namespace ir
...
...
paddle/fluid/framework/ir/graph_pattern_detector.cc
浏览文件 @
af15f6f0
...
@@ -27,6 +27,19 @@ namespace ir {
...
@@ -27,6 +27,19 @@ namespace ir {
size_t
PDPattern
::
id_
=
0UL
;
size_t
PDPattern
::
id_
=
0UL
;
PDNode
*
PDPattern
::
NewNode
(
const
std
::
string
&
name
)
{
if
(
!
name
.
empty
())
{
PADDLE_ENFORCE_EQ
(
node_map_
.
count
(
name
),
0
,
"PDNode's name should be unique, get duplicate [%s]"
,
name
);
}
nodes_
.
emplace_back
(
new
PDNode
(
this
,
name
));
auto
*
cur
=
nodes_
.
back
().
get
();
node_map_
[
name
]
=
cur
;
return
cur
;
}
PDNode
*
PDPattern
::
NewNode
(
PDNode
::
teller_t
&&
teller
,
const
std
::
string
&
name
)
{
PDNode
*
PDPattern
::
NewNode
(
PDNode
::
teller_t
&&
teller
,
const
std
::
string
&
name
)
{
if
(
!
name
.
empty
())
{
if
(
!
name
.
empty
())
{
PADDLE_ENFORCE_EQ
(
node_map_
.
count
(
name
),
0
,
PADDLE_ENFORCE_EQ
(
node_map_
.
count
(
name
),
0
,
...
@@ -40,7 +53,7 @@ PDNode* PDPattern::NewNode(PDNode::teller_t&& teller, const std::string& name) {
...
@@ -40,7 +53,7 @@ PDNode* PDPattern::NewNode(PDNode::teller_t&& teller, const std::string& name) {
return
cur
;
return
cur
;
}
}
PDNode
*
PDPattern
::
RetriveNode
(
const
std
::
string
&
id
)
const
{
PDNode
*
PDPattern
::
Retri
e
veNode
(
const
std
::
string
&
id
)
const
{
auto
it
=
node_map_
.
find
(
id
);
auto
it
=
node_map_
.
find
(
id
);
if
(
it
==
node_map_
.
end
())
{
if
(
it
==
node_map_
.
end
())
{
return
nullptr
;
return
nullptr
;
...
@@ -62,7 +75,9 @@ void GraphPatternDetector::operator()(Graph* graph,
...
@@ -62,7 +75,9 @@ void GraphPatternDetector::operator()(Graph* graph,
auto
subgraphs
=
DetectPatterns
();
auto
subgraphs
=
DetectPatterns
();
UniquePatterns
(
&
subgraphs
);
UniquePatterns
(
&
subgraphs
);
RemoveOverlappedMatch
(
&
subgraphs
);
RemoveOverlappedMatch
(
&
subgraphs
);
ValidateByNodeRole
(
&
subgraphs
);
if
(
subgraphs
.
empty
())
return
;
LOG
(
INFO
)
<<
"detect "
<<
subgraphs
.
size
()
<<
" subgraph matches the pattern"
;
LOG
(
INFO
)
<<
"detect "
<<
subgraphs
.
size
()
<<
" subgraph matches the pattern"
;
int
id
=
0
;
int
id
=
0
;
for
(
auto
&
g
:
subgraphs
)
{
for
(
auto
&
g
:
subgraphs
)
{
...
@@ -83,10 +98,54 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) {
...
@@ -83,10 +98,54 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) {
}
}
}
}
}
}
// Check to early stop if some PDNode can't find matched Node.
for
(
auto
&
pdnode
:
pattern_
.
nodes
())
{
if
(
!
pdnodes2nodes_
.
count
(
pdnode
.
get
()))
{
VLOG
(
4
)
<<
pdnode
->
name
()
<<
" can't find matched Node, early stop"
;
return
false
;
}
}
VLOG
(
3
)
<<
pdnodes2nodes_
.
size
()
<<
" nodes marked"
;
VLOG
(
3
)
<<
pdnodes2nodes_
.
size
()
<<
" nodes marked"
;
return
!
pdnodes2nodes_
.
empty
();
return
!
pdnodes2nodes_
.
empty
();
}
}
// The intermediate Nodes can only link to the nodes inside the pattern, or this
// subgraph will be droped.
void
GraphPatternDetector
::
ValidateByNodeRole
(
std
::
vector
<
GraphPatternDetector
::
subgraph_t
>*
subgraphs
)
{
std
::
vector
<
GraphPatternDetector
::
subgraph_t
>
result
;
subgraphs
->
erase
(
std
::
remove_if
(
subgraphs
->
begin
(),
subgraphs
->
end
(),
[](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
)
->
bool
{
// Collect the inputs and outputs.
std
::
unordered_set
<
Node
*>
ios
;
for
(
auto
&
item
:
subgraph
)
{
if
(
!
item
.
first
->
IsIntermediate
())
{
ios
.
insert
(
item
.
second
);
}
}
for
(
auto
&
item
:
subgraph
)
{
if
(
item
.
first
->
IsIntermediate
())
{
for
(
auto
*
x
:
item
.
second
->
inputs
)
{
if
(
!
ios
.
count
(
x
))
{
return
true
;
}
}
for
(
auto
*
x
:
item
.
second
->
outputs
)
{
if
(
!
ios
.
count
(
x
))
{
return
true
;
}
}
}
}
return
false
;
}),
subgraphs
->
end
());
}
struct
HitGroup
{
struct
HitGroup
{
std
::
unordered_map
<
PDNode
*
,
Node
*>
roles
;
std
::
unordered_map
<
PDNode
*
,
Node
*>
roles
;
...
@@ -140,6 +199,7 @@ GraphPatternDetector::DetectPatterns() {
...
@@ -140,6 +199,7 @@ GraphPatternDetector::DetectPatterns() {
// in edges of PDNodes.
// in edges of PDNodes.
for
(
const
auto
&
edge
:
pattern_
.
edges
())
{
for
(
const
auto
&
edge
:
pattern_
.
edges
())
{
VLOG
(
4
)
<<
"check "
<<
edge
.
first
->
name
()
<<
" -> "
<<
edge
.
second
->
name
();
VLOG
(
4
)
<<
"check "
<<
edge
.
first
->
name
()
<<
" -> "
<<
edge
.
second
->
name
();
// TODO(Superjomn) Fix bug here, the groups might be duplicate here.
// Each role has two PDNodes, which indicates two roles.
// Each role has two PDNodes, which indicates two roles.
// Detect two Nodes that can match these two roles and they are connected.
// Detect two Nodes that can match these two roles and they are connected.
auto
&
pre_groups
=
bi_records
[
step
%
2
];
auto
&
pre_groups
=
bi_records
[
step
%
2
];
...
@@ -149,6 +209,7 @@ GraphPatternDetector::DetectPatterns() {
...
@@ -149,6 +209,7 @@ GraphPatternDetector::DetectPatterns() {
// source -> target
// source -> target
for
(
Node
*
source
:
pdnodes2nodes_
[
edge
.
first
])
{
for
(
Node
*
source
:
pdnodes2nodes_
[
edge
.
first
])
{
for
(
Node
*
target
:
pdnodes2nodes_
[
edge
.
second
])
{
for
(
Node
*
target
:
pdnodes2nodes_
[
edge
.
second
])
{
VLOG
(
8
)
<<
"check "
<<
source
->
id
()
<<
" -- "
<<
target
->
id
();
// TODO(Superjomn) add some prune strategies.
// TODO(Superjomn) add some prune strategies.
for
(
const
auto
&
group
:
pre_groups
)
{
for
(
const
auto
&
group
:
pre_groups
)
{
HitGroup
new_group
=
group
;
HitGroup
new_group
=
group
;
...
@@ -165,6 +226,12 @@ GraphPatternDetector::DetectPatterns() {
...
@@ -165,6 +226,12 @@ GraphPatternDetector::DetectPatterns() {
}
}
}
}
VLOG
(
3
)
<<
"step "
<<
step
<<
" get records: "
<<
cur_groups
.
size
();
VLOG
(
3
)
<<
"step "
<<
step
<<
" get records: "
<<
cur_groups
.
size
();
for
(
auto
&
group
:
cur_groups
)
{
for
(
auto
&
item
:
group
.
roles
)
{
VLOG
(
4
)
<<
"node "
<<
item
.
second
->
id
()
<<
" as "
<<
item
.
first
->
name
();
}
VLOG
(
4
)
<<
"========================================================="
;
}
}
}
for
(
auto
&
group
:
bi_records
[
step
%
2
])
{
for
(
auto
&
group
:
bi_records
[
step
%
2
])
{
...
@@ -260,6 +327,118 @@ PDNode& PDNode::LinksFrom(const std::vector<PDNode*>& others) {
...
@@ -260,6 +327,118 @@ PDNode& PDNode::LinksFrom(const std::vector<PDNode*>& others) {
return
*
this
;
return
*
this
;
}
}
PDNode
*
PDNode
::
assert_is_op
()
{
asserts_
.
emplace_back
([
this
](
Node
*
x
)
{
return
x
&&
x
->
IsOp
();
});
return
this
;
}
PDNode
*
PDNode
::
assert_is_op
(
const
std
::
string
&
op_type
)
{
asserts_
.
emplace_back
([
this
,
op_type
](
Node
*
x
)
{
return
x
&&
x
->
IsOp
()
&&
x
->
Op
()
->
Type
()
==
op_type
;
});
return
this
;
}
PDNode
*
PDNode
::
assert_is_var
()
{
asserts_
.
emplace_back
([
this
](
Node
*
x
)
{
return
x
&&
x
->
IsVar
();
});
return
this
;
}
PDNode
*
PDNode
::
assert_var_not_persistable
()
{
assert_is_var
();
asserts_
.
emplace_back
([
this
](
Node
*
x
)
{
return
!
x
->
Var
()
->
Persistable
();
});
return
this
;
}
PDNode
*
PDNode
::
assert_is_persistable_var
()
{
assert_is_var
();
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
return
x
->
Var
()
->
Persistable
();
});
return
this
;
}
PDNode
*
PDNode
::
assert_is_op_nth_input
(
const
std
::
string
&
op_type
,
const
std
::
string
&
argument
,
int
nth
)
{
assert_is_var
();
assert_is_op_input
(
op_type
);
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
for
(
auto
*
op
:
x
->
outputs
)
{
if
(
IsNthInput
(
x
,
op
,
argument
,
nth
))
return
true
;
}
return
false
;
});
return
this
;
}
PDNode
*
PDNode
::
assert_is_op_nth_output
(
const
std
::
string
&
op_type
,
const
std
::
string
&
argument
,
int
nth
)
{
assert_is_var
();
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
for
(
auto
*
op
:
x
->
inputs
)
{
if
(
IsNthOutput
(
x
,
op
,
argument
,
nth
))
return
true
;
}
return
false
;
});
return
this
;
}
PDNode
*
PDNode
::
assert_is_only_input_of_op
(
const
std
::
string
&
op_type
)
{
assert_is_var
();
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
for
(
auto
*
op
:
x
->
outputs
)
{
if
(
op
&&
op
->
IsOp
()
&&
op
->
Op
()
&&
op
->
Op
()
->
Type
()
==
op_type
&&
op
->
inputs
.
size
()
==
1
)
{
return
true
;
}
}
return
false
;
});
return
this
;
}
PDNode
*
PDNode
::
assert_is_only_output_of_op
(
const
std
::
string
&
op_type
)
{
assert_is_var
();
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
for
(
auto
*
op
:
x
->
inputs
)
{
if
(
op
&&
op
->
IsOp
()
&&
op
->
Op
()
&&
op
->
Op
()
->
Type
()
==
op_type
&&
op
->
outputs
.
size
()
==
1
)
{
return
true
;
}
}
return
false
;
});
return
this
;
}
PDNode
*
PDNode
::
assert_is_op_output
(
const
std
::
string
&
op_type
)
{
assert_is_var
();
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
for
(
auto
*
op
:
x
->
inputs
)
{
if
(
op
&&
op
->
IsOp
()
&&
op
->
Op
()
&&
op
->
Op
()
->
Type
()
==
op_type
)
{
return
true
;
}
}
return
false
;
});
return
this
;
}
PDNode
*
PDNode
::
assert_is_op_input
(
const
std
::
string
&
op_type
)
{
assert_is_var
();
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
for
(
auto
*
op
:
x
->
outputs
)
{
if
(
op
&&
op
->
IsOp
()
&&
op
->
Op
()
&&
op
->
Op
()
->
Type
()
==
op_type
)
{
return
true
;
}
}
return
false
;
});
return
this
;
}
PDNode
*
PDNode
::
assert_op_has_n_inputs
(
const
std
::
string
&
op_type
,
size_t
n
)
{
assert_is_op
(
op_type
);
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
return
x
->
inputs
.
size
()
==
n
;
});
return
this
;
}
PDNode
*
PDNode
::
assert_op_has_n_outputs
(
const
std
::
string
&
op_type
,
size_t
n
)
{
assert_is_op
(
op_type
);
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
return
x
->
outputs
.
size
()
==
n
;
});
return
this
;
}
PDNode
*
PDNode
::
assert_more
(
PDNode
::
teller_t
&&
teller
)
{
asserts_
.
emplace_back
(
std
::
move
(
teller
));
return
this
;
}
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/ir/graph_pattern_detector.h
浏览文件 @
af15f6f0
...
@@ -39,14 +39,24 @@ struct PDNode {
...
@@ -39,14 +39,24 @@ struct PDNode {
// tell whether an ir::Node* is a candidation for a PDNode.
// tell whether an ir::Node* is a candidation for a PDNode.
using
teller_t
=
std
::
function
<
bool
(
Node
*
)
>
;
using
teller_t
=
std
::
function
<
bool
(
Node
*
)
>
;
enum
class
Type
{
kOp
,
kVar
};
enum
class
Type
{
kOp
,
kVar
};
enum
class
Role
{
kUnknown
,
// No role,
kInput
,
// an input and will be retained,
kOutput
,
// an output and will be retained,
kIntermediate
// will be removed after handler.
};
// this link to others
// this link to others
PDNode
&
LinksTo
(
const
std
::
vector
<
PDNode
*>&
others
);
PDNode
&
LinksTo
(
const
std
::
vector
<
PDNode
*>&
others
);
PDNode
&
LinksFrom
(
const
std
::
vector
<
PDNode
*>&
others
);
PDNode
&
LinksFrom
(
const
std
::
vector
<
PDNode
*>&
others
);
bool
Tell
(
Node
*
node
)
const
{
bool
Tell
(
Node
*
node
)
const
{
PADDLE_ENFORCE
(
teller_
!=
nullptr
,
"teller should be set for a PDNode"
);
if
(
teller_
)
return
teller_
(
node
);
return
teller_
(
node
);
for
(
auto
&
asrt
:
asserts_
)
{
if
(
!
asrt
(
node
))
return
false
;
}
return
true
;
}
}
bool
IsOp
()
const
{
return
type_
==
Type
::
kOp
;
}
bool
IsOp
()
const
{
return
type_
==
Type
::
kOp
;
}
...
@@ -54,10 +64,52 @@ struct PDNode {
...
@@ -54,10 +64,52 @@ struct PDNode {
const
std
::
string
&
name
()
const
{
return
name_
;
}
const
std
::
string
&
name
()
const
{
return
name_
;
}
PDNode
(
const
PDNode
&
)
=
delete
;
PDNode
&
operator
=
(
const
PDNode
&
)
=
delete
;
PDNode
&
operator
=
(
const
PDNode
&
)
=
delete
;
PDNode
(
const
PDNode
&
)
=
delete
;
// Mark this node is an Input of a subgraph and will be retained.
PDNode
*
AsInput
()
{
role_
=
Role
::
kInput
;
return
this
;
}
// Mark this node is an Output of a subgraph and will be retained.
PDNode
*
AsOutput
()
{
role_
=
Role
::
kOutput
;
return
this
;
}
// Mark this node will be removed, so all the links should be inside a matched
// sub-graph.
PDNode
*
AsIntermediate
()
{
role_
=
Role
::
kIntermediate
;
return
this
;
}
bool
IsIntermediate
()
const
{
return
role_
==
Role
::
kIntermediate
;
}
bool
IsInput
()
const
{
return
role_
==
Role
::
kInput
;
}
bool
IsOutput
()
const
{
return
role_
==
Role
::
kOutput
;
}
// Assertions, helper functions to simplify the pattern definition.
PDNode
*
assert_is_op
();
PDNode
*
assert_is_op
(
const
std
::
string
&
op_type
);
PDNode
*
assert_is_var
();
PDNode
*
assert_var_not_persistable
();
PDNode
*
assert_is_persistable_var
();
PDNode
*
assert_is_op_output
(
const
std
::
string
&
op_type
);
PDNode
*
assert_is_op_input
(
const
std
::
string
&
op_type
);
PDNode
*
assert_is_op_nth_input
(
const
std
::
string
&
op_type
,
const
std
::
string
&
argument
,
int
nth
);
PDNode
*
assert_is_op_nth_output
(
const
std
::
string
&
op_type
,
const
std
::
string
&
argument
,
int
nth
);
PDNode
*
assert_is_only_input_of_op
(
const
std
::
string
&
op_type
);
PDNode
*
assert_is_only_output_of_op
(
const
std
::
string
&
op_type
);
PDNode
*
assert_op_has_n_inputs
(
const
std
::
string
&
op_type
,
size_t
n
);
PDNode
*
assert_op_has_n_outputs
(
const
std
::
string
&
op_type
,
size_t
n
);
PDNode
*
assert_more
(
teller_t
&&
teller
);
private:
private:
PDNode
(
PDPattern
*
pattern
,
const
std
::
string
&
name
=
""
,
Type
type
=
Type
::
kVar
)
:
pattern_
(
pattern
),
name_
(
name
),
type_
(
type
)
{}
PDNode
(
teller_t
&&
teller
,
PDPattern
*
pattern
,
const
std
::
string
&
name
=
""
,
PDNode
(
teller_t
&&
teller
,
PDPattern
*
pattern
,
const
std
::
string
&
name
=
""
,
Type
type
=
Type
::
kVar
)
Type
type
=
Type
::
kVar
)
:
teller_
(
std
::
move
(
teller
)),
:
teller_
(
std
::
move
(
teller
)),
...
@@ -71,10 +123,13 @@ struct PDNode {
...
@@ -71,10 +123,13 @@ struct PDNode {
friend
class
PDPattern
;
friend
class
PDPattern
;
// Will removed latter.
teller_t
teller_
;
teller_t
teller_
;
std
::
vector
<
teller_t
>
asserts_
;
PDPattern
*
pattern_
;
PDPattern
*
pattern_
;
std
::
string
name_
;
std
::
string
name_
;
Type
type_
;
Type
type_
;
Role
role_
{
Role
::
kUnknown
};
};
};
/*
/*
...
@@ -87,19 +142,18 @@ struct PDNode {
...
@@ -87,19 +142,18 @@ struct PDNode {
* This pattern can be defined as with the following pseudo codes
* This pattern can be defined as with the following pseudo codes
*
*
* // Create two operator PDNodes.
* // Create two operator PDNodes.
* MUL = PDPattern.NewNode()
* MUL = PDPattern.NewNode()
.assert_is_op("mul");
* ELE = PDPattern.NewNode()
* ELE = PDPattern.NewNode()
.assert_is_op("elementwise_add");
* // Create the variable PDNodes.
* // Create the variable PDNodes.
* MUL_out = PDPattern.NewNode()
* MUL_out = PDPattern.NewNode().assert_is_op_output("mul") \
* // Add teller to define some rules that help to filter the target Nodes.
* .assert_is_op_input("elementwise_add") \
* MUL.teller = lambda(node): node->IsOp() && node->Op()->Type == "mul";
* .AsIntermediate();
* ELE.teller = lambda(node): \
* // Add relations.
* node->IsOp() && node->Op()->Type == "elementwise_add";
* MUL->LinksTo({MUL_out});
* MUL_out.teller = lambda(node): node->IsVar() && (MUL in node->inputs)
* MUL_out->LinksTo({ELE});
* && (ELE in node->outputs)
*
*
* One can add more specific
teller
s for PDNodes or edges, both the Operator
* One can add more specific
assert
s for PDNodes or edges, both the Operator
* and Variable Nodes can be ruled in PDNode.
teller
.
* and Variable Nodes can be ruled in PDNode.
assert_more(...)
.
*
*
* PDPattern can record the general patterns, such as the pattern represents
* PDPattern can record the general patterns, such as the pattern represents
* - Op in CPU -> Op in GPU -> Op in CPU, to findout the IO abnormal place.
* - Op in CPU -> Op in GPU -> Op in CPU, to findout the IO abnormal place.
...
@@ -112,7 +166,8 @@ class PDPattern {
...
@@ -112,7 +166,8 @@ class PDPattern {
void
AddEdge
(
PDNode
*
a
,
PDNode
*
b
);
void
AddEdge
(
PDNode
*
a
,
PDNode
*
b
);
PDNode
*
NewNode
(
PDNode
::
teller_t
&&
teller
,
const
std
::
string
&
name
=
NewID
());
PDNode
*
NewNode
(
PDNode
::
teller_t
&&
teller
,
const
std
::
string
&
name
=
NewID
());
PDNode
*
RetriveNode
(
const
std
::
string
&
id
)
const
;
PDNode
*
NewNode
(
const
std
::
string
&
name
=
NewID
());
PDNode
*
RetrieveNode
(
const
std
::
string
&
id
)
const
;
const
std
::
vector
<
std
::
unique_ptr
<
PDNode
>>&
nodes
()
const
{
return
nodes_
;
}
const
std
::
vector
<
std
::
unique_ptr
<
PDNode
>>&
nodes
()
const
{
return
nodes_
;
}
const
std
::
vector
<
edge_t
>&
edges
()
const
{
return
edges_
;
}
const
std
::
vector
<
edge_t
>&
edges
()
const
{
return
edges_
;
}
...
@@ -185,6 +240,9 @@ class GraphPatternDetector {
...
@@ -185,6 +240,9 @@ class GraphPatternDetector {
// Remove overlapped match subgraphs, when overlapped, keep the previous one.
// Remove overlapped match subgraphs, when overlapped, keep the previous one.
void
RemoveOverlappedMatch
(
std
::
vector
<
subgraph_t
>*
subgraphs
);
void
RemoveOverlappedMatch
(
std
::
vector
<
subgraph_t
>*
subgraphs
);
// Validate whether the intermediate nodes are linked by external nodes.
void
ValidateByNodeRole
(
std
::
vector
<
subgraph_t
>*
subgraphs
);
#ifdef PADDLE_WITH_TESTING
#ifdef PADDLE_WITH_TESTING
FRIEND_TEST
(
GraphPatternDetecter
,
MarkPDNodesInGraph
);
FRIEND_TEST
(
GraphPatternDetecter
,
MarkPDNodesInGraph
);
FRIEND_TEST
(
GraphPatternDetecter
,
DetectPatterns
);
FRIEND_TEST
(
GraphPatternDetecter
,
DetectPatterns
);
...
@@ -228,6 +286,14 @@ static bool IsNthInput(Node* var, Node* op, const std::string& argument,
...
@@ -228,6 +286,14 @@ static bool IsNthInput(Node* var, Node* op, const std::string& argument,
return
var
->
Name
()
==
op
->
Op
()
->
Input
(
argument
)[
nth
];
return
var
->
Name
()
==
op
->
Op
()
->
Input
(
argument
)[
nth
];
}
}
static
bool
IsNthOutput
(
Node
*
var
,
Node
*
op
,
const
std
::
string
&
argument
,
size_t
nth
)
{
PADDLE_ENFORCE
(
var
->
IsVar
());
PADDLE_ENFORCE
(
op
->
IsOp
());
if
(
op
->
inputs
.
size
()
<=
nth
)
return
false
;
return
var
->
Name
()
==
op
->
Op
()
->
Output
(
argument
)[
nth
];
}
static
void
GraphSafeRemoveNodes
(
Graph
*
graph
,
static
void
GraphSafeRemoveNodes
(
Graph
*
graph
,
const
std
::
unordered_set
<
const
Node
*>&
nodes
)
{
const
std
::
unordered_set
<
const
Node
*>&
nodes
)
{
for
(
auto
*
node
:
nodes
)
{
for
(
auto
*
node
:
nodes
)
{
...
...
paddle/fluid/framework/ir/graph_pattern_detector_tester.cc
浏览文件 @
af15f6f0
...
@@ -167,6 +167,39 @@ TEST(GraphPatternDetecter, MultiSubgraph) {
...
@@ -167,6 +167,39 @@ TEST(GraphPatternDetecter, MultiSubgraph) {
ASSERT_LE
(
count
,
2
);
ASSERT_LE
(
count
,
2
);
}
}
TEST
(
GraphPatternDetector
,
IntermediateCheck
)
{
ProgramDesc
program
;
Graph
graph
(
program
);
BuildGraph
(
&
graph
);
// o2->v2->o3
// o2->v2->o4
// check o2+o3 fuse, should fail because v2 also link to o4.
GraphPatternDetector
detector
;
auto
*
op2
=
detector
.
mutable_pattern
()
->
NewNode
(
[](
Node
*
x
)
{
return
x
&&
x
->
IsOp
()
&&
x
->
Name
()
==
"op2"
;
},
"op2"
);
auto
*
op3
=
detector
.
mutable_pattern
()
->
NewNode
(
[](
Node
*
x
)
{
return
x
&&
x
->
IsOp
()
&&
x
->
Name
()
==
"op3"
;
},
"op3"
);
auto
*
v2
=
detector
.
mutable_pattern
()
->
NewNode
(
[](
Node
*
x
)
{
return
x
&&
x
->
IsVar
()
&&
x
->
Name
()
==
"var2"
;
},
"var2"
)
->
AsIntermediate
();
v2
->
LinksFrom
({
op2
}).
LinksTo
({
op3
});
int
count
=
0
;
detector
(
&
graph
,
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
g
,
Graph
*
graph
)
{
++
count
;
});
EXPECT_EQ
(
count
,
0
);
count
=
0
;
v2
->
AsInput
();
detector
(
&
graph
,
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
g
,
Graph
*
graph
)
{
++
count
;
});
ASSERT_EQ
(
count
,
1
);
}
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc
浏览文件 @
af15f6f0
...
@@ -180,16 +180,16 @@ PDNode* BuildFCPattern(PDPattern* pattern, PDNode* fc_x) {
...
@@ -180,16 +180,16 @@ PDNode* BuildFCPattern(PDPattern* pattern, PDNode* fc_x) {
std
::
unique_ptr
<
ir
::
Graph
>
SeqConcatFcFusePass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
SeqConcatFcFusePass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
FusePassBase
::
Init
(
graph
.
get
());
FusePassBase
::
Init
(
"seq_concat_fc_fuse"
,
graph
.
get
());
GraphPatternDetector
detector
;
GraphPatternDetector
detector
;
auto
*
pattern
=
detector
.
mutable_pattern
();
auto
*
pattern
=
detector
.
mutable_pattern
();
auto
*
concat_out
=
BuildSeqExpandConcatPattern
(
pattern
);
auto
*
concat_out
=
BuildSeqExpandConcatPattern
(
pattern
);
BuildFCPattern
(
pattern
,
concat_out
);
BuildFCPattern
(
pattern
,
concat_out
);
#define GET_NODE(id, pattern) \
#define GET_NODE(id, pattern)
\
PADDLE_ENFORCE(subgraph.count(pattern.RetriveNode(#id)), \
PADDLE_ENFORCE(subgraph.count(pattern.Retri
e
veNode(#id)), \
"pattern has no Node called %s", #id); \
"pattern has no Node called %s", #id);
\
auto* id = subgraph.at(pattern.RetriveNode(#id)); \
auto* id = subgraph.at(pattern.Retri
e
veNode(#id)); \
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
detector
(
graph
.
get
(),
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
detector
(
graph
.
get
(),
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
...
...
paddle/fluid/inference/analysis/analyzer.cc
浏览文件 @
af15f6f0
...
@@ -93,7 +93,6 @@ class DfgPassManagerImpl final : public DfgPassManager {
...
@@ -93,7 +93,6 @@ class DfgPassManagerImpl final : public DfgPassManager {
void
AddGraphvizDebugerPass
(
Pass
*
pass
)
{
void
AddGraphvizDebugerPass
(
Pass
*
pass
)
{
auto
*
debuger_pass
=
pass
->
CreateGraphvizDebugerPass
();
auto
*
debuger_pass
=
pass
->
CreateGraphvizDebugerPass
();
if
(
debuger_pass
)
{
if
(
debuger_pass
)
{
LOG
(
INFO
)
<<
" - register debug pass ["
<<
debuger_pass
->
repr
()
<<
"]"
;
Register
(
debuger_pass
->
repr
(),
debuger_pass
);
Register
(
debuger_pass
->
repr
(),
debuger_pass
);
}
}
}
}
...
...
paddle/fluid/inference/analysis/analyzer_tester.cc
浏览文件 @
af15f6f0
...
@@ -16,10 +16,13 @@
...
@@ -16,10 +16,13 @@
#include <google/protobuf/text_format.h>
#include <google/protobuf/text_format.h>
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/inference/analysis/ut_helper.h"
#include "paddle/fluid/inference/analysis/ut_helper.h"
#include "paddle/fluid/inference/api/analysis_predictor.h"
#include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/utils/singleton.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/profiler.h"
DEFINE_string
(
infer_ditu_rnn_model
,
""
,
"model path for ditu RNN"
);
DEFINE_string
(
infer_ditu_rnn_model
,
""
,
"model path for ditu RNN"
);
...
@@ -31,6 +34,8 @@ namespace paddle {
...
@@ -31,6 +34,8 @@ namespace paddle {
namespace
inference
{
namespace
inference
{
namespace
analysis
{
namespace
analysis
{
using
namespace
framework
;
TEST
(
Analyzer
,
analysis_without_tensorrt
)
{
TEST
(
Analyzer
,
analysis_without_tensorrt
)
{
FLAGS_IA_enable_tensorrt_subgraph_engine
=
false
;
FLAGS_IA_enable_tensorrt_subgraph_engine
=
false
;
Argument
argument
;
Argument
argument
;
...
@@ -311,6 +316,20 @@ void TestDituRNNPrediction(const std::string &model_path,
...
@@ -311,6 +316,20 @@ void TestDituRNNPrediction(const std::string &model_path,
EXPECT_NEAR
(
data
[
i
],
base_data
[
i
],
1e-3
);
EXPECT_NEAR
(
data
[
i
],
base_data
[
i
],
1e-3
);
}
}
}
}
if
(
use_analysis
&&
activate_ir
)
{
AnalysisPredictor
*
analysis_predictor
=
dynamic_cast
<
AnalysisPredictor
*>
(
predictor
.
get
());
auto
&
fuse_statis
=
analysis_predictor
->
analysis_argument
()
.
Get
<
std
::
unordered_map
<
std
::
string
,
int
>>
(
framework
::
ir
::
kFuseStatisAttr
);
for
(
auto
&
item
:
fuse_statis
)
{
LOG
(
INFO
)
<<
"fused "
<<
item
.
first
<<
" "
<<
item
.
second
;
}
ASSERT_TRUE
(
fuse_statis
.
count
(
"fc"
));
EXPECT_EQ
(
fuse_statis
.
at
(
"fc"
),
1
);
}
}
}
// Directly infer with the original model.
// Directly infer with the original model.
...
...
paddle/fluid/inference/analysis/argument.h
浏览文件 @
af15f6f0
...
@@ -64,7 +64,8 @@ struct Argument {
...
@@ -64,7 +64,8 @@ struct Argument {
template
<
typename
T
>
template
<
typename
T
>
void
Set
(
const
std
::
string
&
key
,
T
*
data
)
{
void
Set
(
const
std
::
string
&
key
,
T
*
data
)
{
PADDLE_ENFORCE_NOT_NULL
(
data
);
PADDLE_ENFORCE_NOT_NULL
(
data
);
PADDLE_ENFORCE
(
!
attrs_
.
count
(
key
),
"duplicate attr called %s"
,
key
);
PADDLE_ENFORCE
(
!
attrs_
.
count
(
key
),
"Duplicate set Argument's attr [%s]"
,
key
);
attrs_
[
key
]
=
data
;
attrs_
[
key
]
=
data
;
attr_deleters_
[
key
]
=
[
data
,
key
,
this
]()
{
attr_deleters_
[
key
]
=
[
data
,
key
,
this
]()
{
VLOG
(
3
)
<<
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
;
VLOG
(
3
)
<<
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
;
...
...
paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc
浏览文件 @
af15f6f0
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h"
#include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h"
#include <vector>
#include <vector>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/proto_desc.h"
#include "paddle/fluid/framework/proto_desc.h"
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/analysis/analyzer.h"
...
@@ -34,7 +35,6 @@ std::vector<std::string> ExtractParameters(
...
@@ -34,7 +35,6 @@ std::vector<std::string> ExtractParameters(
bool
DataFlowGraphToFluidPass
::
Initialize
(
Argument
*
argument
)
{
bool
DataFlowGraphToFluidPass
::
Initialize
(
Argument
*
argument
)
{
ANALYSIS_ARGUMENT_CHECK_FIELD
(
argument
)
ANALYSIS_ARGUMENT_CHECK_FIELD
(
argument
)
ANALYSIS_ARGUMENT_CHECK_FIELD
(
argument
->
origin_program_desc
)
ANALYSIS_ARGUMENT_CHECK_FIELD
(
argument
->
origin_program_desc
)
PADDLE_ENFORCE
(
!
argument
->
transformed_program_desc
);
// The transformed_program_desc should inherit all the VarDesc and BlockDesc
// The transformed_program_desc should inherit all the VarDesc and BlockDesc
// from the original program desc. The operators of the main block(the first
// from the original program desc. The operators of the main block(the first
// block) should rewritten by data flow graph.
// block) should rewritten by data flow graph.
...
@@ -66,7 +66,7 @@ void DataFlowGraphToFluidPass::Run(DataFlowGraph *graph) {
...
@@ -66,7 +66,7 @@ void DataFlowGraphToFluidPass::Run(DataFlowGraph *graph) {
}
}
}
}
if
(
argument_
->
Has
(
"param_scope"
))
{
if
(
argument_
->
Has
(
framework
::
ir
::
kParamScopeAttr
))
{
LOG
(
WARNING
)
<<
"parameter changes in the scope takes effect"
;
LOG
(
WARNING
)
<<
"parameter changes in the scope takes effect"
;
}
}
...
...
paddle/fluid/inference/analysis/fluid_to_ir_pass.cc
浏览文件 @
af15f6f0
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#include "paddle/fluid/inference/analysis/fluid_to_ir_pass.h"
#include "paddle/fluid/inference/analysis/fluid_to_ir_pass.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/place.h"
...
@@ -26,11 +27,11 @@ void FluidToIrPass::EnableParamModify(const std::string &model_dir,
...
@@ -26,11 +27,11 @@ void FluidToIrPass::EnableParamModify(const std::string &model_dir,
const
std
::
string
&
prog_file
,
const
std
::
string
&
prog_file
,
const
std
::
string
&
param_file
)
{
const
std
::
string
&
param_file
)
{
PADDLE_ENFORCE
(
argument_
);
PADDLE_ENFORCE
(
argument_
);
argument_
->
Set
(
"param_scope"
,
new
framework
::
Scope
);
argument_
->
Set
(
framework
::
ir
::
kParamScopeAttr
,
new
framework
::
Scope
);
// Load parameters.
// Load parameters.
VLOG
(
3
)
<<
"Loading parameters from "
<<
model_dir
;
VLOG
(
3
)
<<
"Loading parameters from "
<<
model_dir
;
LoadParams
(
&
argument_
->
Get
<
framework
::
Scope
>
(
"param_scope"
),
model_dir
,
LoadParams
(
&
argument_
->
Get
<
framework
::
Scope
>
(
framework
::
ir
::
kParamScopeAttr
)
,
prog_file
,
param_file
);
model_dir
,
prog_file
,
param_file
);
}
}
bool
FluidToIrPass
::
LoadParams
(
framework
::
Scope
*
scope
,
const
std
::
string
&
dir
,
bool
FluidToIrPass
::
LoadParams
(
framework
::
Scope
*
scope
,
const
std
::
string
&
dir
,
...
...
paddle/fluid/inference/analysis/fluid_to_ir_pass.h
浏览文件 @
af15f6f0
...
@@ -14,12 +14,14 @@
...
@@ -14,12 +14,14 @@
#pragma once
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/inference/analysis/ir_pass_manager.h"
#include "paddle/fluid/inference/analysis/ir_pass_manager.h"
#include "paddle/fluid/inference/analysis/pass.h"
#include "paddle/fluid/inference/analysis/pass.h"
namespace
paddle
{
namespace
paddle
{
namespace
inference
{
namespace
inference
{
namespace
analysis
{
namespace
analysis
{
using
namespace
framework
;
static
const
char
kFluidToIrPassesAttr
[]
=
"__fluid_to_ir_passes__"
;
static
const
char
kFluidToIrPassesAttr
[]
=
"__fluid_to_ir_passes__"
;
...
@@ -45,13 +47,12 @@ class FluidToIrPass final : public DataFlowGraphPass {
...
@@ -45,13 +47,12 @@ class FluidToIrPass final : public DataFlowGraphPass {
ANALYSIS_ARGUMENT_CHECK_FIELD
(
argument
->
fluid_model_program_path
);
ANALYSIS_ARGUMENT_CHECK_FIELD
(
argument
->
fluid_model_program_path
);
// Load program.
// Load program.
auto
program
=
LoadProgramDesc
(
*
argument
->
fluid_model_program_path
);
auto
program
=
LoadProgramDesc
(
*
argument
->
fluid_model_program_path
);
argument
->
origin_program_desc
.
reset
(
argument
->
origin_program_desc
.
reset
(
new
proto
::
ProgramDesc
(
program
));
new
framework
::
proto
::
ProgramDesc
(
program
));
// Create main data flow graph.
// Create main data flow graph.
if
(
!
argument
->
main_dfg
)
{
if
(
!
argument
->
main_dfg
)
{
argument
->
main_dfg
.
reset
(
new
DataFlowGraph
);
argument
->
main_dfg
.
reset
(
new
DataFlowGraph
);
}
}
argument
->
Set
(
"ir_program_desc"
,
new
framework
::
ProgramDesc
(
program
));
argument
->
Set
(
"ir_program_desc"
,
new
ProgramDesc
(
program
));
LOG
(
INFO
)
<<
"Loading parameters"
;
LOG
(
INFO
)
<<
"Loading parameters"
;
// Load parameters to argument if needed.
// Load parameters to argument if needed.
...
@@ -73,15 +74,15 @@ class FluidToIrPass final : public DataFlowGraphPass {
...
@@ -73,15 +74,15 @@ class FluidToIrPass final : public DataFlowGraphPass {
void
Run
(
DataFlowGraph
*
graph
)
override
{
void
Run
(
DataFlowGraph
*
graph
)
override
{
// Call all the IR Passes
// Call all the IR Passes
IRPassManager
ir_passes
(
IRPassManager
ir_passes
(
argument_
->
Get
<
ProgramDesc
>
(
"ir_program_desc"
),
argument_
->
Get
<
framework
::
ProgramDesc
>
(
"ir_program_desc"
),
nullptr
);
nullptr
);
// Pass the scope from analysis to IR if needed.
// Pass the scope from analysis to IR if needed.
if
(
argument_
->
Has
(
"param_scope"
))
{
if
(
argument_
->
Has
(
ir
::
kParamScopeAttr
))
{
// Here the address is passed, attention that IR doesn't own the scope, so
// Here the address is passed, attention that IR doesn't own the scope, so
// the real scope in analysis should live during the IR phase.
// the real scope in analysis should live during the IR phase.
ir_passes
.
graph
().
Set
(
ir_passes
.
graph
().
Set
(
"param_scope"
,
new
framework
::
Scope
*
(
ir
::
kParamScopeAttr
,
&
argument_
->
Get
<
framework
::
Scope
>
(
"param_scope"
)));
new
Scope
*
(
&
argument_
->
Get
<
Scope
>
(
ir
::
kParamScopeAttr
)));
}
}
const
auto
&
ir_passes_to_apply
=
const
auto
&
ir_passes_to_apply
=
...
@@ -90,6 +91,14 @@ class FluidToIrPass final : public DataFlowGraphPass {
...
@@ -90,6 +91,14 @@ class FluidToIrPass final : public DataFlowGraphPass {
PADDLE_ENFORCE
(
argument_
->
main_dfg
.
get
());
PADDLE_ENFORCE
(
argument_
->
main_dfg
.
get
());
argument_
->
main_dfg
->
Build
(
ir_passes
.
graph
());
argument_
->
main_dfg
->
Build
(
ir_passes
.
graph
());
// inherit the arguments from ir.
if
(
ir_passes
.
graph
().
Has
(
ir
::
kFuseStatisAttr
))
{
argument_
->
Set
(
ir
::
kFuseStatisAttr
,
new
std
::
unordered_map
<
std
::
string
,
int
>
(
ir_passes
.
graph
().
Get
<
std
::
unordered_map
<
std
::
string
,
int
>>
(
ir
::
kFuseStatisAttr
)));
}
}
}
void
EnableParamModify
(
const
std
::
string
&
model_dir
,
void
EnableParamModify
(
const
std
::
string
&
model_dir
,
...
@@ -100,7 +109,7 @@ class FluidToIrPass final : public DataFlowGraphPass {
...
@@ -100,7 +109,7 @@ class FluidToIrPass final : public DataFlowGraphPass {
private:
private:
// Load parameters from a single file or from a directory.
// Load parameters from a single file or from a directory.
bool
LoadParams
(
framework
::
Scope
*
scope
,
const
std
::
string
&
dir
,
bool
LoadParams
(
Scope
*
scope
,
const
std
::
string
&
dir
,
const
std
::
string
&
prog_file
,
const
std
::
string
&
param_file
);
const
std
::
string
&
prog_file
,
const
std
::
string
&
param_file
);
private:
private:
...
...
paddle/fluid/inference/analysis/ir_pass_manager.cc
浏览文件 @
af15f6f0
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#include "paddle/fluid/inference/analysis/ir_pass_manager.h"
#include "paddle/fluid/inference/analysis/ir_pass_manager.h"
#include <string>
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
...
@@ -25,7 +26,8 @@ IRPassManager::IRPassManager(const ProgramDesc &program,
...
@@ -25,7 +26,8 @@ IRPassManager::IRPassManager(const ProgramDesc &program,
framework
::
Scope
*
scope
)
framework
::
Scope
*
scope
)
:
program_
(
program
)
{
:
program_
(
program
)
{
graph_
.
reset
(
new
framework
::
ir
::
Graph
(
program
));
graph_
.
reset
(
new
framework
::
ir
::
Graph
(
program
));
if
(
scope
)
graph_
->
Set
(
"param_scope"
,
new
framework
::
Scope
*
(
scope
));
if
(
scope
)
graph_
->
Set
(
framework
::
ir
::
kParamScopeAttr
,
new
framework
::
Scope
*
(
scope
));
}
}
void
IRPassManager
::
Apply
(
const
std
::
vector
<
std
::
string
>
&
passes
)
{
void
IRPassManager
::
Apply
(
const
std
::
vector
<
std
::
string
>
&
passes
)
{
...
...
paddle/fluid/inference/api/analysis_predictor.cc
浏览文件 @
af15f6f0
...
@@ -12,121 +12,96 @@
...
@@ -12,121 +12,96 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/inference/api/analysis_predictor.h"
#include <memory>
#include <memory>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/api/api_impl.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/utils/singleton.h"
#include "paddle/fluid/inference/utils/singleton.h"
namespace
paddle
{
namespace
paddle
{
using
inference
::
analysis
::
Argument
;
bool
AnalysisPredictor
::
Init
(
using
inference
::
Singleton
;
const
std
::
shared_ptr
<
framework
::
Scope
>&
parent_scope
)
{
using
inference
::
analysis
::
Analyzer
;
VLOG
(
3
)
<<
"Predictor::init()"
;
using
framework
::
proto
::
ProgramDesc
;
if
(
config_
.
use_gpu
)
{
place_
=
paddle
::
platform
::
CUDAPlace
(
config_
.
device
);
/* This predictor is based on the original native predictor with IR and Analysis
}
else
{
* support. It will optimize IR and Parameters in the runtime.
place_
=
paddle
::
platform
::
CPUPlace
();
* TODO(Superjomn) Replace the Navive predictor?
*/
class
AnalysisPredictor
:
public
NativePaddlePredictor
{
public:
explicit
AnalysisPredictor
(
const
NativeConfig
&
config
)
:
NativePaddlePredictor
(
config
),
config_
(
config
)
{}
bool
Init
(
const
std
::
shared_ptr
<
framework
::
Scope
>&
parent_scope
)
{
VLOG
(
3
)
<<
"Predictor::init()"
;
if
(
config_
.
use_gpu
)
{
place_
=
paddle
::
platform
::
CUDAPlace
(
config_
.
device
);
}
else
{
place_
=
paddle
::
platform
::
CPUPlace
();
}
PADDLE_ENFORCE
(
!
parent_scope
);
if
(
parent_scope
)
{
scope_
=
parent_scope
;
sub_scope_
=
&
(
parent_scope
->
NewScope
());
}
else
{
paddle
::
framework
::
InitDevices
(
false
);
scope_
.
reset
(
new
paddle
::
framework
::
Scope
());
}
executor_
.
reset
(
new
paddle
::
framework
::
Executor
(
place_
));
// Initialize the inference program
if
(
!
config_
.
model_dir
.
empty
())
{
// Parameters are saved in separate files sited in
// the specified `dirname`.
inference_program_
=
paddle
::
inference
::
Load
(
executor_
.
get
(),
scope_
.
get
(),
config_
.
model_dir
);
}
else
if
(
!
config_
.
prog_file
.
empty
()
&&
!
config_
.
param_file
.
empty
())
{
// All parameters are saved in a single file.
// The file names should be consistent with that used
// in Python API `fluid.io.save_inference_model`.
inference_program_
=
paddle
::
inference
::
Load
(
executor_
.
get
(),
scope_
.
get
(),
config_
.
prog_file
,
config_
.
param_file
);
}
else
{
LOG
(
ERROR
)
<<
"fail to load inference model."
;
return
false
;
}
OptimizeInferenceProgram
();
ctx_
=
executor_
->
Prepare
(
*
inference_program_
,
0
);
VLOG
(
5
)
<<
"to create variables"
;
PADDLE_ENFORCE
(
scope_
.
get
());
executor_
->
CreateVariables
(
*
inference_program_
,
sub_scope_
?
sub_scope_
:
scope_
.
get
(),
0
);
// Get the feed_target_names and fetch_target_names
PrepareFeedFetch
();
return
true
;
}
}
PADDLE_ENFORCE
(
!
parent_scope
);
bool
Run
(
const
std
::
vector
<
PaddleTensor
>&
inputs
,
if
(
parent_scope
)
{
std
::
vector
<
PaddleTensor
>*
output_data
,
scope_
=
parent_scope
;
int
batch_size
=
-
1
)
override
{
sub_scope_
=
&
(
parent_scope
->
NewScope
());
return
NativePaddlePredictor
::
Run
(
inputs
,
output_data
,
batch_size
);
}
else
{
paddle
::
framework
::
InitDevices
(
false
);
scope_
.
reset
(
new
paddle
::
framework
::
Scope
());
}
}
void
OptimizeInferenceProgram
()
{
executor_
.
reset
(
new
paddle
::
framework
::
Executor
(
place_
));
LOG
(
INFO
)
<<
"optimize begin"
;
FLAGS_IA_enable_ir
=
true
;
// Initialize the inference program
FLAGS_IA_enable_tensorrt_subgraph_engine
=
false
;
if
(
!
config_
.
model_dir
.
empty
())
{
FLAGS_IA_output_storage_path
=
""
;
// Don't output the model.
// Parameters are saved in separate files sited in
// Analyze inference_program
// the specified `dirname`.
Argument
argument
;
inference_program_
=
paddle
::
inference
::
Load
(
executor_
.
get
(),
scope_
.
get
(),
if
(
!
config_
.
model_dir
.
empty
())
{
config_
.
model_dir
);
argument
.
fluid_model_dir
.
reset
(
new
std
::
string
(
config_
.
model_dir
));
}
else
if
(
!
config_
.
prog_file
.
empty
()
&&
!
config_
.
param_file
.
empty
())
{
}
else
{
// All parameters are saved in a single file.
PADDLE_ENFORCE
(
// The file names should be consistent with that used
!
config_
.
param_file
.
empty
(),
// in Python API `fluid.io.save_inference_model`.
"Either model_dir or (param_file, prog_file) should be set."
);
inference_program_
=
paddle
::
inference
::
Load
(
PADDLE_ENFORCE
(
!
config_
.
prog_file
.
empty
());
executor_
.
get
(),
scope_
.
get
(),
config_
.
prog_file
,
config_
.
param_file
);
argument
.
fluid_model_program_path
.
reset
(
}
else
{
new
std
::
string
(
config_
.
prog_file
));
LOG
(
ERROR
)
<<
"fail to load inference model."
;
argument
.
fluid_model_param_path
.
reset
(
return
false
;
new
std
::
string
(
config_
.
param_file
));
}
argument
.
origin_program_desc
.
reset
(
new
ProgramDesc
(
*
inference_program_
->
Proto
()));
Singleton
<
Analyzer
>::
Global
().
Run
(
&
argument
);
CHECK
(
argument
.
transformed_program_desc
);
VLOG
(
5
)
<<
"to prepare executor"
;
// LOG(INFO) << "transformed_parogram_desc " <<
// argument.transformed_program_desc->DebugString();
inference_program_
.
reset
(
new
framework
::
ProgramDesc
(
*
argument
.
transformed_program_desc
));
PADDLE_ENFORCE
(
argument
.
Has
(
"param_scope"
));
// Update scope.
scope_
.
reset
(
argument
.
Release
<
framework
::
Scope
>
(
"param_scope"
));
LOG
(
INFO
)
<<
"optimize end =="
;
}
}
private:
OptimizeInferenceProgram
();
NativeConfig
config_
;
ctx_
=
executor_
->
Prepare
(
*
inference_program_
,
0
);
};
VLOG
(
5
)
<<
"to create variables"
;
PADDLE_ENFORCE
(
scope_
.
get
());
executor_
->
CreateVariables
(
*
inference_program_
,
sub_scope_
?
sub_scope_
:
scope_
.
get
(),
0
);
// Get the feed_target_names and fetch_target_names
PrepareFeedFetch
();
return
true
;
}
void
AnalysisPredictor
::
OptimizeInferenceProgram
()
{
LOG
(
INFO
)
<<
"optimize begin"
;
FLAGS_IA_enable_ir
=
true
;
FLAGS_IA_enable_tensorrt_subgraph_engine
=
false
;
FLAGS_IA_output_storage_path
=
""
;
// Don't output the model.
// Analyze inference_program
if
(
!
config_
.
model_dir
.
empty
())
{
argument_
.
fluid_model_dir
.
reset
(
new
std
::
string
(
config_
.
model_dir
));
}
else
{
PADDLE_ENFORCE
(
!
config_
.
param_file
.
empty
(),
"Either model_dir or (param_file, prog_file) should be set."
);
PADDLE_ENFORCE
(
!
config_
.
prog_file
.
empty
());
argument_
.
fluid_model_program_path
.
reset
(
new
std
::
string
(
config_
.
prog_file
));
argument_
.
fluid_model_param_path
.
reset
(
new
std
::
string
(
config_
.
param_file
));
}
argument_
.
origin_program_desc
.
reset
(
new
ProgramDesc
(
*
inference_program_
->
Proto
()));
Analyzer
().
Run
(
&
argument_
);
CHECK
(
argument_
.
transformed_program_desc
);
VLOG
(
5
)
<<
"to prepare executor"
;
// LOG(INFO) << "transformed_parogram_desc " <<
// argument.transformed_program_desc->DebugString();
inference_program_
.
reset
(
new
framework
::
ProgramDesc
(
*
argument_
.
transformed_program_desc
));
PADDLE_ENFORCE
(
argument_
.
Has
(
framework
::
ir
::
kParamScopeAttr
));
// Update scope.
scope_
.
reset
(
argument_
.
Release
<
framework
::
Scope
>
(
framework
::
ir
::
kParamScopeAttr
));
LOG
(
INFO
)
<<
"optimize end =="
;
}
template
<
>
template
<
>
std
::
unique_ptr
<
PaddlePredictor
>
CreatePaddlePredictor
<
std
::
unique_ptr
<
PaddlePredictor
>
CreatePaddlePredictor
<
...
...
paddle/fluid/inference/api/analysis_predictor.h
0 → 100644
浏览文件 @
af15f6f0
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/api/api_impl.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
namespace
paddle
{
using
inference
::
analysis
::
Argument
;
using
inference
::
analysis
::
Analyzer
;
using
framework
::
proto
::
ProgramDesc
;
/* This predictor is based on the original native predictor with IR and Analysis
* support. It will optimize IR and Parameters in the runtime.
* TODO(Superjomn) Replace the Navive predictor?
*/
class
AnalysisPredictor
:
public
NativePaddlePredictor
{
public:
explicit
AnalysisPredictor
(
const
NativeConfig
&
config
)
:
NativePaddlePredictor
(
config
),
config_
(
config
)
{}
bool
Init
(
const
std
::
shared_ptr
<
framework
::
Scope
>&
parent_scope
);
bool
Run
(
const
std
::
vector
<
PaddleTensor
>&
inputs
,
std
::
vector
<
PaddleTensor
>*
output_data
,
int
batch_size
=
-
1
)
override
{
return
NativePaddlePredictor
::
Run
(
inputs
,
output_data
,
batch_size
);
}
void
OptimizeInferenceProgram
();
Argument
&
analysis_argument
()
{
return
argument_
;
}
private:
NativeConfig
config_
;
Argument
argument_
;
};
}
// namespace paddle
paddle/fluid/inference/api/api_impl.cc
浏览文件 @
af15f6f0
...
@@ -62,7 +62,7 @@ void NativePaddlePredictor::PrepareFeedFetch() {
...
@@ -62,7 +62,7 @@ void NativePaddlePredictor::PrepareFeedFetch() {
for
(
auto
*
op
:
inference_program_
->
Block
(
0
).
AllOps
())
{
for
(
auto
*
op
:
inference_program_
->
Block
(
0
).
AllOps
())
{
if
(
op
->
Type
()
==
"feed"
)
{
if
(
op
->
Type
()
==
"feed"
)
{
int
idx
=
boost
::
get
<
int
>
(
op
->
GetAttr
(
"col"
));
int
idx
=
boost
::
get
<
int
>
(
op
->
GetAttr
(
"col"
));
if
(
feeds_
.
size
()
<=
idx
)
{
if
(
feeds_
.
size
()
<=
static_cast
<
size_t
>
(
idx
)
)
{
feeds_
.
resize
(
idx
+
1
);
feeds_
.
resize
(
idx
+
1
);
}
}
feeds_
[
idx
]
=
op
;
feeds_
[
idx
]
=
op
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录