Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
af15f6f0
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
af15f6f0
编写于
8月 31, 2018
作者:
Y
Yan Chunwei
提交者:
GitHub
8月 31, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fea/refine fuse (#13076)
上级
819af27d
变更
20
显示空白变更内容
内联
并排
Showing
20 changed file
with
545 addition
and
226 deletion
+545
-226
paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc
paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc
+1
-1
paddle/fluid/framework/ir/fc_fuse_pass.cc
paddle/fluid/framework/ir/fc_fuse_pass.cc
+37
-75
paddle/fluid/framework/ir/fc_fuse_pass.h
paddle/fluid/framework/ir/fc_fuse_pass.h
+2
-1
paddle/fluid/framework/ir/fc_fuse_pass_tester.cc
paddle/fluid/framework/ir/fc_fuse_pass_tester.cc
+7
-2
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
+4
-4
paddle/fluid/framework/ir/fuse_pass_base.h
paddle/fluid/framework/ir/fuse_pass_base.h
+18
-2
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+180
-1
paddle/fluid/framework/ir/graph_pattern_detector.h
paddle/fluid/framework/ir/graph_pattern_detector.h
+81
-15
paddle/fluid/framework/ir/graph_pattern_detector_tester.cc
paddle/fluid/framework/ir/graph_pattern_detector_tester.cc
+33
-0
paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc
paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc
+5
-5
paddle/fluid/inference/analysis/analyzer.cc
paddle/fluid/inference/analysis/analyzer.cc
+0
-1
paddle/fluid/inference/analysis/analyzer_tester.cc
paddle/fluid/inference/analysis/analyzer_tester.cc
+19
-0
paddle/fluid/inference/analysis/argument.h
paddle/fluid/inference/analysis/argument.h
+2
-1
paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc
...fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc
+2
-2
paddle/fluid/inference/analysis/fluid_to_ir_pass.cc
paddle/fluid/inference/analysis/fluid_to_ir_pass.cc
+4
-3
paddle/fluid/inference/analysis/fluid_to_ir_pass.h
paddle/fluid/inference/analysis/fluid_to_ir_pass.h
+18
-9
paddle/fluid/inference/analysis/ir_pass_manager.cc
paddle/fluid/inference/analysis/ir_pass_manager.cc
+3
-1
paddle/fluid/inference/api/analysis_predictor.cc
paddle/fluid/inference/api/analysis_predictor.cc
+77
-102
paddle/fluid/inference/api/analysis_predictor.h
paddle/fluid/inference/api/analysis_predictor.h
+51
-0
paddle/fluid/inference/api/api_impl.cc
paddle/fluid/inference/api/api_impl.cc
+1
-1
未找到文件。
paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc
浏览文件 @
af15f6f0
...
...
@@ -59,7 +59,7 @@ void FindWhileOp(Graph* graph) {
auto
handle
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
auto
*
while_pat_node
=
gpd
.
pattern
().
RetriveNode
(
"while"
);
auto
*
while_pat_node
=
gpd
.
pattern
().
Retri
e
veNode
(
"while"
);
auto
*
while_node
=
subgraph
.
at
(
while_pat_node
);
marked_nodes
.
insert
(
while_node
);
};
...
...
paddle/fluid/framework/ir/fc_fuse_pass.cc
浏览文件 @
af15f6f0
...
...
@@ -31,77 +31,34 @@ bool VarOutLinksToOp(Node* node, const std::string& op_type) {
}
void
BuildFCPattern
(
PDPattern
*
pattern
)
{
// make sure the selected MUL op has one input argument is a parameter.
auto
*
mul_parameter_var
=
pattern
->
NewNode
(
[](
Node
*
node
)
{
return
node
->
IsVar
()
&&
node
->
outputs
.
size
()
==
1UL
&&
node
->
outputs
.
front
()
->
Op
()
->
Type
()
==
"mul"
&&
node
->
Var
()
&&
node
->
Var
()
->
Persistable
();
// check is a parameter
},
"mul_weight"
/*name*/
);
auto
*
mul_tmp_input_var
=
pattern
->
NewNode
(
[](
Node
*
node
)
{
bool
result
=
node
->
IsVar
()
&&
node
->
outputs
.
size
()
>=
1UL
&&
node
->
Var
()
&&
!
node
->
Var
()
->
Persistable
();
// this input is not an parameter.
if
(
!
result
)
return
false
;
// check whether one output is MUL op.
for
(
auto
*
op
:
node
->
outputs
)
{
if
(
op
->
IsOp
()
&&
op
->
Op
()
->
Type
()
==
"mul"
)
return
true
;
}
return
false
;
},
"mul_tmp_var"
/*name*/
);
// select a MUL op
auto
*
mul_op
=
pattern
->
NewNode
(
[](
Node
*
node
)
{
return
node
->
IsOp
()
&&
// start from an Op
node
->
Op
()
->
Type
()
==
"mul"
;
// type is mul
// the output should be consumed only by one element_add, that check
// leaves in a Var PDNode.
},
"mul"
/*name*/
);
// make sure the MUL op's output has only one consumer and links to an
// ELEMENTWISE_ADD op.
auto
*
mul_out_var
=
pattern
->
NewNode
(
[](
Node
*
node
)
{
return
node
->
IsVar
()
&&
// starts from a Var
node
->
outputs
.
size
()
==
1UL
&&
// only has one consumer
node
->
outputs
.
front
()
->
IsOp
()
&&
// check basic logic
node
->
Var
()
&&
// not a ControlDepVar
node
->
outputs
.
front
()
->
Op
()
->
Type
()
==
"elementwise_add"
;
// a very strong validation
},
"mul_out"
);
// this check is not essential, just to make the corresponding variable Node
// retrival easier.
auto
*
elementwise_add_tmp_var
=
pattern
->
NewNode
(
[](
Node
*
node
)
{
return
node
->
IsVar
()
&&
node
->
outputs
.
size
()
>=
1UL
&&
node
->
Var
()
&&
VarOutLinksToOp
(
node
,
"elementwise_add"
);
},
"elementwise_add_tmpvar"
);
// select an ELEMENTWISE_ADD op
auto
*
elementwise_add_op
=
pattern
->
NewNode
(
[](
Node
*
node
)
{
return
node
->
IsOp
()
&&
node
->
Op
()
->
Type
()
==
"elementwise_add"
;
},
"elementwise_add"
/*name*/
);
// get the ELEMENTWISE_ADD op's output
auto
*
elementwise_add_out_var
=
pattern
->
NewNode
(
[](
Node
*
node
)
{
return
node
->
IsVar
()
&&
node
->
inputs
.
size
()
==
1UL
&&
node
->
Var
()
&&
node
->
inputs
.
front
()
->
Op
()
->
Type
()
==
"elementwise_add"
;
},
"elementwise_add_out"
);
mul_op
->
LinksFrom
({
mul_parameter_var
,
mul_tmp_input_var
})
.
LinksTo
({
mul_out_var
});
// Create Operators
auto
*
mul_op
=
pattern
->
NewNode
(
"mul"
)
->
assert_is_op
(
"mul"
);
auto
*
elementwise_add_op
=
pattern
->
NewNode
(
"elementwise_add"
)
->
assert_is_op
(
"elementwise_add"
);
// Create variables
// w
auto
*
mul_weight_var
=
pattern
->
NewNode
(
"mul_weight"
)
->
AsInput
()
->
assert_is_op_nth_input
(
"mul"
,
"Y"
,
0
);
// x
auto
*
mul_tmp_var
=
pattern
->
NewNode
(
"mul_tmp_var"
)
->
AsInput
()
->
assert_is_op_nth_input
(
"mul"
,
"X"
,
0
);
// intermediate variable, will be removed in the IR after fuse.
auto
*
mul_out_var
=
pattern
->
NewNode
(
"mul_out"
)
->
AsIntermediate
()
->
assert_is_only_output_of_op
(
"mul"
)
->
assert_is_op_input
(
"elementwise_add"
);
// bias
auto
*
elementwise_add_tmp_var
=
pattern
->
NewNode
(
"elementwise_add_tmpvar"
)
->
assert_is_op_input
(
"elementwise_add"
)
->
AsInput
();
// output
auto
*
elementwise_add_out_var
=
pattern
->
NewNode
(
"elementwise_add_out"
)
->
AsOutput
()
->
assert_is_op_output
(
"elementwise_add"
);
mul_op
->
LinksFrom
({
mul_weight_var
,
mul_tmp_var
}).
LinksTo
({
mul_out_var
});
elementwise_add_op
->
LinksFrom
({
mul_out_var
,
elementwise_add_tmp_var
})
.
LinksTo
({
elementwise_add_out_var
});
}
...
...
@@ -120,6 +77,7 @@ bool LinksReplace(std::vector<Node*>* links, Node* from, Node* to) {
std
::
unique_ptr
<
ir
::
Graph
>
FCFusePass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
PADDLE_ENFORCE
(
graph
.
get
());
FusePassBase
::
Init
(
"fc"
,
graph
.
get
());
std
::
unordered_set
<
Node
*>
nodes2delete
;
...
...
@@ -127,11 +85,12 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
BuildFCPattern
(
gpd
.
mutable_pattern
());
#define GET_NODE(id) \
PADDLE_ENFORCE(subgraph.count(gpd.pattern().RetriveNode(#id)), \
PADDLE_ENFORCE(subgraph.count(gpd.pattern().Retri
e
veNode(#id)), \
"pattern has no Node called %s", #id); \
auto* id = subgraph.at(gpd.pattern().RetriveNode(#id)); \
auto* id = subgraph.at(gpd.pattern().Retri
e
veNode(#id)); \
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
int
found_fc_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
VLOG
(
4
)
<<
"handle FC fuse"
;
...
...
@@ -176,10 +135,13 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
graph
->
RemoveNode
(
mul
);
graph
->
RemoveNode
(
elementwise_add
);
graph
->
RemoveNode
(
mul_out
);
// tmp variable
found_fc_count
++
;
};
gpd
(
graph
.
get
(),
handler
);
AddStatis
(
found_fc_count
);
return
graph
;
}
...
...
paddle/fluid/framework/ir/fc_fuse_pass.h
浏览文件 @
af15f6f0
...
...
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
...
...
@@ -23,7 +24,7 @@ namespace ir {
/*
* Fuse the MUL and ELEMENTWISE_ADD to a FCOp.
*/
class
FCFusePass
:
public
Pass
{
class
FCFusePass
:
public
FusePassBase
{
public:
virtual
~
FCFusePass
()
{}
...
...
paddle/fluid/framework/ir/fc_fuse_pass_tester.cc
浏览文件 @
af15f6f0
...
...
@@ -25,8 +25,13 @@ void SetOp(ProgramDesc* prog, const std::string& type,
const
std
::
vector
<
std
::
string
>&
outputs
)
{
auto
*
op
=
prog
->
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
type
);
op
->
SetInput
(
"Xs"
,
inputs
);
op
->
SetOutput
(
"Ys"
,
outputs
);
if
(
type
==
"mul"
)
{
op
->
SetInput
(
"X"
,
{
inputs
[
0
]});
op
->
SetInput
(
"Y"
,
{
inputs
[
1
]});
}
else
if
(
type
==
"elementwise_add"
)
{
op
->
SetInput
(
"X"
,
inputs
);
}
op
->
SetOutput
(
"Out"
,
outputs
);
}
// a->OP0->b
...
...
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
浏览文件 @
af15f6f0
...
...
@@ -36,7 +36,7 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
auto
*
id
=
subgraph
.
at
(
gpd
.
pattern
().
RetriveNode
(
"any_node"
));
auto
*
id
=
subgraph
.
at
(
gpd
.
pattern
().
Retri
e
veNode
(
"any_node"
));
marked_nodes
.
insert
(
id
);
};
gpd
(
graph
.
get
(),
handler
);
...
...
@@ -64,9 +64,9 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
#undef GET_NODE
#undef SET_IN
LOG
(
INFO
)
<<
"hidden_n: "
<<
hidden_n
->
Name
();
LOG
(
INFO
)
<<
"cell: "
<<
cell_n
->
Name
();
LOG
(
INFO
)
<<
"xx: "
<<
xx_n
->
Name
();
VLOG
(
4
)
<<
"hidden_n: "
<<
hidden_n
->
Name
();
VLOG
(
4
)
<<
"cell: "
<<
cell_n
->
Name
();
VLOG
(
4
)
<<
"xx: "
<<
xx_n
->
Name
();
op_desc
.
SetInput
(
"H0"
,
{});
op_desc
.
SetInput
(
"C0"
,
{});
...
...
paddle/fluid/framework/ir/fuse_pass_base.h
浏览文件 @
af15f6f0
...
...
@@ -22,21 +22,37 @@ namespace paddle {
namespace
framework
{
namespace
ir
{
static
const
char
kParamScopeAttr
[]
=
"param_scope"
;
static
const
char
kParamScopeAttr
[]
=
"__param_scope__"
;
static
const
char
kFuseStatisAttr
[]
=
"__fuse_statis__"
;
class
FusePassBase
:
public
Pass
{
public:
void
Init
(
Graph
*
graph
)
const
{
graph_
=
graph
;
}
void
Init
(
const
std
::
string
&
repr
,
Graph
*
graph
)
const
{
repr_
=
repr
;
graph_
=
graph
;
}
Scope
*
param_scope
()
const
{
PADDLE_ENFORCE
(
graph_
->
Has
(
kParamScopeAttr
));
return
graph_
->
Get
<
framework
::
Scope
*>
(
kParamScopeAttr
);
}
void
AddStatis
(
int
count_of_fused
)
const
{
PADDLE_ENFORCE
(
graph_
);
PADDLE_ENFORCE
(
!
repr_
.
empty
());
if
(
!
graph_
->
Has
(
kFuseStatisAttr
))
{
graph_
->
Set
(
kFuseStatisAttr
,
new
std
::
unordered_map
<
std
::
string
,
int
>
);
}
auto
&
info
=
graph_
->
Get
<
std
::
unordered_map
<
std
::
string
,
int
>>
(
kFuseStatisAttr
);
info
[
repr_
]
=
count_of_fused
;
}
virtual
~
FusePassBase
()
{}
protected:
mutable
Graph
*
graph_
;
mutable
std
::
string
repr_
;
};
}
// namespace ir
...
...
paddle/fluid/framework/ir/graph_pattern_detector.cc
浏览文件 @
af15f6f0
...
...
@@ -27,6 +27,19 @@ namespace ir {
size_t
PDPattern
::
id_
=
0UL
;
PDNode
*
PDPattern
::
NewNode
(
const
std
::
string
&
name
)
{
if
(
!
name
.
empty
())
{
PADDLE_ENFORCE_EQ
(
node_map_
.
count
(
name
),
0
,
"PDNode's name should be unique, get duplicate [%s]"
,
name
);
}
nodes_
.
emplace_back
(
new
PDNode
(
this
,
name
));
auto
*
cur
=
nodes_
.
back
().
get
();
node_map_
[
name
]
=
cur
;
return
cur
;
}
PDNode
*
PDPattern
::
NewNode
(
PDNode
::
teller_t
&&
teller
,
const
std
::
string
&
name
)
{
if
(
!
name
.
empty
())
{
PADDLE_ENFORCE_EQ
(
node_map_
.
count
(
name
),
0
,
...
...
@@ -40,7 +53,7 @@ PDNode* PDPattern::NewNode(PDNode::teller_t&& teller, const std::string& name) {
return
cur
;
}
PDNode
*
PDPattern
::
RetriveNode
(
const
std
::
string
&
id
)
const
{
PDNode
*
PDPattern
::
Retri
e
veNode
(
const
std
::
string
&
id
)
const
{
auto
it
=
node_map_
.
find
(
id
);
if
(
it
==
node_map_
.
end
())
{
return
nullptr
;
...
...
@@ -62,7 +75,9 @@ void GraphPatternDetector::operator()(Graph* graph,
auto
subgraphs
=
DetectPatterns
();
UniquePatterns
(
&
subgraphs
);
RemoveOverlappedMatch
(
&
subgraphs
);
ValidateByNodeRole
(
&
subgraphs
);
if
(
subgraphs
.
empty
())
return
;
LOG
(
INFO
)
<<
"detect "
<<
subgraphs
.
size
()
<<
" subgraph matches the pattern"
;
int
id
=
0
;
for
(
auto
&
g
:
subgraphs
)
{
...
...
@@ -83,10 +98,54 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) {
}
}
}
// Check to early stop if some PDNode can't find matched Node.
for
(
auto
&
pdnode
:
pattern_
.
nodes
())
{
if
(
!
pdnodes2nodes_
.
count
(
pdnode
.
get
()))
{
VLOG
(
4
)
<<
pdnode
->
name
()
<<
" can't find matched Node, early stop"
;
return
false
;
}
}
VLOG
(
3
)
<<
pdnodes2nodes_
.
size
()
<<
" nodes marked"
;
return
!
pdnodes2nodes_
.
empty
();
}
// The intermediate Nodes can only link to the nodes inside the pattern, or this
// subgraph will be droped.
void
GraphPatternDetector
::
ValidateByNodeRole
(
std
::
vector
<
GraphPatternDetector
::
subgraph_t
>*
subgraphs
)
{
std
::
vector
<
GraphPatternDetector
::
subgraph_t
>
result
;
subgraphs
->
erase
(
std
::
remove_if
(
subgraphs
->
begin
(),
subgraphs
->
end
(),
[](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
)
->
bool
{
// Collect the inputs and outputs.
std
::
unordered_set
<
Node
*>
ios
;
for
(
auto
&
item
:
subgraph
)
{
if
(
!
item
.
first
->
IsIntermediate
())
{
ios
.
insert
(
item
.
second
);
}
}
for
(
auto
&
item
:
subgraph
)
{
if
(
item
.
first
->
IsIntermediate
())
{
for
(
auto
*
x
:
item
.
second
->
inputs
)
{
if
(
!
ios
.
count
(
x
))
{
return
true
;
}
}
for
(
auto
*
x
:
item
.
second
->
outputs
)
{
if
(
!
ios
.
count
(
x
))
{
return
true
;
}
}
}
}
return
false
;
}),
subgraphs
->
end
());
}
struct
HitGroup
{
std
::
unordered_map
<
PDNode
*
,
Node
*>
roles
;
...
...
@@ -140,6 +199,7 @@ GraphPatternDetector::DetectPatterns() {
// in edges of PDNodes.
for
(
const
auto
&
edge
:
pattern_
.
edges
())
{
VLOG
(
4
)
<<
"check "
<<
edge
.
first
->
name
()
<<
" -> "
<<
edge
.
second
->
name
();
// TODO(Superjomn) Fix bug here, the groups might be duplicate here.
// Each role has two PDNodes, which indicates two roles.
// Detect two Nodes that can match these two roles and they are connected.
auto
&
pre_groups
=
bi_records
[
step
%
2
];
...
...
@@ -149,6 +209,7 @@ GraphPatternDetector::DetectPatterns() {
// source -> target
for
(
Node
*
source
:
pdnodes2nodes_
[
edge
.
first
])
{
for
(
Node
*
target
:
pdnodes2nodes_
[
edge
.
second
])
{
VLOG
(
8
)
<<
"check "
<<
source
->
id
()
<<
" -- "
<<
target
->
id
();
// TODO(Superjomn) add some prune strategies.
for
(
const
auto
&
group
:
pre_groups
)
{
HitGroup
new_group
=
group
;
...
...
@@ -165,6 +226,12 @@ GraphPatternDetector::DetectPatterns() {
}
}
VLOG
(
3
)
<<
"step "
<<
step
<<
" get records: "
<<
cur_groups
.
size
();
for
(
auto
&
group
:
cur_groups
)
{
for
(
auto
&
item
:
group
.
roles
)
{
VLOG
(
4
)
<<
"node "
<<
item
.
second
->
id
()
<<
" as "
<<
item
.
first
->
name
();
}
VLOG
(
4
)
<<
"========================================================="
;
}
}
for
(
auto
&
group
:
bi_records
[
step
%
2
])
{
...
...
@@ -260,6 +327,118 @@ PDNode& PDNode::LinksFrom(const std::vector<PDNode*>& others) {
return
*
this
;
}
PDNode
*
PDNode
::
assert_is_op
()
{
asserts_
.
emplace_back
([
this
](
Node
*
x
)
{
return
x
&&
x
->
IsOp
();
});
return
this
;
}
PDNode
*
PDNode
::
assert_is_op
(
const
std
::
string
&
op_type
)
{
asserts_
.
emplace_back
([
this
,
op_type
](
Node
*
x
)
{
return
x
&&
x
->
IsOp
()
&&
x
->
Op
()
->
Type
()
==
op_type
;
});
return
this
;
}
PDNode
*
PDNode
::
assert_is_var
()
{
asserts_
.
emplace_back
([
this
](
Node
*
x
)
{
return
x
&&
x
->
IsVar
();
});
return
this
;
}
PDNode
*
PDNode
::
assert_var_not_persistable
()
{
assert_is_var
();
asserts_
.
emplace_back
([
this
](
Node
*
x
)
{
return
!
x
->
Var
()
->
Persistable
();
});
return
this
;
}
PDNode
*
PDNode
::
assert_is_persistable_var
()
{
assert_is_var
();
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
return
x
->
Var
()
->
Persistable
();
});
return
this
;
}
PDNode
*
PDNode
::
assert_is_op_nth_input
(
const
std
::
string
&
op_type
,
const
std
::
string
&
argument
,
int
nth
)
{
assert_is_var
();
assert_is_op_input
(
op_type
);
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
for
(
auto
*
op
:
x
->
outputs
)
{
if
(
IsNthInput
(
x
,
op
,
argument
,
nth
))
return
true
;
}
return
false
;
});
return
this
;
}
PDNode
*
PDNode
::
assert_is_op_nth_output
(
const
std
::
string
&
op_type
,
const
std
::
string
&
argument
,
int
nth
)
{
assert_is_var
();
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
for
(
auto
*
op
:
x
->
inputs
)
{
if
(
IsNthOutput
(
x
,
op
,
argument
,
nth
))
return
true
;
}
return
false
;
});
return
this
;
}
PDNode
*
PDNode
::
assert_is_only_input_of_op
(
const
std
::
string
&
op_type
)
{
assert_is_var
();
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
for
(
auto
*
op
:
x
->
outputs
)
{
if
(
op
&&
op
->
IsOp
()
&&
op
->
Op
()
&&
op
->
Op
()
->
Type
()
==
op_type
&&
op
->
inputs
.
size
()
==
1
)
{
return
true
;
}
}
return
false
;
});
return
this
;
}
PDNode
*
PDNode
::
assert_is_only_output_of_op
(
const
std
::
string
&
op_type
)
{
assert_is_var
();
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
for
(
auto
*
op
:
x
->
inputs
)
{
if
(
op
&&
op
->
IsOp
()
&&
op
->
Op
()
&&
op
->
Op
()
->
Type
()
==
op_type
&&
op
->
outputs
.
size
()
==
1
)
{
return
true
;
}
}
return
false
;
});
return
this
;
}
PDNode
*
PDNode
::
assert_is_op_output
(
const
std
::
string
&
op_type
)
{
assert_is_var
();
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
for
(
auto
*
op
:
x
->
inputs
)
{
if
(
op
&&
op
->
IsOp
()
&&
op
->
Op
()
&&
op
->
Op
()
->
Type
()
==
op_type
)
{
return
true
;
}
}
return
false
;
});
return
this
;
}
PDNode
*
PDNode
::
assert_is_op_input
(
const
std
::
string
&
op_type
)
{
assert_is_var
();
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
for
(
auto
*
op
:
x
->
outputs
)
{
if
(
op
&&
op
->
IsOp
()
&&
op
->
Op
()
&&
op
->
Op
()
->
Type
()
==
op_type
)
{
return
true
;
}
}
return
false
;
});
return
this
;
}
PDNode
*
PDNode
::
assert_op_has_n_inputs
(
const
std
::
string
&
op_type
,
size_t
n
)
{
assert_is_op
(
op_type
);
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
return
x
->
inputs
.
size
()
==
n
;
});
return
this
;
}
PDNode
*
PDNode
::
assert_op_has_n_outputs
(
const
std
::
string
&
op_type
,
size_t
n
)
{
assert_is_op
(
op_type
);
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
return
x
->
outputs
.
size
()
==
n
;
});
return
this
;
}
PDNode
*
PDNode
::
assert_more
(
PDNode
::
teller_t
&&
teller
)
{
asserts_
.
emplace_back
(
std
::
move
(
teller
));
return
this
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/graph_pattern_detector.h
浏览文件 @
af15f6f0
...
...
@@ -39,14 +39,24 @@ struct PDNode {
// tell whether an ir::Node* is a candidation for a PDNode.
using
teller_t
=
std
::
function
<
bool
(
Node
*
)
>
;
enum
class
Type
{
kOp
,
kVar
};
enum
class
Role
{
kUnknown
,
// No role,
kInput
,
// an input and will be retained,
kOutput
,
// an output and will be retained,
kIntermediate
// will be removed after handler.
};
// this link to others
PDNode
&
LinksTo
(
const
std
::
vector
<
PDNode
*>&
others
);
PDNode
&
LinksFrom
(
const
std
::
vector
<
PDNode
*>&
others
);
bool
Tell
(
Node
*
node
)
const
{
PADDLE_ENFORCE
(
teller_
!=
nullptr
,
"teller should be set for a PDNode"
);
return
teller_
(
node
);
if
(
teller_
)
return
teller_
(
node
);
for
(
auto
&
asrt
:
asserts_
)
{
if
(
!
asrt
(
node
))
return
false
;
}
return
true
;
}
bool
IsOp
()
const
{
return
type_
==
Type
::
kOp
;
}
...
...
@@ -54,10 +64,52 @@ struct PDNode {
const
std
::
string
&
name
()
const
{
return
name_
;
}
PDNode
(
const
PDNode
&
)
=
delete
;
PDNode
&
operator
=
(
const
PDNode
&
)
=
delete
;
PDNode
(
const
PDNode
&
)
=
delete
;
// Mark this node is an Input of a subgraph and will be retained.
PDNode
*
AsInput
()
{
role_
=
Role
::
kInput
;
return
this
;
}
// Mark this node is an Output of a subgraph and will be retained.
PDNode
*
AsOutput
()
{
role_
=
Role
::
kOutput
;
return
this
;
}
// Mark this node will be removed, so all the links should be inside a matched
// sub-graph.
PDNode
*
AsIntermediate
()
{
role_
=
Role
::
kIntermediate
;
return
this
;
}
bool
IsIntermediate
()
const
{
return
role_
==
Role
::
kIntermediate
;
}
bool
IsInput
()
const
{
return
role_
==
Role
::
kInput
;
}
bool
IsOutput
()
const
{
return
role_
==
Role
::
kOutput
;
}
// Assertions, helper functions to simplify the pattern definition.
PDNode
*
assert_is_op
();
PDNode
*
assert_is_op
(
const
std
::
string
&
op_type
);
PDNode
*
assert_is_var
();
PDNode
*
assert_var_not_persistable
();
PDNode
*
assert_is_persistable_var
();
PDNode
*
assert_is_op_output
(
const
std
::
string
&
op_type
);
PDNode
*
assert_is_op_input
(
const
std
::
string
&
op_type
);
PDNode
*
assert_is_op_nth_input
(
const
std
::
string
&
op_type
,
const
std
::
string
&
argument
,
int
nth
);
PDNode
*
assert_is_op_nth_output
(
const
std
::
string
&
op_type
,
const
std
::
string
&
argument
,
int
nth
);
PDNode
*
assert_is_only_input_of_op
(
const
std
::
string
&
op_type
);
PDNode
*
assert_is_only_output_of_op
(
const
std
::
string
&
op_type
);
PDNode
*
assert_op_has_n_inputs
(
const
std
::
string
&
op_type
,
size_t
n
);
PDNode
*
assert_op_has_n_outputs
(
const
std
::
string
&
op_type
,
size_t
n
);
PDNode
*
assert_more
(
teller_t
&&
teller
);
private:
PDNode
(
PDPattern
*
pattern
,
const
std
::
string
&
name
=
""
,
Type
type
=
Type
::
kVar
)
:
pattern_
(
pattern
),
name_
(
name
),
type_
(
type
)
{}
PDNode
(
teller_t
&&
teller
,
PDPattern
*
pattern
,
const
std
::
string
&
name
=
""
,
Type
type
=
Type
::
kVar
)
:
teller_
(
std
::
move
(
teller
)),
...
...
@@ -71,10 +123,13 @@ struct PDNode {
friend
class
PDPattern
;
// Will removed latter.
teller_t
teller_
;
std
::
vector
<
teller_t
>
asserts_
;
PDPattern
*
pattern_
;
std
::
string
name_
;
Type
type_
;
Role
role_
{
Role
::
kUnknown
};
};
/*
...
...
@@ -87,19 +142,18 @@ struct PDNode {
* This pattern can be defined as with the following pseudo codes
*
* // Create two operator PDNodes.
* MUL = PDPattern.NewNode()
* ELE = PDPattern.NewNode()
* MUL = PDPattern.NewNode()
.assert_is_op("mul");
* ELE = PDPattern.NewNode()
.assert_is_op("elementwise_add");
* // Create the variable PDNodes.
* MUL_out = PDPattern.NewNode()
* // Add teller to define some rules that help to filter the target Nodes.
* MUL.teller = lambda(node): node->IsOp() && node->Op()->Type == "mul";
* ELE.teller = lambda(node): \
* node->IsOp() && node->Op()->Type == "elementwise_add";
* MUL_out.teller = lambda(node): node->IsVar() && (MUL in node->inputs)
* && (ELE in node->outputs)
* MUL_out = PDPattern.NewNode().assert_is_op_output("mul") \
* .assert_is_op_input("elementwise_add") \
* .AsIntermediate();
* // Add relations.
* MUL->LinksTo({MUL_out});
* MUL_out->LinksTo({ELE});
*
* One can add more specific
teller
s for PDNodes or edges, both the Operator
* and Variable Nodes can be ruled in PDNode.
teller
.
* One can add more specific
assert
s for PDNodes or edges, both the Operator
* and Variable Nodes can be ruled in PDNode.
assert_more(...)
.
*
* PDPattern can record the general patterns, such as the pattern represents
* - Op in CPU -> Op in GPU -> Op in CPU, to findout the IO abnormal place.
...
...
@@ -112,7 +166,8 @@ class PDPattern {
void
AddEdge
(
PDNode
*
a
,
PDNode
*
b
);
PDNode
*
NewNode
(
PDNode
::
teller_t
&&
teller
,
const
std
::
string
&
name
=
NewID
());
PDNode
*
RetriveNode
(
const
std
::
string
&
id
)
const
;
PDNode
*
NewNode
(
const
std
::
string
&
name
=
NewID
());
PDNode
*
RetrieveNode
(
const
std
::
string
&
id
)
const
;
const
std
::
vector
<
std
::
unique_ptr
<
PDNode
>>&
nodes
()
const
{
return
nodes_
;
}
const
std
::
vector
<
edge_t
>&
edges
()
const
{
return
edges_
;
}
...
...
@@ -185,6 +240,9 @@ class GraphPatternDetector {
// Remove overlapped match subgraphs, when overlapped, keep the previous one.
void
RemoveOverlappedMatch
(
std
::
vector
<
subgraph_t
>*
subgraphs
);
// Validate whether the intermediate nodes are linked by external nodes.
void
ValidateByNodeRole
(
std
::
vector
<
subgraph_t
>*
subgraphs
);
#ifdef PADDLE_WITH_TESTING
FRIEND_TEST
(
GraphPatternDetecter
,
MarkPDNodesInGraph
);
FRIEND_TEST
(
GraphPatternDetecter
,
DetectPatterns
);
...
...
@@ -228,6 +286,14 @@ static bool IsNthInput(Node* var, Node* op, const std::string& argument,
return
var
->
Name
()
==
op
->
Op
()
->
Input
(
argument
)[
nth
];
}
static
bool
IsNthOutput
(
Node
*
var
,
Node
*
op
,
const
std
::
string
&
argument
,
size_t
nth
)
{
PADDLE_ENFORCE
(
var
->
IsVar
());
PADDLE_ENFORCE
(
op
->
IsOp
());
if
(
op
->
inputs
.
size
()
<=
nth
)
return
false
;
return
var
->
Name
()
==
op
->
Op
()
->
Output
(
argument
)[
nth
];
}
static
void
GraphSafeRemoveNodes
(
Graph
*
graph
,
const
std
::
unordered_set
<
const
Node
*>&
nodes
)
{
for
(
auto
*
node
:
nodes
)
{
...
...
paddle/fluid/framework/ir/graph_pattern_detector_tester.cc
浏览文件 @
af15f6f0
...
...
@@ -167,6 +167,39 @@ TEST(GraphPatternDetecter, MultiSubgraph) {
ASSERT_LE
(
count
,
2
);
}
TEST
(
GraphPatternDetector
,
IntermediateCheck
)
{
ProgramDesc
program
;
Graph
graph
(
program
);
BuildGraph
(
&
graph
);
// o2->v2->o3
// o2->v2->o4
// check o2+o3 fuse, should fail because v2 also link to o4.
GraphPatternDetector
detector
;
auto
*
op2
=
detector
.
mutable_pattern
()
->
NewNode
(
[](
Node
*
x
)
{
return
x
&&
x
->
IsOp
()
&&
x
->
Name
()
==
"op2"
;
},
"op2"
);
auto
*
op3
=
detector
.
mutable_pattern
()
->
NewNode
(
[](
Node
*
x
)
{
return
x
&&
x
->
IsOp
()
&&
x
->
Name
()
==
"op3"
;
},
"op3"
);
auto
*
v2
=
detector
.
mutable_pattern
()
->
NewNode
(
[](
Node
*
x
)
{
return
x
&&
x
->
IsVar
()
&&
x
->
Name
()
==
"var2"
;
},
"var2"
)
->
AsIntermediate
();
v2
->
LinksFrom
({
op2
}).
LinksTo
({
op3
});
int
count
=
0
;
detector
(
&
graph
,
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
g
,
Graph
*
graph
)
{
++
count
;
});
EXPECT_EQ
(
count
,
0
);
count
=
0
;
v2
->
AsInput
();
detector
(
&
graph
,
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
g
,
Graph
*
graph
)
{
++
count
;
});
ASSERT_EQ
(
count
,
1
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc
浏览文件 @
af15f6f0
...
...
@@ -180,16 +180,16 @@ PDNode* BuildFCPattern(PDPattern* pattern, PDNode* fc_x) {
std
::
unique_ptr
<
ir
::
Graph
>
SeqConcatFcFusePass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
FusePassBase
::
Init
(
graph
.
get
());
FusePassBase
::
Init
(
"seq_concat_fc_fuse"
,
graph
.
get
());
GraphPatternDetector
detector
;
auto
*
pattern
=
detector
.
mutable_pattern
();
auto
*
concat_out
=
BuildSeqExpandConcatPattern
(
pattern
);
BuildFCPattern
(
pattern
,
concat_out
);
#define GET_NODE(id, pattern) \
PADDLE_ENFORCE(subgraph.count(pattern.RetriveNode(#id)), \
PADDLE_ENFORCE(subgraph.count(pattern.Retri
e
veNode(#id)), \
"pattern has no Node called %s", #id); \
auto* id = subgraph.at(pattern.RetriveNode(#id)); \
auto* id = subgraph.at(pattern.Retri
e
veNode(#id)); \
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
detector
(
graph
.
get
(),
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
...
...
paddle/fluid/inference/analysis/analyzer.cc
浏览文件 @
af15f6f0
...
...
@@ -93,7 +93,6 @@ class DfgPassManagerImpl final : public DfgPassManager {
void
AddGraphvizDebugerPass
(
Pass
*
pass
)
{
auto
*
debuger_pass
=
pass
->
CreateGraphvizDebugerPass
();
if
(
debuger_pass
)
{
LOG
(
INFO
)
<<
" - register debug pass ["
<<
debuger_pass
->
repr
()
<<
"]"
;
Register
(
debuger_pass
->
repr
(),
debuger_pass
);
}
}
...
...
paddle/fluid/inference/analysis/analyzer_tester.cc
浏览文件 @
af15f6f0
...
...
@@ -16,10 +16,13 @@
#include <google/protobuf/text_format.h>
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/inference/analysis/ut_helper.h"
#include "paddle/fluid/inference/api/analysis_predictor.h"
#include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/utils/singleton.h"
#include "paddle/fluid/platform/profiler.h"
DEFINE_string
(
infer_ditu_rnn_model
,
""
,
"model path for ditu RNN"
);
...
...
@@ -31,6 +34,8 @@ namespace paddle {
namespace
inference
{
namespace
analysis
{
using
namespace
framework
;
TEST
(
Analyzer
,
analysis_without_tensorrt
)
{
FLAGS_IA_enable_tensorrt_subgraph_engine
=
false
;
Argument
argument
;
...
...
@@ -311,6 +316,20 @@ void TestDituRNNPrediction(const std::string &model_path,
EXPECT_NEAR
(
data
[
i
],
base_data
[
i
],
1e-3
);
}
}
if
(
use_analysis
&&
activate_ir
)
{
AnalysisPredictor
*
analysis_predictor
=
dynamic_cast
<
AnalysisPredictor
*>
(
predictor
.
get
());
auto
&
fuse_statis
=
analysis_predictor
->
analysis_argument
()
.
Get
<
std
::
unordered_map
<
std
::
string
,
int
>>
(
framework
::
ir
::
kFuseStatisAttr
);
for
(
auto
&
item
:
fuse_statis
)
{
LOG
(
INFO
)
<<
"fused "
<<
item
.
first
<<
" "
<<
item
.
second
;
}
ASSERT_TRUE
(
fuse_statis
.
count
(
"fc"
));
EXPECT_EQ
(
fuse_statis
.
at
(
"fc"
),
1
);
}
}
// Directly infer with the original model.
...
...
paddle/fluid/inference/analysis/argument.h
浏览文件 @
af15f6f0
...
...
@@ -64,7 +64,8 @@ struct Argument {
template
<
typename
T
>
void
Set
(
const
std
::
string
&
key
,
T
*
data
)
{
PADDLE_ENFORCE_NOT_NULL
(
data
);
PADDLE_ENFORCE
(
!
attrs_
.
count
(
key
),
"duplicate attr called %s"
,
key
);
PADDLE_ENFORCE
(
!
attrs_
.
count
(
key
),
"Duplicate set Argument's attr [%s]"
,
key
);
attrs_
[
key
]
=
data
;
attr_deleters_
[
key
]
=
[
data
,
key
,
this
]()
{
VLOG
(
3
)
<<
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
;
...
...
paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc
浏览文件 @
af15f6f0
...
...
@@ -15,6 +15,7 @@
#include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h"
#include <vector>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/proto_desc.h"
#include "paddle/fluid/inference/analysis/analyzer.h"
...
...
@@ -34,7 +35,6 @@ std::vector<std::string> ExtractParameters(
bool
DataFlowGraphToFluidPass
::
Initialize
(
Argument
*
argument
)
{
ANALYSIS_ARGUMENT_CHECK_FIELD
(
argument
)
ANALYSIS_ARGUMENT_CHECK_FIELD
(
argument
->
origin_program_desc
)
PADDLE_ENFORCE
(
!
argument
->
transformed_program_desc
);
// The transformed_program_desc should inherit all the VarDesc and BlockDesc
// from the original program desc. The operators of the main block(the first
// block) should rewritten by data flow graph.
...
...
@@ -66,7 +66,7 @@ void DataFlowGraphToFluidPass::Run(DataFlowGraph *graph) {
}
}
if
(
argument_
->
Has
(
"param_scope"
))
{
if
(
argument_
->
Has
(
framework
::
ir
::
kParamScopeAttr
))
{
LOG
(
WARNING
)
<<
"parameter changes in the scope takes effect"
;
}
...
...
paddle/fluid/inference/analysis/fluid_to_ir_pass.cc
浏览文件 @
af15f6f0
...
...
@@ -14,6 +14,7 @@
#include "paddle/fluid/inference/analysis/fluid_to_ir_pass.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
...
...
@@ -26,11 +27,11 @@ void FluidToIrPass::EnableParamModify(const std::string &model_dir,
const
std
::
string
&
prog_file
,
const
std
::
string
&
param_file
)
{
PADDLE_ENFORCE
(
argument_
);
argument_
->
Set
(
"param_scope"
,
new
framework
::
Scope
);
argument_
->
Set
(
framework
::
ir
::
kParamScopeAttr
,
new
framework
::
Scope
);
// Load parameters.
VLOG
(
3
)
<<
"Loading parameters from "
<<
model_dir
;
LoadParams
(
&
argument_
->
Get
<
framework
::
Scope
>
(
"param_scope"
),
model_dir
,
prog_file
,
param_file
);
LoadParams
(
&
argument_
->
Get
<
framework
::
Scope
>
(
framework
::
ir
::
kParamScopeAttr
)
,
model_dir
,
prog_file
,
param_file
);
}
bool
FluidToIrPass
::
LoadParams
(
framework
::
Scope
*
scope
,
const
std
::
string
&
dir
,
...
...
paddle/fluid/inference/analysis/fluid_to_ir_pass.h
浏览文件 @
af15f6f0
...
...
@@ -14,12 +14,14 @@
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/inference/analysis/ir_pass_manager.h"
#include "paddle/fluid/inference/analysis/pass.h"
namespace
paddle
{
namespace
inference
{
namespace
analysis
{
using
namespace
framework
;
static
const
char
kFluidToIrPassesAttr
[]
=
"__fluid_to_ir_passes__"
;
...
...
@@ -45,13 +47,12 @@ class FluidToIrPass final : public DataFlowGraphPass {
ANALYSIS_ARGUMENT_CHECK_FIELD
(
argument
->
fluid_model_program_path
);
// Load program.
auto
program
=
LoadProgramDesc
(
*
argument
->
fluid_model_program_path
);
argument
->
origin_program_desc
.
reset
(
new
framework
::
proto
::
ProgramDesc
(
program
));
argument
->
origin_program_desc
.
reset
(
new
proto
::
ProgramDesc
(
program
));
// Create main data flow graph.
if
(
!
argument
->
main_dfg
)
{
argument
->
main_dfg
.
reset
(
new
DataFlowGraph
);
}
argument
->
Set
(
"ir_program_desc"
,
new
framework
::
ProgramDesc
(
program
));
argument
->
Set
(
"ir_program_desc"
,
new
ProgramDesc
(
program
));
LOG
(
INFO
)
<<
"Loading parameters"
;
// Load parameters to argument if needed.
...
...
@@ -73,15 +74,15 @@ class FluidToIrPass final : public DataFlowGraphPass {
void
Run
(
DataFlowGraph
*
graph
)
override
{
// Call all the IR Passes
IRPassManager
ir_passes
(
argument_
->
Get
<
framework
::
ProgramDesc
>
(
"ir_program_desc"
),
nullptr
);
IRPassManager
ir_passes
(
argument_
->
Get
<
ProgramDesc
>
(
"ir_program_desc"
),
nullptr
);
// Pass the scope from analysis to IR if needed.
if
(
argument_
->
Has
(
"param_scope"
))
{
if
(
argument_
->
Has
(
ir
::
kParamScopeAttr
))
{
// Here the address is passed, attention that IR doesn't own the scope, so
// the real scope in analysis should live during the IR phase.
ir_passes
.
graph
().
Set
(
"param_scope"
,
new
framework
::
Scope
*
(
&
argument_
->
Get
<
framework
::
Scope
>
(
"param_scope"
)));
ir
::
kParamScopeAttr
,
new
Scope
*
(
&
argument_
->
Get
<
Scope
>
(
ir
::
kParamScopeAttr
)));
}
const
auto
&
ir_passes_to_apply
=
...
...
@@ -90,6 +91,14 @@ class FluidToIrPass final : public DataFlowGraphPass {
PADDLE_ENFORCE
(
argument_
->
main_dfg
.
get
());
argument_
->
main_dfg
->
Build
(
ir_passes
.
graph
());
// inherit the arguments from ir.
if
(
ir_passes
.
graph
().
Has
(
ir
::
kFuseStatisAttr
))
{
argument_
->
Set
(
ir
::
kFuseStatisAttr
,
new
std
::
unordered_map
<
std
::
string
,
int
>
(
ir_passes
.
graph
().
Get
<
std
::
unordered_map
<
std
::
string
,
int
>>
(
ir
::
kFuseStatisAttr
)));
}
}
void
EnableParamModify
(
const
std
::
string
&
model_dir
,
...
...
@@ -100,7 +109,7 @@ class FluidToIrPass final : public DataFlowGraphPass {
private:
// Load parameters from a single file or from a directory.
bool
LoadParams
(
framework
::
Scope
*
scope
,
const
std
::
string
&
dir
,
bool
LoadParams
(
Scope
*
scope
,
const
std
::
string
&
dir
,
const
std
::
string
&
prog_file
,
const
std
::
string
&
param_file
);
private:
...
...
paddle/fluid/inference/analysis/ir_pass_manager.cc
浏览文件 @
af15f6f0
...
...
@@ -14,6 +14,7 @@
#include "paddle/fluid/inference/analysis/ir_pass_manager.h"
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/scope.h"
...
...
@@ -25,7 +26,8 @@ IRPassManager::IRPassManager(const ProgramDesc &program,
framework
::
Scope
*
scope
)
:
program_
(
program
)
{
graph_
.
reset
(
new
framework
::
ir
::
Graph
(
program
));
if
(
scope
)
graph_
->
Set
(
"param_scope"
,
new
framework
::
Scope
*
(
scope
));
if
(
scope
)
graph_
->
Set
(
framework
::
ir
::
kParamScopeAttr
,
new
framework
::
Scope
*
(
scope
));
}
void
IRPassManager
::
Apply
(
const
std
::
vector
<
std
::
string
>
&
passes
)
{
...
...
paddle/fluid/inference/api/analysis_predictor.cc
浏览文件 @
af15f6f0
...
...
@@ -12,31 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/api/analysis_predictor.h"
#include <memory>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/api/api_impl.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/utils/singleton.h"
namespace
paddle
{
using
inference
::
analysis
::
Argument
;
using
inference
::
Singleton
;
using
inference
::
analysis
::
Analyzer
;
using
framework
::
proto
::
ProgramDesc
;
/* This predictor is based on the original native predictor with IR and Analysis
* support. It will optimize IR and Parameters in the runtime.
* TODO(Superjomn) Replace the Navive predictor?
*/
class
AnalysisPredictor
:
public
NativePaddlePredictor
{
public:
explicit
AnalysisPredictor
(
const
NativeConfig
&
config
)
:
NativePaddlePredictor
(
config
),
config_
(
config
)
{}
bool
Init
(
const
std
::
shared_ptr
<
framework
::
Scope
>&
parent_scope
)
{
bool
AnalysisPredictor
::
Init
(
const
std
::
shared_ptr
<
framework
::
Scope
>&
parent_scope
)
{
VLOG
(
3
)
<<
"Predictor::init()"
;
if
(
config_
.
use_gpu
)
{
place_
=
paddle
::
platform
::
CUDAPlace
(
config_
.
device
);
...
...
@@ -58,8 +45,8 @@ class AnalysisPredictor : public NativePaddlePredictor {
if
(
!
config_
.
model_dir
.
empty
())
{
// Parameters are saved in separate files sited in
// the specified `dirname`.
inference_program_
=
paddle
::
inference
::
Load
(
executor_
.
get
(),
scope_
.
get
(),
config_
.
model_dir
);
inference_program_
=
paddle
::
inference
::
Load
(
executor_
.
get
(),
scope_
.
get
(),
config_
.
model_dir
);
}
else
if
(
!
config_
.
prog_file
.
empty
()
&&
!
config_
.
param_file
.
empty
())
{
// All parameters are saved in a single file.
// The file names should be consistent with that used
...
...
@@ -78,55 +65,43 @@ class AnalysisPredictor : public NativePaddlePredictor {
PADDLE_ENFORCE
(
scope_
.
get
());
executor_
->
CreateVariables
(
*
inference_program_
,
sub_scope_
?
sub_scope_
:
scope_
.
get
(),
0
);
// Get the feed_target_names and fetch_target_names
PrepareFeedFetch
();
return
true
;
}
bool
Run
(
const
std
::
vector
<
PaddleTensor
>&
inputs
,
std
::
vector
<
PaddleTensor
>*
output_data
,
int
batch_size
=
-
1
)
override
{
return
NativePaddlePredictor
::
Run
(
inputs
,
output_data
,
batch_size
);
}
}
void
OptimizeInferenceProgram
()
{
void
AnalysisPredictor
::
OptimizeInferenceProgram
()
{
LOG
(
INFO
)
<<
"optimize begin"
;
FLAGS_IA_enable_ir
=
true
;
FLAGS_IA_enable_tensorrt_subgraph_engine
=
false
;
FLAGS_IA_output_storage_path
=
""
;
// Don't output the model.
// Analyze inference_program
Argument
argument
;
if
(
!
config_
.
model_dir
.
empty
())
{
argument
.
fluid_model_dir
.
reset
(
new
std
::
string
(
config_
.
model_dir
));
argument_
.
fluid_model_dir
.
reset
(
new
std
::
string
(
config_
.
model_dir
));
}
else
{
PADDLE_ENFORCE
(
!
config_
.
param_file
.
empty
(),
"Either model_dir or (param_file, prog_file) should be set."
);
PADDLE_ENFORCE
(
!
config_
.
prog_file
.
empty
());
argument
.
fluid_model_program_path
.
reset
(
argument_
.
fluid_model_program_path
.
reset
(
new
std
::
string
(
config_
.
prog_file
));
argument
.
fluid_model_param_path
.
reset
(
new
std
::
string
(
config_
.
param_file
));
argument_
.
fluid_model_param_path
.
reset
(
new
std
::
string
(
config_
.
param_file
));
}
argument
.
origin_program_desc
.
reset
(
argument_
.
origin_program_desc
.
reset
(
new
ProgramDesc
(
*
inference_program_
->
Proto
()));
Singleton
<
Analyzer
>::
Global
().
Run
(
&
argument
);
CHECK
(
argument
.
transformed_program_desc
);
Analyzer
().
Run
(
&
argument_
);
CHECK
(
argument_
.
transformed_program_desc
);
VLOG
(
5
)
<<
"to prepare executor"
;
// LOG(INFO) << "transformed_parogram_desc " <<
// argument.transformed_program_desc->DebugString();
inference_program_
.
reset
(
new
framework
::
ProgramDesc
(
*
argument
.
transformed_program_desc
));
PADDLE_ENFORCE
(
argument
.
Has
(
"param_scope"
));
new
framework
::
ProgramDesc
(
*
argument_
.
transformed_program_desc
));
PADDLE_ENFORCE
(
argument_
.
Has
(
framework
::
ir
::
kParamScopeAttr
));
// Update scope.
scope_
.
reset
(
argument
.
Release
<
framework
::
Scope
>
(
"param_scope"
));
scope_
.
reset
(
argument_
.
Release
<
framework
::
Scope
>
(
framework
::
ir
::
kParamScopeAttr
));
LOG
(
INFO
)
<<
"optimize end =="
;
}
private:
NativeConfig
config_
;
};
}
template
<
>
std
::
unique_ptr
<
PaddlePredictor
>
CreatePaddlePredictor
<
...
...
paddle/fluid/inference/api/analysis_predictor.h
0 → 100644
浏览文件 @
af15f6f0
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/api/api_impl.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
namespace
paddle
{
using
inference
::
analysis
::
Argument
;
using
inference
::
analysis
::
Analyzer
;
using
framework
::
proto
::
ProgramDesc
;
/* This predictor is based on the original native predictor with IR and Analysis
* support. It will optimize IR and Parameters in the runtime.
* TODO(Superjomn) Replace the Navive predictor?
*/
class
AnalysisPredictor
:
public
NativePaddlePredictor
{
public:
explicit
AnalysisPredictor
(
const
NativeConfig
&
config
)
:
NativePaddlePredictor
(
config
),
config_
(
config
)
{}
bool
Init
(
const
std
::
shared_ptr
<
framework
::
Scope
>&
parent_scope
);
bool
Run
(
const
std
::
vector
<
PaddleTensor
>&
inputs
,
std
::
vector
<
PaddleTensor
>*
output_data
,
int
batch_size
=
-
1
)
override
{
return
NativePaddlePredictor
::
Run
(
inputs
,
output_data
,
batch_size
);
}
void
OptimizeInferenceProgram
();
Argument
&
analysis_argument
()
{
return
argument_
;
}
private:
NativeConfig
config_
;
Argument
argument_
;
};
}
// namespace paddle
paddle/fluid/inference/api/api_impl.cc
浏览文件 @
af15f6f0
...
...
@@ -62,7 +62,7 @@ void NativePaddlePredictor::PrepareFeedFetch() {
for
(
auto
*
op
:
inference_program_
->
Block
(
0
).
AllOps
())
{
if
(
op
->
Type
()
==
"feed"
)
{
int
idx
=
boost
::
get
<
int
>
(
op
->
GetAttr
(
"col"
));
if
(
feeds_
.
size
()
<=
idx
)
{
if
(
feeds_
.
size
()
<=
static_cast
<
size_t
>
(
idx
)
)
{
feeds_
.
resize
(
idx
+
1
);
}
feeds_
[
idx
]
=
op
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录