Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
1ea9971a
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1ea9971a
编写于
6月 30, 2022
作者:
J
JingZhuangzhuang
提交者:
GitHub
6月 30, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
modify graph_pattern to thread_local (#43945)
上级
26187c27
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
268 addition
and
129 deletion
+268
-129
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+103
-43
paddle/fluid/framework/ir/graph_pattern_detector.h
paddle/fluid/framework/ir/graph_pattern_detector.h
+64
-34
paddle/fluid/inference/api/analysis_predictor.cc
paddle/fluid/inference/api/analysis_predictor.cc
+101
-52
未找到文件。
paddle/fluid/framework/ir/graph_pattern_detector.cc
浏览文件 @
1ea9971a
...
@@ -28,10 +28,17 @@ using string::Style;
...
@@ -28,10 +28,17 @@ using string::Style;
size_t
PDPattern
::
id_
=
0UL
;
size_t
PDPattern
::
id_
=
0UL
;
#ifdef PADDLE_WITH_TENSORRT
namespace
patterns
{
thread_local
std
::
unordered_map
<
std
::
string
,
size_t
>
KeyCounter
::
dic_
;
}
#endif
PDNode
*
PDPattern
::
NewNode
(
const
std
::
string
&
name
)
{
PDNode
*
PDPattern
::
NewNode
(
const
std
::
string
&
name
)
{
if
(
!
name
.
empty
())
{
if
(
!
name
.
empty
())
{
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
node_map_
.
count
(
name
),
0UL
,
node_map_
.
count
(
name
),
0UL
,
platform
::
errors
::
PreconditionNotMet
(
platform
::
errors
::
PreconditionNotMet
(
"PDNode's name should be unique, get duplicate [%s]"
,
name
));
"PDNode's name should be unique, get duplicate [%s]"
,
name
));
}
}
...
@@ -45,7 +52,8 @@ PDNode *PDPattern::NewNode(const std::string &name) {
...
@@ -45,7 +52,8 @@ PDNode *PDPattern::NewNode(const std::string &name) {
PDNode
*
PDPattern
::
NewNode
(
PDNode
::
teller_t
&&
teller
,
const
std
::
string
&
name
)
{
PDNode
*
PDPattern
::
NewNode
(
PDNode
::
teller_t
&&
teller
,
const
std
::
string
&
name
)
{
if
(
!
name
.
empty
())
{
if
(
!
name
.
empty
())
{
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
node_map_
.
count
(
name
),
0UL
,
node_map_
.
count
(
name
),
0UL
,
platform
::
errors
::
PreconditionNotMet
(
platform
::
errors
::
PreconditionNotMet
(
"PDNode's name should be unique, get duplicate [%s]"
,
name
));
"PDNode's name should be unique, get duplicate [%s]"
,
name
));
}
}
...
@@ -70,7 +78,9 @@ void PDPattern::AddEdge(PDNode *a, PDNode *b) {
...
@@ -70,7 +78,9 @@ void PDPattern::AddEdge(PDNode *a, PDNode *b) {
a
,
platform
::
errors
::
NotFound
(
"PDNode %s is not found."
,
a
->
name
()));
a
,
platform
::
errors
::
NotFound
(
"PDNode %s is not found."
,
a
->
name
()));
PADDLE_ENFORCE_NOT_NULL
(
PADDLE_ENFORCE_NOT_NULL
(
b
,
platform
::
errors
::
NotFound
(
"PDNode %s is not found."
,
b
->
name
()));
b
,
platform
::
errors
::
NotFound
(
"PDNode %s is not found."
,
b
->
name
()));
PADDLE_ENFORCE_NE
(
a
,
b
,
platform
::
errors
::
PermissionDenied
(
PADDLE_ENFORCE_NE
(
a
,
b
,
platform
::
errors
::
PermissionDenied
(
"Cannot connect the same node in the graph."
));
"Cannot connect the same node in the graph."
));
edges_
.
emplace_back
(
a
,
b
);
edges_
.
emplace_back
(
a
,
b
);
}
}
...
@@ -128,7 +138,8 @@ void GraphPatternDetector::ValidateByNodeRole(
...
@@ -128,7 +138,8 @@ void GraphPatternDetector::ValidateByNodeRole(
subgraphs
->
erase
(
subgraphs
->
erase
(
std
::
remove_if
(
std
::
remove_if
(
subgraphs
->
begin
(),
subgraphs
->
end
(),
subgraphs
->
begin
(),
subgraphs
->
end
(),
[](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
)
->
bool
{
[](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
)
->
bool
{
// Collect the inputs and outputs.
// Collect the inputs and outputs.
std
::
set
<
Node
*>
ios
;
std
::
set
<
Node
*>
ios
;
...
@@ -310,7 +321,8 @@ void GraphPatternDetector::SortSubgraphs(
...
@@ -310,7 +321,8 @@ void GraphPatternDetector::SortSubgraphs(
}
}
std
::
sort
(
std
::
sort
(
subgraphs
->
begin
(),
subgraphs
->
end
(),
subgraphs
->
begin
(),
subgraphs
->
end
(),
[](
const
GraphPatternDetector
::
subgraph_t
&
a
,
[](
const
GraphPatternDetector
::
subgraph_t
&
a
,
const
GraphPatternDetector
::
subgraph_t
&
b
)
{
const
GraphPatternDetector
::
subgraph_t
&
b
)
{
for
(
auto
&
item
:
a
)
{
for
(
auto
&
item
:
a
)
{
...
@@ -438,7 +450,8 @@ PDNode *PDNode::assert_is_persistable_var() {
...
@@ -438,7 +450,8 @@ PDNode *PDNode::assert_is_persistable_var() {
}
}
PDNode
*
PDNode
::
assert_is_op_nth_input
(
const
std
::
string
&
op_type
,
PDNode
*
PDNode
::
assert_is_op_nth_input
(
const
std
::
string
&
op_type
,
const
std
::
string
&
argument
,
int
nth
)
{
const
std
::
string
&
argument
,
int
nth
)
{
assert_is_var
();
assert_is_var
();
assert_is_op_input
(
op_type
);
assert_is_op_input
(
op_type
);
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
...
@@ -453,7 +466,8 @@ PDNode *PDNode::assert_is_op_nth_input(const std::string &op_type,
...
@@ -453,7 +466,8 @@ PDNode *PDNode::assert_is_op_nth_input(const std::string &op_type,
}
}
PDNode
*
PDNode
::
assert_is_op_nth_output
(
const
std
::
string
&
op_type
,
PDNode
*
PDNode
::
assert_is_op_nth_output
(
const
std
::
string
&
op_type
,
const
std
::
string
&
argument
,
int
nth
)
{
const
std
::
string
&
argument
,
int
nth
)
{
assert_is_var
();
assert_is_var
();
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
for
(
auto
*
op
:
x
->
inputs
)
{
for
(
auto
*
op
:
x
->
inputs
)
{
...
@@ -580,7 +594,8 @@ PDNode *PDNode::assert_is_ops(const std::unordered_set<std::string> &op_types) {
...
@@ -580,7 +594,8 @@ PDNode *PDNode::assert_is_ops(const std::unordered_set<std::string> &op_types) {
PDNode
*
PDNode
::
assert_is_ops_nth_input
(
PDNode
*
PDNode
::
assert_is_ops_nth_input
(
const
std
::
unordered_set
<
std
::
string
>
&
op_types
,
const
std
::
unordered_set
<
std
::
string
>
&
op_types
,
const
std
::
string
&
argument
,
int
nth
)
{
const
std
::
string
&
argument
,
int
nth
)
{
assert_is_var
();
assert_is_var
();
assert_is_ops_input
(
op_types
);
assert_is_ops_input
(
op_types
);
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
...
@@ -596,7 +611,8 @@ PDNode *PDNode::assert_is_ops_nth_input(
...
@@ -596,7 +611,8 @@ PDNode *PDNode::assert_is_ops_nth_input(
PDNode
*
PDNode
::
assert_is_ops_nth_output
(
PDNode
*
PDNode
::
assert_is_ops_nth_output
(
const
std
::
unordered_set
<
std
::
string
>
&
op_types
,
const
std
::
unordered_set
<
std
::
string
>
&
op_types
,
const
std
::
string
&
argument
,
int
nth
)
{
const
std
::
string
&
argument
,
int
nth
)
{
assert_is_var
();
assert_is_var
();
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
for
(
auto
*
op
:
x
->
inputs
)
{
for
(
auto
*
op
:
x
->
inputs
)
{
...
@@ -693,11 +709,13 @@ bool VarLinksToOp(Node *node, const std::string &op_type) {
...
@@ -693,11 +709,13 @@ bool VarLinksToOp(Node *node, const std::string &op_type) {
bool
IsNthInput
(
Node
*
var
,
Node
*
op
,
const
std
::
string
&
argument
,
size_t
nth
)
{
bool
IsNthInput
(
Node
*
var
,
Node
*
op
,
const
std
::
string
&
argument
,
size_t
nth
)
{
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
var
->
IsVar
(),
true
,
var
->
IsVar
(),
true
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"First parameter of function IsNthInput must be Node::Var"
));
"First parameter of function IsNthInput must be Node::Var"
));
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
op
->
IsOp
(),
true
,
op
->
IsOp
(),
true
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"Second parameter of function IsNthInput must be Node::Op"
));
"Second parameter of function IsNthInput must be Node::Op"
));
if
(
!
HasInput
(
op
,
argument
)
||
op
->
Op
()
->
Input
(
argument
).
size
()
<=
nth
)
if
(
!
HasInput
(
op
,
argument
)
||
op
->
Op
()
->
Input
(
argument
).
size
()
<=
nth
)
...
@@ -707,7 +725,8 @@ bool IsNthInput(Node *var, Node *op, const std::string &argument, size_t nth) {
...
@@ -707,7 +725,8 @@ bool IsNthInput(Node *var, Node *op, const std::string &argument, size_t nth) {
bool
HasInput
(
Node
*
op
,
const
std
::
string
&
argument
)
{
bool
HasInput
(
Node
*
op
,
const
std
::
string
&
argument
)
{
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
op
->
IsOp
(),
true
,
op
->
IsOp
(),
true
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"First parameter of function HasInput must be Node::Op"
));
"First parameter of function HasInput must be Node::Op"
));
auto
const
&
names
=
op
->
Op
()
->
InputNames
();
auto
const
&
names
=
op
->
Op
()
->
InputNames
();
...
@@ -718,7 +737,8 @@ bool HasInput(Node *op, const std::string &argument) {
...
@@ -718,7 +737,8 @@ bool HasInput(Node *op, const std::string &argument) {
bool
HasOutput
(
Node
*
op
,
const
std
::
string
&
argument
)
{
bool
HasOutput
(
Node
*
op
,
const
std
::
string
&
argument
)
{
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
op
->
IsOp
(),
true
,
op
->
IsOp
(),
true
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"First parameter of function HasOuput must be Node::Op"
));
"First parameter of function HasOuput must be Node::Op"
));
auto
const
&
names
=
op
->
Op
()
->
OutputNames
();
auto
const
&
names
=
op
->
Op
()
->
OutputNames
();
...
@@ -729,11 +749,13 @@ bool HasOutput(Node *op, const std::string &argument) {
...
@@ -729,11 +749,13 @@ bool HasOutput(Node *op, const std::string &argument) {
bool
IsNthOutput
(
Node
*
var
,
Node
*
op
,
const
std
::
string
&
argument
,
size_t
nth
)
{
bool
IsNthOutput
(
Node
*
var
,
Node
*
op
,
const
std
::
string
&
argument
,
size_t
nth
)
{
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
var
->
IsVar
(),
true
,
var
->
IsVar
(),
true
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"First parameter of function IsNthOutput must be Node::Var"
));
"First parameter of function IsNthOutput must be Node::Var"
));
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
op
->
IsOp
(),
true
,
op
->
IsOp
(),
true
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"Second parameter of function IsNthOutput must be Node::Op"
));
"Second parameter of function IsNthOutput must be Node::Op"
));
if
(
!
HasOutput
(
op
,
argument
)
||
op
->
Op
()
->
Output
(
argument
).
size
()
<=
nth
)
if
(
!
HasOutput
(
op
,
argument
)
||
op
->
Op
()
->
Output
(
argument
).
size
()
<=
nth
)
...
@@ -875,22 +897,35 @@ PDNode *patterns::ConvBN::operator()(paddle::framework::ir::PDNode *conv_input,
...
@@ -875,22 +897,35 @@ PDNode *patterns::ConvBN::operator()(paddle::framework::ir::PDNode *conv_input,
eltwise_op
->
LinksFrom
({
conv_out_var
,
eltwise_y_in_var
})
eltwise_op
->
LinksFrom
({
conv_out_var
,
eltwise_y_in_var
})
.
LinksTo
({
eltwise_out_var
});
.
LinksTo
({
eltwise_out_var
});
batch_norm_op
batch_norm_op
->
LinksFrom
({
eltwise_out_var
,
bn_scale_var
,
bn_bias_var
,
bn_mean_var
,
->
LinksFrom
({
eltwise_out_var
,
bn_scale_var
,
bn_bias_var
,
bn_mean_var
,
bn_variance_var
})
bn_variance_var
})
.
LinksTo
({
bn_out_var
,
bn_mean_out_var
,
bn_variance_out_var
,
.
LinksTo
({
bn_out_var
,
bn_saved_mean_var
,
bn_saved_variance_var
});
bn_mean_out_var
,
bn_variance_out_var
,
bn_saved_mean_var
,
bn_saved_variance_var
});
}
else
{
}
else
{
batch_norm_op
batch_norm_op
->
LinksFrom
({
conv_out_var
,
bn_scale_var
,
bn_bias_var
,
bn_mean_var
,
->
LinksFrom
({
conv_out_var
,
bn_scale_var
,
bn_bias_var
,
bn_mean_var
,
bn_variance_var
})
bn_variance_var
})
.
LinksTo
({
bn_out_var
,
bn_mean_out_var
,
bn_variance_out_var
,
.
LinksTo
({
bn_out_var
,
bn_saved_mean_var
,
bn_saved_variance_var
});
bn_mean_out_var
,
bn_variance_out_var
,
bn_saved_mean_var
,
bn_saved_variance_var
});
}
}
return
bn_out_var
;
return
bn_out_var
;
}
}
PDNode
*
patterns
::
ConvActivation
::
operator
()(
PDNode
*
patterns
::
ConvActivation
::
operator
()(
paddle
::
framework
::
ir
::
PDNode
*
conv_input
,
std
::
string
conv_type
,
paddle
::
framework
::
ir
::
PDNode
*
conv_input
,
std
::
string
conv_type
,
std
::
string
activation_type
)
{
std
::
string
activation_type
)
{
// Create Operators
// Create Operators
conv_input
->
assert_is_op_input
(
conv_type
,
"Input"
);
conv_input
->
assert_is_op_input
(
conv_type
,
"Input"
);
...
@@ -920,7 +955,8 @@ PDNode *patterns::ConvActivation::operator()(
...
@@ -920,7 +955,8 @@ PDNode *patterns::ConvActivation::operator()(
PDNode
*
patterns
::
ElementwiseActivation
::
operator
()(
PDNode
*
patterns
::
ElementwiseActivation
::
operator
()(
paddle
::
framework
::
ir
::
PDNode
*
elementwise_a
,
paddle
::
framework
::
ir
::
PDNode
*
elementwise_a
,
const
std
::
string
&
elementwise_type
,
const
std
::
string
&
activation_type
)
{
const
std
::
string
&
elementwise_type
,
const
std
::
string
&
activation_type
)
{
// Create Operators
// Create Operators
elementwise_a
->
assert_is_op_input
(
elementwise_type
,
"X"
);
elementwise_a
->
assert_is_op_input
(
elementwise_type
,
"X"
);
auto
*
elementwise_op
=
auto
*
elementwise_op
=
...
@@ -995,7 +1031,8 @@ PDNode *patterns::SeqConvEltAddRelu::operator()(
...
@@ -995,7 +1031,8 @@ PDNode *patterns::SeqConvEltAddRelu::operator()(
}
}
PDNode
*
patterns
::
FC
::
operator
()(
paddle
::
framework
::
ir
::
PDNode
*
x
,
PDNode
*
patterns
::
FC
::
operator
()(
paddle
::
framework
::
ir
::
PDNode
*
x
,
bool
with_bias
,
bool
with_relu
)
{
bool
with_bias
,
bool
with_relu
)
{
// Create shared nodes.
// Create shared nodes.
x
->
assert_is_op_input
(
"mul"
,
"X"
);
x
->
assert_is_op_input
(
"mul"
,
"X"
);
auto
*
mul
=
pattern
->
NewNode
(
mul_repr
())
->
assert_is_op
(
"mul"
);
auto
*
mul
=
pattern
->
NewNode
(
mul_repr
())
->
assert_is_op
(
"mul"
);
...
@@ -1261,8 +1298,12 @@ PDNode *patterns::BatchNormAct::operator()(
...
@@ -1261,8 +1298,12 @@ PDNode *patterns::BatchNormAct::operator()(
bn
->
LinksFrom
(
bn
->
LinksFrom
(
{
bn_x_var
,
bn_scale_var
,
bn_bias_var
,
bn_variance_var
,
bn_mean_var
})
{
bn_x_var
,
bn_scale_var
,
bn_bias_var
,
bn_variance_var
,
bn_mean_var
})
.
LinksTo
({
bn_mean_out_var
,
bn_variance_out_var
,
bn_saved_variance_var
,
.
LinksTo
({
bn_mean_out_var
,
bn_saved_mean_var
,
bn_reserve_space
,
bn_out_var
});
bn_variance_out_var
,
bn_saved_variance_var
,
bn_saved_mean_var
,
bn_reserve_space
,
bn_out_var
});
act
->
LinksFrom
({
bn_out_var
}).
LinksTo
({
act_out_var
});
act
->
LinksFrom
({
bn_out_var
}).
LinksTo
({
act_out_var
});
return
act_out_var
;
return
act_out_var
;
...
@@ -1319,8 +1360,13 @@ PDNode *patterns::BatchNormActGrad::operator()(
...
@@ -1319,8 +1360,13 @@ PDNode *patterns::BatchNormActGrad::operator()(
.
LinksTo
({
d_intermediate_var
});
.
LinksTo
({
d_intermediate_var
});
bn_grad
bn_grad
->
LinksFrom
({
bn_x_var
,
d_intermediate_var
,
bn_scale_var
,
bn_bias_var
,
->
LinksFrom
({
bn_x_var
,
bn_saved_mean_var
,
bn_saved_variance_var
,
bn_reserve_space
})
d_intermediate_var
,
bn_scale_var
,
bn_bias_var
,
bn_saved_mean_var
,
bn_saved_variance_var
,
bn_reserve_space
})
.
LinksTo
({
d_bn_x_var
,
d_bn_scale_var
,
d_bn_bias_var
});
.
LinksTo
({
d_bn_x_var
,
d_bn_scale_var
,
d_bn_bias_var
});
return
bn_grad
;
return
bn_grad
;
...
@@ -1404,8 +1450,12 @@ PDNode *patterns::BatchNormAddAct::operator()(
...
@@ -1404,8 +1450,12 @@ PDNode *patterns::BatchNormAddAct::operator()(
pattern
->
NewNode
(
act_out_repr
())
->
assert_is_ops_output
(
act_types
,
"Out"
);
pattern
->
NewNode
(
act_out_repr
())
->
assert_is_ops_output
(
act_types
,
"Out"
);
bn
->
LinksFrom
({
bn_x_var
,
bn_scale_var
,
bn_bias_var
})
bn
->
LinksFrom
({
bn_x_var
,
bn_scale_var
,
bn_bias_var
})
.
LinksTo
({
bn_mean_out_var
,
bn_variance_out_var
,
bn_saved_variance_var
,
.
LinksTo
({
bn_mean_out_var
,
bn_saved_mean_var
,
bn_reserve_space
,
bn_out_var
});
bn_variance_out_var
,
bn_saved_variance_var
,
bn_saved_mean_var
,
bn_reserve_space
,
bn_out_var
});
elewise_add
->
LinksFrom
({
elewise_add_in_var
,
bn_out_var
})
elewise_add
->
LinksFrom
({
elewise_add_in_var
,
bn_out_var
})
.
LinksTo
({
elewise_add_out_var
});
.
LinksTo
({
elewise_add_out_var
});
act
->
LinksFrom
({
elewise_add_out_var
}).
LinksTo
({
act_out_var
});
act
->
LinksFrom
({
elewise_add_out_var
}).
LinksTo
({
act_out_var
});
...
@@ -1484,8 +1534,13 @@ PDNode *patterns::BatchNormAddActGrad::operator()(
...
@@ -1484,8 +1534,13 @@ PDNode *patterns::BatchNormAddActGrad::operator()(
.
LinksTo
({
d_elewise_add_in_var
,
d_bn_out_var
});
.
LinksTo
({
d_elewise_add_in_var
,
d_bn_out_var
});
bn_grad
bn_grad
->
LinksFrom
({
bn_x_var
,
d_bn_out_var
,
bn_scale_var
,
bn_bias_var
,
->
LinksFrom
({
bn_x_var
,
bn_saved_mean_var
,
bn_saved_variance_var
,
bn_reserve_space
})
d_bn_out_var
,
bn_scale_var
,
bn_bias_var
,
bn_saved_mean_var
,
bn_saved_variance_var
,
bn_reserve_space
})
.
LinksTo
({
d_bn_x_var
,
d_bn_scale_var
,
d_bn_bias_var
});
.
LinksTo
({
d_bn_x_var
,
d_bn_scale_var
,
d_bn_bias_var
});
return
bn_grad
;
return
bn_grad
;
...
@@ -1558,7 +1613,8 @@ PDNode *patterns::ElewiseAddAct::operator()(
...
@@ -1558,7 +1613,8 @@ PDNode *patterns::ElewiseAddAct::operator()(
PDNode
*
patterns
::
LinearAct
::
operator
()(
PDNode
*
patterns
::
LinearAct
::
operator
()(
paddle
::
framework
::
ir
::
PDNode
*
linear_x_var
,
paddle
::
framework
::
ir
::
PDNode
*
linear_x_var
,
const
std
::
unordered_set
<
std
::
string
>
&
act_types
,
bool
with_grad_link
,
const
std
::
unordered_set
<
std
::
string
>
&
act_types
,
bool
with_grad_link
,
bool
is_act_grad_x_from_act
)
{
bool
is_act_grad_x_from_act
)
{
auto
*
matmul_w_var
=
auto
*
matmul_w_var
=
pattern
->
NewNode
(
matmul_w_repr
())
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
pattern
->
NewNode
(
matmul_w_repr
())
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
...
@@ -1621,7 +1677,8 @@ PDNode *patterns::LinearAct::operator()(
...
@@ -1621,7 +1677,8 @@ PDNode *patterns::LinearAct::operator()(
PDNode
*
patterns
::
ElewiseAddMatmulAct
::
operator
()(
PDNode
*
patterns
::
ElewiseAddMatmulAct
::
operator
()(
paddle
::
framework
::
ir
::
PDNode
*
dout_var
,
paddle
::
framework
::
ir
::
PDNode
*
dout_var
,
const
std
::
unordered_set
<
std
::
string
>
&
act_grad_types
,
const
std
::
unordered_set
<
std
::
string
>
&
act_grad_types
,
bool
without_x_gradient
,
bool
is_act_grad_x_from_act
)
{
bool
without_x_gradient
,
bool
is_act_grad_x_from_act
)
{
auto
*
ele_grad_bias_var
=
auto
*
ele_grad_bias_var
=
pattern
->
NewNode
(
ele_grad_bias_repr
())
pattern
->
NewNode
(
ele_grad_bias_repr
())
->
assert_is_op_input
(
"elementwise_add_grad"
,
"Y"
);
->
assert_is_op_input
(
"elementwise_add_grad"
,
"Y"
);
...
@@ -2052,7 +2109,8 @@ PDNode *patterns::Pool::operator()() {
...
@@ -2052,7 +2109,8 @@ PDNode *patterns::Pool::operator()() {
return
output_var
;
return
output_var
;
}
}
PDNode
*
patterns
::
Elementwise
::
operator
()(
PDNode
*
x_var
,
PDNode
*
y_var
,
PDNode
*
patterns
::
Elementwise
::
operator
()(
PDNode
*
x_var
,
PDNode
*
y_var
,
const
std
::
string
elementwise_type
)
{
const
std
::
string
elementwise_type
)
{
auto
elementwise_op
=
auto
elementwise_op
=
pattern
->
NewNode
(
elementwise_op_repr
())
->
assert_is_op
(
elementwise_type
);
pattern
->
NewNode
(
elementwise_op_repr
())
->
assert_is_op
(
elementwise_type
);
...
@@ -2084,7 +2142,9 @@ PDNode *patterns::ElementwiseOp::operator()(
...
@@ -2084,7 +2142,9 @@ PDNode *patterns::ElementwiseOp::operator()(
}
}
PDNode
*
patterns
::
ResidualElementwise
::
operator
()(
PDNode
*
patterns
::
ResidualElementwise
::
operator
()(
PDNode
*
op_var
,
PDNode
*
residual_var
,
const
std
::
string
elementwise_type
,
PDNode
*
op_var
,
PDNode
*
residual_var
,
const
std
::
string
elementwise_type
,
bool
as_x
)
{
bool
as_x
)
{
auto
elementwise_op
=
auto
elementwise_op
=
pattern
->
NewNode
(
elementwise_op_repr
())
->
assert_is_op
(
elementwise_type
);
pattern
->
NewNode
(
elementwise_op_repr
())
->
assert_is_op
(
elementwise_type
);
...
@@ -3065,7 +3125,8 @@ void patterns::DeleteQuantDequantLinearOpPattern::operator()() {
...
@@ -3065,7 +3125,8 @@ void patterns::DeleteQuantDequantLinearOpPattern::operator()() {
}
}
PDNode
*
patterns
::
ReshapeTransposeMatmulPattern
::
operator
()(
PDNode
*
patterns
::
ReshapeTransposeMatmulPattern
::
operator
()(
const
std
::
string
&
op_name
,
bool
with_reshape_xshape
,
const
std
::
string
&
op_name
,
bool
with_reshape_xshape
,
bool
with_transpose_xshape
)
{
bool
with_transpose_xshape
)
{
auto
reshape_op
=
auto
reshape_op
=
pattern
->
NewNode
(
reshape_op_repr
())
->
assert_is_op
(
"reshape2"
);
pattern
->
NewNode
(
reshape_op_repr
())
->
assert_is_op
(
"reshape2"
);
...
@@ -3098,8 +3159,7 @@ PDNode *patterns::ReshapeTransposeMatmulPattern::operator()(
...
@@ -3098,8 +3159,7 @@ PDNode *patterns::ReshapeTransposeMatmulPattern::operator()(
transpose_out
->
assert_is_only_output_of_op
(
"transpose2"
);
transpose_out
->
assert_is_only_output_of_op
(
"transpose2"
);
auto
transpose_xshape
=
auto
transpose_xshape
=
with_transpose_xshape
with_transpose_xshape
?
pattern
->
NewNode
(
transpose_xshape_repr
())
?
pattern
->
NewNode
(
transpose_xshape_repr
())
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op_output
(
"transpose2"
,
"XShape"
)
->
assert_is_op_output
(
"transpose2"
,
"XShape"
)
:
nullptr
;
:
nullptr
;
...
...
paddle/fluid/framework/ir/graph_pattern_detector.h
浏览文件 @
1ea9971a
...
@@ -122,10 +122,12 @@ struct PDNode {
...
@@ -122,10 +122,12 @@ struct PDNode {
PDNode
*
assert_is_op_input
(
const
std
::
string
&
op_type
,
PDNode
*
assert_is_op_input
(
const
std
::
string
&
op_type
,
const
std
::
string
&
argument
);
const
std
::
string
&
argument
);
PDNode
*
assert_is_op_nth_input
(
const
std
::
string
&
op_type
,
PDNode
*
assert_is_op_nth_input
(
const
std
::
string
&
op_type
,
const
std
::
string
&
argument
,
int
nth
);
const
std
::
string
&
argument
,
int
nth
);
PDNode
*
assert_is_not_op_input
(
const
std
::
string
&
argument
);
PDNode
*
assert_is_not_op_input
(
const
std
::
string
&
argument
);
PDNode
*
assert_is_op_nth_output
(
const
std
::
string
&
op_type
,
PDNode
*
assert_is_op_nth_output
(
const
std
::
string
&
op_type
,
const
std
::
string
&
argument
,
int
nth
);
const
std
::
string
&
argument
,
int
nth
);
PDNode
*
assert_is_only_input_of_op
(
const
std
::
string
&
op_type
);
PDNode
*
assert_is_only_input_of_op
(
const
std
::
string
&
op_type
);
PDNode
*
assert_is_only_output_of_op
(
const
std
::
string
&
op_type
);
PDNode
*
assert_is_only_output_of_op
(
const
std
::
string
&
op_type
);
PDNode
*
assert_op_has_n_inputs
(
const
std
::
string
&
op_type
,
size_t
n
);
PDNode
*
assert_op_has_n_inputs
(
const
std
::
string
&
op_type
,
size_t
n
);
...
@@ -138,13 +140,15 @@ struct PDNode {
...
@@ -138,13 +140,15 @@ struct PDNode {
const
std
::
string
&
argument
);
const
std
::
string
&
argument
);
PDNode
*
assert_is_ops_nth_input
(
PDNode
*
assert_is_ops_nth_input
(
const
std
::
unordered_set
<
std
::
string
>&
op_types
,
const
std
::
unordered_set
<
std
::
string
>&
op_types
,
const
std
::
string
&
argument
,
int
nth
);
const
std
::
string
&
argument
,
int
nth
);
PDNode
*
assert_is_ops_input
(
const
std
::
unordered_set
<
std
::
string
>&
op_types
);
PDNode
*
assert_is_ops_input
(
const
std
::
unordered_set
<
std
::
string
>&
op_types
);
PDNode
*
assert_is_ops_input
(
const
std
::
unordered_set
<
std
::
string
>&
op_types
,
PDNode
*
assert_is_ops_input
(
const
std
::
unordered_set
<
std
::
string
>&
op_types
,
const
std
::
string
&
argument
);
const
std
::
string
&
argument
);
PDNode
*
assert_is_ops_nth_output
(
PDNode
*
assert_is_ops_nth_output
(
const
std
::
unordered_set
<
std
::
string
>&
op_types
,
const
std
::
unordered_set
<
std
::
string
>&
op_types
,
const
std
::
string
&
argument
,
int
nth
);
const
std
::
string
&
argument
,
int
nth
);
PDNode
*
assert_is_only_input_of_ops
(
PDNode
*
assert_is_only_input_of_ops
(
const
std
::
unordered_set
<
std
::
string
>&
op_types
);
const
std
::
unordered_set
<
std
::
string
>&
op_types
);
...
@@ -164,10 +168,13 @@ struct PDNode {
...
@@ -164,10 +168,13 @@ struct PDNode {
}
}
private:
private:
PDNode
(
PDPattern
*
pattern
,
const
std
::
string
&
name
=
""
,
PDNode
(
PDPattern
*
pattern
,
const
std
::
string
&
name
=
""
,
Type
type
=
Type
::
kVar
)
Type
type
=
Type
::
kVar
)
:
pattern_
(
pattern
),
name_
(
name
),
type_
(
type
)
{}
:
pattern_
(
pattern
),
name_
(
name
),
type_
(
type
)
{}
PDNode
(
teller_t
&&
teller
,
PDPattern
*
pattern
,
const
std
::
string
&
name
=
""
,
PDNode
(
teller_t
&&
teller
,
PDPattern
*
pattern
,
const
std
::
string
&
name
=
""
,
Type
type
=
Type
::
kVar
)
Type
type
=
Type
::
kVar
)
:
teller_
(
std
::
move
(
teller
)),
:
teller_
(
std
::
move
(
teller
)),
pattern_
(
pattern
),
pattern_
(
pattern
),
...
@@ -398,16 +405,25 @@ struct KeyCounter {
...
@@ -398,16 +405,25 @@ struct KeyCounter {
return
x
;
return
x
;
}
}
#ifdef PADDLE_WITH_TENSORRT
static
int
IncCounter
(
const
std
::
string
&
key
)
{
return
dic_
[
key
]
++
;
}
static
void
CleanCounter
()
{
dic_
.
clear
();
}
private:
static
thread_local
std
::
unordered_map
<
std
::
string
,
size_t
>
dic_
;
#else
int
IncCounter
(
const
std
::
string
&
key
)
{
return
dic_
[
key
]
++
;
}
int
IncCounter
(
const
std
::
string
&
key
)
{
return
dic_
[
key
]
++
;
}
private:
private:
std
::
unordered_map
<
std
::
string
,
size_t
>
dic_
;
std
::
unordered_map
<
std
::
string
,
size_t
>
dic_
;
#endif
};
};
// Generate a unique PDNode's name with name_scope and id.
// Generate a unique PDNode's name with name_scope and id.
// The format is {name_scope}/{repr}/{id}/{name}
// The format is {name_scope}/{repr}/{id}/{name}
static
std
::
string
PDNodeName
(
const
std
::
string
&
name_scope
,
static
std
::
string
PDNodeName
(
const
std
::
string
&
name_scope
,
const
std
::
string
&
repr
,
size_t
id
,
const
std
::
string
&
repr
,
size_t
id
,
const
std
::
string
&
name
)
{
const
std
::
string
&
name
)
{
return
string
::
Sprintf
(
"%s/%s/%d/%s"
,
name_scope
,
repr
,
id
,
name
);
return
string
::
Sprintf
(
"%s/%s/%d/%s"
,
name_scope
,
repr
,
id
,
name
);
}
}
...
@@ -415,15 +431,15 @@ static std::string PDNodeName(const std::string& name_scope,
...
@@ -415,15 +431,15 @@ static std::string PDNodeName(const std::string& name_scope,
// The format is {name_scope}/{repr}/{id}
// The format is {name_scope}/{repr}/{id}
static
std
::
string
PDNodeName
(
const
std
::
string
&
name_scope
,
static
std
::
string
PDNodeName
(
const
std
::
string
&
name_scope
,
const
std
::
string
&
repr
)
{
const
std
::
string
&
repr
)
{
return
string
::
Sprintf
(
"%s/%s/%d"
,
name_scope
,
repr
,
return
string
::
Sprintf
(
KeyCounter
::
Instance
().
IncCounter
(
repr
));
"%s/%s/%d"
,
name_scope
,
repr
,
KeyCounter
::
Instance
().
IncCounter
(
repr
));
}
}
// Generate a unique key. It can be used for a universally unique temporary
// Generate a unique key. It can be used for a universally unique temporary
// name.
// name.
// The format is {repr}/{id}
// The format is {repr}/{id}
static
std
::
string
UniqueKey
(
const
std
::
string
&
repr
)
{
static
std
::
string
UniqueKey
(
const
std
::
string
&
repr
)
{
return
string
::
Sprintf
(
"%s/%d"
,
repr
,
return
string
::
Sprintf
(
KeyCounter
::
Instance
().
IncCounter
(
repr
));
"%s/%d"
,
repr
,
KeyCounter
::
Instance
().
IncCounter
(
repr
));
}
}
// Declare a PDNode in a pattern, will create two methods:
// Declare a PDNode in a pattern, will create two methods:
...
@@ -440,17 +456,19 @@ static std::string UniqueKey(const std::string& repr) {
...
@@ -440,17 +456,19 @@ static std::string UniqueKey(const std::string& repr) {
// arg: the argument declared by PATTERN_DECL_NODE in a pattern definition.
// arg: the argument declared by PATTERN_DECL_NODE in a pattern definition.
// pat: the pattern object.
// pat: the pattern object.
#define GET_IR_NODE_FROM_SUBGRAPH(var, arg, pat) \
#define GET_IR_NODE_FROM_SUBGRAPH(var, arg, pat) \
PADDLE_ENFORCE_NE(subgraph.count(pat.arg##_n()), 0UL, \
PADDLE_ENFORCE_NE(subgraph.count(pat.arg##_n()), \
0UL, \
platform::errors::NotFound("Node not found for PDNode %s", \
platform::errors::NotFound("Node not found for PDNode %s", \
pat.arg##_repr())); \
pat.arg##_repr())); \
Node* var = subgraph.at(pat.arg##_n()); \
Node* var = subgraph.at(pat.arg##_n()); \
PADDLE_ENFORCE_NOT_NULL(
\
PADDLE_ENFORCE_NOT_NULL(
var,
\
var, platform::errors::NotFound("node %s not exists in the sub-graph",
\
platform::errors::NotFound(
\
#arg));
"node %s not exists in the sub-graph",
#arg));
// The base class of all the patterns.
// The base class of all the patterns.
struct
PatternBase
{
struct
PatternBase
{
PatternBase
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
,
PatternBase
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
,
const
std
::
string
&
repr
)
const
std
::
string
&
repr
)
:
pattern
(
pattern
),
:
pattern
(
pattern
),
name_scope_
(
name_scope
),
name_scope_
(
name_scope
),
...
@@ -476,7 +494,8 @@ struct ConvBN : public PatternBase {
...
@@ -476,7 +494,8 @@ struct ConvBN : public PatternBase {
ConvBN
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
ConvBN
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"conv_bn"
)
{}
:
PatternBase
(
pattern
,
name_scope
,
"conv_bn"
)
{}
PDNode
*
operator
()(
PDNode
*
conv_input
,
const
std
::
string
&
conv_type
,
PDNode
*
operator
()(
PDNode
*
conv_input
,
const
std
::
string
&
conv_type
,
bool
with_eltwise_add
);
bool
with_eltwise_add
);
// declare operator node's name
// declare operator node's name
...
@@ -514,7 +533,8 @@ struct ConvActivation : public PatternBase {
...
@@ -514,7 +533,8 @@ struct ConvActivation : public PatternBase {
ConvActivation
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
ConvActivation
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"conv_activation"
)
{}
:
PatternBase
(
pattern
,
name_scope
,
"conv_activation"
)
{}
PDNode
*
operator
()(
PDNode
*
conv_input
,
std
::
string
conv_type
=
"conv2d"
,
PDNode
*
operator
()(
PDNode
*
conv_input
,
std
::
string
conv_type
=
"conv2d"
,
std
::
string
activation_type
=
"relu"
);
std
::
string
activation_type
=
"relu"
);
// declare operator node's name
// declare operator node's name
...
@@ -536,7 +556,8 @@ struct ElementwiseActivation : public PatternBase {
...
@@ -536,7 +556,8 @@ struct ElementwiseActivation : public PatternBase {
ElementwiseActivation
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
ElementwiseActivation
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"elementwise_add_activation"
)
{}
:
PatternBase
(
pattern
,
name_scope
,
"elementwise_add_activation"
)
{}
PDNode
*
operator
()(
PDNode
*
elementwise_a
,
const
std
::
string
&
elementwise_type
,
PDNode
*
operator
()(
PDNode
*
elementwise_a
,
const
std
::
string
&
elementwise_type
,
const
std
::
string
&
activation_type
);
const
std
::
string
&
activation_type
);
// declare operator node's name
// declare operator node's name
...
@@ -936,7 +957,8 @@ struct LinearAct : public PatternBase {
...
@@ -936,7 +957,8 @@ struct LinearAct : public PatternBase {
PDNode
*
operator
()(
PDNode
*
x
,
PDNode
*
operator
()(
PDNode
*
x
,
const
std
::
unordered_set
<
std
::
string
>&
act_types
,
const
std
::
unordered_set
<
std
::
string
>&
act_types
,
bool
with_grad_link
,
bool
is_act_grad_x_from_act
);
bool
with_grad_link
,
bool
is_act_grad_x_from_act
);
// declare operator node's name
// declare operator node's name
PATTERN_DECL_NODE
(
matmul
);
PATTERN_DECL_NODE
(
matmul
);
...
@@ -965,7 +987,8 @@ struct ElewiseAddMatmulAct : public PatternBase {
...
@@ -965,7 +987,8 @@ struct ElewiseAddMatmulAct : public PatternBase {
PDNode
*
operator
()(
PDNode
*
x
,
PDNode
*
operator
()(
PDNode
*
x
,
const
std
::
unordered_set
<
std
::
string
>&
act_grad_types
,
const
std
::
unordered_set
<
std
::
string
>&
act_grad_types
,
bool
without_x_gradient
,
bool
is_act_grad_x_from_act
);
bool
without_x_gradient
,
bool
is_act_grad_x_from_act
);
// declare operator node's name
// declare operator node's name
PATTERN_DECL_NODE
(
ele_add_grad
);
PATTERN_DECL_NODE
(
ele_add_grad
);
...
@@ -1062,7 +1085,8 @@ struct Elementwise : public PatternBase {
...
@@ -1062,7 +1085,8 @@ struct Elementwise : public PatternBase {
Elementwise
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
Elementwise
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"elementwise"
)
{}
:
PatternBase
(
pattern
,
name_scope
,
"elementwise"
)
{}
PDNode
*
operator
()(
PDNode
*
x_var
,
PDNode
*
y_var
,
PDNode
*
operator
()(
PDNode
*
x_var
,
PDNode
*
y_var
,
const
std
::
string
elementwise_type
);
const
std
::
string
elementwise_type
);
PATTERN_DECL_NODE
(
elementwise_op
);
PATTERN_DECL_NODE
(
elementwise_op
);
...
@@ -1088,11 +1112,14 @@ struct ElementwiseOp : public PatternBase {
...
@@ -1088,11 +1112,14 @@ struct ElementwiseOp : public PatternBase {
// This pattern allows operator output to be X or Y
// This pattern allows operator output to be X or Y
// and residual data Y or X, based on as_x flag
// and residual data Y or X, based on as_x flag
struct
ResidualElementwise
:
public
PatternBase
{
struct
ResidualElementwise
:
public
PatternBase
{
ResidualElementwise
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
,
ResidualElementwise
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
,
bool
as_x
)
bool
as_x
)
:
PatternBase
(
pattern
,
name_scope
,
"residual_elementwise"
)
{}
:
PatternBase
(
pattern
,
name_scope
,
"residual_elementwise"
)
{}
PDNode
*
operator
()(
PDNode
*
op_var
,
PDNode
*
residual_var
,
PDNode
*
operator
()(
PDNode
*
op_var
,
const
std
::
string
elementwise_type
,
bool
as_x
);
PDNode
*
residual_var
,
const
std
::
string
elementwise_type
,
bool
as_x
);
PATTERN_DECL_NODE
(
operator_output
);
PATTERN_DECL_NODE
(
operator_output
);
PATTERN_DECL_NODE
(
residual_data
);
PATTERN_DECL_NODE
(
residual_data
);
...
@@ -1467,8 +1494,8 @@ struct ConvElementwiseaddAct : public PatternBase {
...
@@ -1467,8 +1494,8 @@ struct ConvElementwiseaddAct : public PatternBase {
// Conv + ElementwiseAdd + ElementwiseAdd + Activation
// Conv + ElementwiseAdd + ElementwiseAdd + Activation
struct
ConvElementwiseadd2Act
:
public
PatternBase
{
struct
ConvElementwiseadd2Act
:
public
PatternBase
{
ConvElementwiseadd2Act
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
ConvElementwiseadd2Act
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
:
PatternBase
(
"conv_elementwiseadd2_elementwiseadd_act"
)
{}
pattern
,
name_scope
,
"conv_elementwiseadd2_elementwiseadd_act"
)
{}
PDNode
*
operator
()(
PDNode
*
conv_in
);
PDNode
*
operator
()(
PDNode
*
conv_in
);
...
@@ -1702,7 +1729,8 @@ struct DequantOpFuse : public PatternBase {
...
@@ -1702,7 +1729,8 @@ struct DequantOpFuse : public PatternBase {
DequantOpFuse
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
DequantOpFuse
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"dequant_fuse"
)
{}
:
PatternBase
(
pattern
,
name_scope
,
"dequant_fuse"
)
{}
void
operator
()(
PDNode
*
quant_op_input
,
const
std
::
string
&
quantized_op_type
,
void
operator
()(
PDNode
*
quant_op_input
,
const
std
::
string
&
quantized_op_type
,
const
std
::
string
&
dequant_type
,
const
std
::
string
&
dequant_type
,
const
std
::
string
&
weight_name
);
const
std
::
string
&
weight_name
);
...
@@ -1758,8 +1786,8 @@ struct DeleteQuantDequantOpPattern : public PatternBase {
...
@@ -1758,8 +1786,8 @@ struct DeleteQuantDequantOpPattern : public PatternBase {
struct
DeleteQuantDequantFilterOpPattern
:
public
PatternBase
{
struct
DeleteQuantDequantFilterOpPattern
:
public
PatternBase
{
DeleteQuantDequantFilterOpPattern
(
PDPattern
*
pattern
,
DeleteQuantDequantFilterOpPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
:
PatternBase
(
"delete_quantdequant_filter_op_pattern"
)
{}
pattern
,
name_scope
,
"delete_quantdequant_filter_op_pattern"
)
{}
void
operator
()();
void
operator
()();
...
@@ -1773,7 +1801,8 @@ struct DeleteQuantDequantFilterOpPattern : public PatternBase {
...
@@ -1773,7 +1801,8 @@ struct DeleteQuantDequantFilterOpPattern : public PatternBase {
struct
DeleteWeightQuantDequantLinearOpPattern
:
public
PatternBase
{
struct
DeleteWeightQuantDequantLinearOpPattern
:
public
PatternBase
{
DeleteWeightQuantDequantLinearOpPattern
(
PDPattern
*
pattern
,
DeleteWeightQuantDequantLinearOpPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
:
PatternBase
(
pattern
,
name_scope
,
"delete_weight_quant_dequant_linear_op_pattern"
)
{}
"delete_weight_quant_dequant_linear_op_pattern"
)
{}
void
operator
()();
void
operator
()();
...
@@ -1788,8 +1817,8 @@ struct DeleteWeightQuantDequantLinearOpPattern : public PatternBase {
...
@@ -1788,8 +1817,8 @@ struct DeleteWeightQuantDequantLinearOpPattern : public PatternBase {
struct
DeleteQuantDequantLinearOpPattern
:
public
PatternBase
{
struct
DeleteQuantDequantLinearOpPattern
:
public
PatternBase
{
DeleteQuantDequantLinearOpPattern
(
PDPattern
*
pattern
,
DeleteQuantDequantLinearOpPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
:
PatternBase
(
"delete_quant_dequant_linear_op_pattern"
)
{}
pattern
,
name_scope
,
"delete_quant_dequant_linear_op_pattern"
)
{}
void
operator
()();
void
operator
()();
...
@@ -1814,7 +1843,8 @@ struct ReshapeTransposeMatmulPattern : public PatternBase {
...
@@ -1814,7 +1843,8 @@ struct ReshapeTransposeMatmulPattern : public PatternBase {
const
std
::
string
&
name_scope
)
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"reshape_transpose_matmul"
)
{}
:
PatternBase
(
pattern
,
name_scope
,
"reshape_transpose_matmul"
)
{}
PDNode
*
operator
()(
const
std
::
string
&
op_name
,
bool
with_reshape_xshape
,
PDNode
*
operator
()(
const
std
::
string
&
op_name
,
bool
with_reshape_xshape
,
bool
with_transpose_xshape
);
bool
with_transpose_xshape
);
PATTERN_DECL_NODE
(
reshape_in
);
PATTERN_DECL_NODE
(
reshape_in
);
...
...
paddle/fluid/inference/api/analysis_predictor.cc
浏览文件 @
1ea9971a
...
@@ -82,9 +82,9 @@ namespace paddle {
...
@@ -82,9 +82,9 @@ namespace paddle {
using
inference
::
Singleton
;
using
inference
::
Singleton
;
#if PADDLE_WITH_TENSORRT
#if PADDLE_WITH_TENSORRT
using
inference
::
tensorrt
::
TRTInt8Calibrator
;
using
inference
::
tensorrt
::
TRTCalibratorEngine
;
using
inference
::
tensorrt
::
TRTCalibratorEngine
;
using
inference
::
tensorrt
::
TRTCalibratorEngineManager
;
using
inference
::
tensorrt
::
TRTCalibratorEngineManager
;
using
inference
::
tensorrt
::
TRTInt8Calibrator
;
#endif
#endif
int
AnalysisPredictor
::
clone_num_
=
1
;
int
AnalysisPredictor
::
clone_num_
=
1
;
...
@@ -101,7 +101,8 @@ bool IsPersistable(const framework::VarDesc *var) {
...
@@ -101,7 +101,8 @@ bool IsPersistable(const framework::VarDesc *var) {
}
}
}
// namespace
}
// namespace
bool
PaddleTensorToLoDTensor
(
const
PaddleTensor
&
pt
,
framework
::
LoDTensor
*
t
,
bool
PaddleTensorToLoDTensor
(
const
PaddleTensor
&
pt
,
framework
::
LoDTensor
*
t
,
const
platform
::
Place
&
place
)
{
const
platform
::
Place
&
place
)
{
framework
::
DDim
ddim
=
phi
::
make_ddim
(
pt
.
shape
);
framework
::
DDim
ddim
=
phi
::
make_ddim
(
pt
.
shape
);
void
*
input_ptr
;
void
*
input_ptr
;
...
@@ -129,18 +130,19 @@ bool PaddleTensorToLoDTensor(const PaddleTensor &pt, framework::LoDTensor *t,
...
@@ -129,18 +130,19 @@ bool PaddleTensorToLoDTensor(const PaddleTensor &pt, framework::LoDTensor *t,
if
(
platform
::
is_cpu_place
(
place
))
{
if
(
platform
::
is_cpu_place
(
place
))
{
// TODO(panyx0718): Init LoDTensor from existing memcpy to save a copy.
// TODO(panyx0718): Init LoDTensor from existing memcpy to save a copy.
std
::
memcpy
(
static_cast
<
void
*>
(
input_ptr
),
pt
.
data
.
data
(),
std
::
memcpy
(
pt
.
data
.
length
());
static_cast
<
void
*>
(
input_ptr
),
pt
.
data
.
data
(),
pt
.
data
.
length
());
}
else
if
(
platform
::
is_ipu_place
(
place
))
{
}
else
if
(
platform
::
is_ipu_place
(
place
))
{
#ifdef PADDLE_WITH_IPU
#ifdef PADDLE_WITH_IPU
std
::
memcpy
(
static_cast
<
void
*>
(
input_ptr
),
pt
.
data
.
data
(),
std
::
memcpy
(
pt
.
data
.
length
());
static_cast
<
void
*>
(
input_ptr
),
pt
.
data
.
data
(),
pt
.
data
.
length
());
#else
#else
PADDLE_THROW
(
paddle
::
platform
::
errors
::
Fatal
(
PADDLE_THROW
(
paddle
::
platform
::
errors
::
Fatal
(
"Not compile with WITH_IPU, should not reach here."
));
"Not compile with WITH_IPU, should not reach here."
));
#endif
#endif
}
else
if
(
platform
::
is_gpu_place
(
place
))
{
}
else
if
(
platform
::
is_gpu_place
(
place
))
{
PADDLE_ENFORCE_EQ
(
platform
::
is_xpu_place
(
place
),
false
,
PADDLE_ENFORCE_EQ
(
platform
::
is_xpu_place
(
place
),
false
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"Only one choice can be made between CPU and XPU."
));
"Only one choice can be made between CPU and XPU."
));
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
...
@@ -148,8 +150,11 @@ bool PaddleTensorToLoDTensor(const PaddleTensor &pt, framework::LoDTensor *t,
...
@@ -148,8 +150,11 @@ bool PaddleTensorToLoDTensor(const PaddleTensor &pt, framework::LoDTensor *t,
auto
*
dev_ctx
=
auto
*
dev_ctx
=
static_cast
<
const
platform
::
CUDADeviceContext
*>
(
pool
.
Get
(
place
));
static_cast
<
const
platform
::
CUDADeviceContext
*>
(
pool
.
Get
(
place
));
auto
dst_gpu_place
=
place
;
auto
dst_gpu_place
=
place
;
memory
::
Copy
(
dst_gpu_place
,
static_cast
<
void
*>
(
input_ptr
),
memory
::
Copy
(
dst_gpu_place
,
platform
::
CPUPlace
(),
pt
.
data
.
data
(),
pt
.
data
.
length
(),
static_cast
<
void
*>
(
input_ptr
),
platform
::
CPUPlace
(),
pt
.
data
.
data
(),
pt
.
data
.
length
(),
dev_ctx
->
stream
());
dev_ctx
->
stream
());
#else
#else
PADDLE_THROW
(
paddle
::
platform
::
errors
::
Fatal
(
PADDLE_THROW
(
paddle
::
platform
::
errors
::
Fatal
(
...
@@ -158,8 +163,11 @@ bool PaddleTensorToLoDTensor(const PaddleTensor &pt, framework::LoDTensor *t,
...
@@ -158,8 +163,11 @@ bool PaddleTensorToLoDTensor(const PaddleTensor &pt, framework::LoDTensor *t,
}
else
if
(
platform
::
is_xpu_place
(
place
))
{
}
else
if
(
platform
::
is_xpu_place
(
place
))
{
#ifdef PADDLE_WITH_XPU
#ifdef PADDLE_WITH_XPU
auto
dst_xpu_place
=
place
;
auto
dst_xpu_place
=
place
;
memory
::
Copy
(
dst_xpu_place
,
static_cast
<
void
*>
(
input_ptr
),
memory
::
Copy
(
dst_xpu_place
,
platform
::
CPUPlace
(),
pt
.
data
.
data
(),
pt
.
data
.
length
());
static_cast
<
void
*>
(
input_ptr
),
platform
::
CPUPlace
(),
pt
.
data
.
data
(),
pt
.
data
.
length
());
#else
#else
PADDLE_THROW
(
paddle
::
platform
::
errors
::
Fatal
(
PADDLE_THROW
(
paddle
::
platform
::
errors
::
Fatal
(
"Not compile with XPU, should not reach here."
));
"Not compile with XPU, should not reach here."
));
...
@@ -263,7 +271,8 @@ bool AnalysisPredictor::PrepareProgram(
...
@@ -263,7 +271,8 @@ bool AnalysisPredictor::PrepareProgram(
}
}
bool
AnalysisPredictor
::
CreateExecutor
()
{
bool
AnalysisPredictor
::
CreateExecutor
()
{
if
(
config_
.
use_gpu
())
{
if
(
config_
.
use_gpu
())
{
PADDLE_ENFORCE_EQ
(
config_
.
use_xpu
(),
false
,
PADDLE_ENFORCE_EQ
(
config_
.
use_xpu
(),
false
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"Only one choice can be made between CPU and XPU."
));
"Only one choice can be made between CPU and XPU."
));
place_
=
paddle
::
platform
::
CUDAPlace
(
config_
.
gpu_device_id
());
place_
=
paddle
::
platform
::
CUDAPlace
(
config_
.
gpu_device_id
());
...
@@ -357,7 +366,8 @@ static bool IsPrepareDataOptTargetOp(framework::OpDesc *op) {
...
@@ -357,7 +366,8 @@ static bool IsPrepareDataOptTargetOp(framework::OpDesc *op) {
}
}
static
void
DisablePrepareDataOpt
(
static
void
DisablePrepareDataOpt
(
std
::
shared_ptr
<
framework
::
ProgramDesc
>
inference_program
,
int
block
,
std
::
shared_ptr
<
framework
::
ProgramDesc
>
inference_program
,
int
block
,
bool
pre_disable_opt
)
{
bool
pre_disable_opt
)
{
bool
disable_opt
=
false
;
bool
disable_opt
=
false
;
auto
&
infer_block
=
inference_program
->
Block
(
block
);
auto
&
infer_block
=
inference_program
->
Block
(
block
);
...
@@ -367,8 +377,8 @@ static void DisablePrepareDataOpt(
...
@@ -367,8 +377,8 @@ static void DisablePrepareDataOpt(
}
}
if
(
op
->
HasAttr
(
"sub_block"
))
{
if
(
op
->
HasAttr
(
"sub_block"
))
{
int
blockID
=
op
->
GetBlockAttrId
(
"sub_block"
);
int
blockID
=
op
->
GetBlockAttrId
(
"sub_block"
);
DisablePrepareDataOpt
(
inference_program
,
blockID
,
DisablePrepareDataOpt
(
disable_opt
||
pre_disable_opt
);
inference_program
,
blockID
,
disable_opt
||
pre_disable_opt
);
}
}
// disable prepare data if unfriendly op is found
// disable prepare data if unfriendly op is found
if
(
!
disable_opt
)
{
if
(
!
disable_opt
)
{
...
@@ -386,8 +396,8 @@ bool AnalysisPredictor::PrepareExecutor() {
...
@@ -386,8 +396,8 @@ bool AnalysisPredictor::PrepareExecutor() {
#endif
#endif
DisablePrepareDataOpt
(
inference_program_
,
0
,
false
);
DisablePrepareDataOpt
(
inference_program_
,
0
,
false
);
executor_
->
Prepare
(
sub_scope_
,
*
inference_program_
,
0
,
executor_
->
Prepare
(
config_
.
use_feed_fetch_ops_
);
sub_scope_
,
*
inference_program_
,
0
,
config_
.
use_feed_fetch_ops_
);
PADDLE_ENFORCE_NOT_NULL
(
sub_scope_
,
PADDLE_ENFORCE_NOT_NULL
(
sub_scope_
,
platform
::
errors
::
PreconditionNotMet
(
platform
::
errors
::
PreconditionNotMet
(
...
@@ -433,8 +443,13 @@ bool AnalysisPredictor::PrepareFleetExecutor() {
...
@@ -433,8 +443,13 @@ bool AnalysisPredictor::PrepareFleetExecutor() {
feed_fetch_vars
.
emplace_back
(
pair
.
second
);
feed_fetch_vars
.
emplace_back
(
pair
.
second
);
}
}
fleet_exe_
->
Init
(
config_
.
dist_config
().
carrier_id
(),
fleet_exe_
->
Init
(
config_
.
dist_config
().
carrier_id
(),
*
(
inference_program_
.
get
()),
scope_
.
get
(),
place_
,
1
,
*
(
inference_program_
.
get
()),
{
task_node_
.
get
()},
id_to_rank
,
feed_fetch_vars
);
scope_
.
get
(),
place_
,
1
,
{
task_node_
.
get
()},
id_to_rank
,
feed_fetch_vars
);
return
true
;
return
true
;
}
}
...
@@ -471,8 +486,12 @@ bool AnalysisPredictor::CommInit() {
...
@@ -471,8 +486,12 @@ bool AnalysisPredictor::CommInit() {
peer_endpoints
.
emplace_back
(
peer_endpoints
.
emplace_back
(
config_
.
dist_config
().
trainer_endpoints
()[
rank
]);
config_
.
dist_config
().
trainer_endpoints
()[
rank
]);
}
}
InsertCommOp
(
var_name_base
+
std
::
to_string
(
order
),
ranks_in_group
,
InsertCommOp
(
var_name_base
+
std
::
to_string
(
order
),
rank_in_group
,
peer_endpoints
,
comm_init_block
,
ring_id
);
ranks_in_group
,
rank_in_group
,
peer_endpoints
,
comm_init_block
,
ring_id
);
order
+=
1
;
order
+=
1
;
}
}
framework
::
NaiveExecutor
e
(
place_
);
framework
::
NaiveExecutor
e
(
place_
);
...
@@ -484,8 +503,11 @@ bool AnalysisPredictor::CommInit() {
...
@@ -484,8 +503,11 @@ bool AnalysisPredictor::CommInit() {
}
}
void
AnalysisPredictor
::
InsertCommOp
(
void
AnalysisPredictor
::
InsertCommOp
(
std
::
string
tmp_var_name
,
int
nranks
,
int
rank
,
std
::
string
tmp_var_name
,
const
std
::
vector
<
std
::
string
>
&
peer_endpoints
,
framework
::
BlockDesc
*
block
,
int
nranks
,
int
rank
,
const
std
::
vector
<
std
::
string
>
&
peer_endpoints
,
framework
::
BlockDesc
*
block
,
int
ring_id
)
{
int
ring_id
)
{
/*
/*
* tmp_var_name: the var name for var comm_id
* tmp_var_name: the var name for var comm_id
...
@@ -542,7 +564,8 @@ bool AnalysisPredictor::LoadConverterConfig(
...
@@ -542,7 +564,8 @@ bool AnalysisPredictor::LoadConverterConfig(
<<
config_
.
dist_config
().
comm_init_config
()
<<
"
\n
"
;
<<
config_
.
dist_config
().
comm_init_config
()
<<
"
\n
"
;
std
::
ifstream
fin
(
config_
.
dist_config
().
comm_init_config
(),
std
::
ios
::
in
);
std
::
ifstream
fin
(
config_
.
dist_config
().
comm_init_config
(),
std
::
ios
::
in
);
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
static_cast
<
bool
>
(
fin
.
is_open
()),
true
,
static_cast
<
bool
>
(
fin
.
is_open
()),
true
,
platform
::
errors
::
NotFound
(
platform
::
errors
::
NotFound
(
"Cannot open file %s, please confirm whether the file is normal."
,
"Cannot open file %s, please confirm whether the file is normal."
,
config_
.
dist_config
().
comm_init_config
()));
config_
.
dist_config
().
comm_init_config
()));
...
@@ -686,8 +709,9 @@ bool AnalysisPredictor::Run(const std::vector<PaddleTensor> &inputs,
...
@@ -686,8 +709,9 @@ bool AnalysisPredictor::Run(const std::vector<PaddleTensor> &inputs,
timer
.
tic
();
timer
.
tic
();
// set feed variable
// set feed variable
framework
::
Scope
*
scope
=
sub_scope_
?
sub_scope_
:
scope_
.
get
();
framework
::
Scope
*
scope
=
sub_scope_
?
sub_scope_
:
scope_
.
get
();
PADDLE_ENFORCE_NOT_NULL
(
scope
,
platform
::
errors
::
PreconditionNotMet
(
PADDLE_ENFORCE_NOT_NULL
(
"The scope should not be nullptr."
));
scope
,
platform
::
errors
::
PreconditionNotMet
(
"The scope should not be nullptr."
));
if
(
!
SetFeed
(
inputs
,
scope
))
{
if
(
!
SetFeed
(
inputs
,
scope
))
{
LOG
(
ERROR
)
<<
"fail to set feed"
;
LOG
(
ERROR
)
<<
"fail to set feed"
;
return
false
;
return
false
;
...
@@ -790,9 +814,11 @@ bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs,
...
@@ -790,9 +814,11 @@ bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs,
for
(
size_t
i
=
0
;
i
<
fetches_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
fetches_
.
size
();
++
i
)
{
int
idx
=
BOOST_GET_CONST
(
int
,
fetches_
[
i
]
->
GetAttr
(
"col"
));
int
idx
=
BOOST_GET_CONST
(
int
,
fetches_
[
i
]
->
GetAttr
(
"col"
));
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
static_cast
<
size_t
>
(
idx
),
i
,
static_cast
<
size_t
>
(
idx
),
i
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"Fetch op's col attr(%d) should be equal to the index(%d)"
,
idx
,
"Fetch op's col attr(%d) should be equal to the index(%d)"
,
idx
,
i
));
i
));
framework
::
FetchType
&
fetch_var
=
framework
::
FetchType
&
fetch_var
=
framework
::
GetFetchVariable
(
*
scope
,
"fetch"
,
idx
);
framework
::
GetFetchVariable
(
*
scope
,
"fetch"
,
idx
);
...
@@ -833,7 +859,8 @@ void AnalysisPredictor::PrepareArgument() {
...
@@ -833,7 +859,8 @@ void AnalysisPredictor::PrepareArgument() {
if
(
!
config_
.
model_dir
().
empty
())
{
if
(
!
config_
.
model_dir
().
empty
())
{
argument_
.
SetModelDir
(
config_
.
model_dir
());
argument_
.
SetModelDir
(
config_
.
model_dir
());
}
else
{
}
else
{
PADDLE_ENFORCE_EQ
(
config_
.
prog_file
().
empty
(),
false
,
PADDLE_ENFORCE_EQ
(
config_
.
prog_file
().
empty
(),
false
,
platform
::
errors
::
PreconditionNotMet
(
platform
::
errors
::
PreconditionNotMet
(
"Either model_dir or prog_file should be set."
));
"Either model_dir or prog_file should be set."
));
std
::
string
dir
=
inference
::
analysis
::
GetDirRoot
(
config_
.
prog_file
());
std
::
string
dir
=
inference
::
analysis
::
GetDirRoot
(
config_
.
prog_file
());
...
@@ -969,7 +996,8 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
...
@@ -969,7 +996,8 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
Analyzer
().
Run
(
&
argument_
);
Analyzer
().
Run
(
&
argument_
);
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
argument_
.
scope_valid
(),
true
,
argument_
.
scope_valid
(),
true
,
platform
::
errors
::
InvalidArgument
(
"The argument scope should be valid."
));
platform
::
errors
::
InvalidArgument
(
"The argument scope should be valid."
));
VLOG
(
5
)
<<
"to prepare executor"
;
VLOG
(
5
)
<<
"to prepare executor"
;
ARGUMENT_CHECK_FIELD
((
&
argument_
),
ir_analyzed_program
);
ARGUMENT_CHECK_FIELD
((
&
argument_
),
ir_analyzed_program
);
...
@@ -1008,8 +1036,9 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
...
@@ -1008,8 +1036,9 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
}
}
template
<
>
template
<
>
std
::
unique_ptr
<
PaddlePredictor
>
CreatePaddlePredictor
<
std
::
unique_ptr
<
PaddlePredictor
>
AnalysisConfig
,
PaddleEngineKind
::
kAnalysis
>
(
const
AnalysisConfig
&
config
)
{
CreatePaddlePredictor
<
AnalysisConfig
,
PaddleEngineKind
::
kAnalysis
>
(
const
AnalysisConfig
&
config
)
{
// TODO(NHZlX): Should add the link to the doc of
// TODO(NHZlX): Should add the link to the doc of
// paddle_infer::CreatePredictor<paddle_infer::Config>
// paddle_infer::CreatePredictor<paddle_infer::Config>
if
(
config
.
glog_info_disabled
())
{
if
(
config
.
glog_info_disabled
())
{
...
@@ -1018,7 +1047,8 @@ std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
...
@@ -1018,7 +1047,8 @@ std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
}
}
VLOG
(
3
)
<<
"create AnalysisConfig"
;
VLOG
(
3
)
<<
"create AnalysisConfig"
;
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
config
.
is_valid
(),
true
,
config
.
is_valid
(),
true
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"Note: Each config can only be used for one predictor."
));
"Note: Each config can only be used for one predictor."
));
...
@@ -1035,11 +1065,13 @@ std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
...
@@ -1035,11 +1065,13 @@ std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
std
::
call_once
(
gflags_initialized
,
[
&
]()
{
std
::
call_once
(
gflags_initialized
,
[
&
]()
{
std
::
vector
<
std
::
string
>
gflags
;
std
::
vector
<
std
::
string
>
gflags
;
PADDLE_ENFORCE_GE
(
PADDLE_ENFORCE_GE
(
config
.
memory_pool_init_size_mb
(),
0.
f
,
config
.
memory_pool_init_size_mb
(),
0.
f
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"The size of memory pool should be greater than 0."
));
"The size of memory pool should be greater than 0."
));
PADDLE_ENFORCE_GE
(
PADDLE_ENFORCE_GE
(
config
.
gpu_device_id
(),
0
,
config
.
gpu_device_id
(),
0
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"Invalid device id (%d). The device id should be greater than 0."
,
"Invalid device id (%d). The device id should be greater than 0."
,
config
.
gpu_device_id
()));
config
.
gpu_device_id
()));
...
@@ -1105,6 +1137,10 @@ std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
...
@@ -1105,6 +1137,10 @@ std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
config
.
SetInValid
();
config
.
SetInValid
();
auto
predictor_p
=
dynamic_cast
<
AnalysisPredictor
*>
(
predictor
.
get
());
auto
predictor_p
=
dynamic_cast
<
AnalysisPredictor
*>
(
predictor
.
get
());
#ifdef PADDLE_WITH_TENSORRT
paddle
::
framework
::
ir
::
patterns
::
KeyCounter
::
Instance
().
CleanCounter
();
#endif
if
(
!
predictor_p
->
Init
(
nullptr
))
{
if
(
!
predictor_p
->
Init
(
nullptr
))
{
return
nullptr
;
return
nullptr
;
}
}
...
@@ -1154,8 +1190,9 @@ void AnalysisPredictor::PrepareFeedFetch() {
...
@@ -1154,8 +1190,9 @@ void AnalysisPredictor::PrepareFeedFetch() {
}
}
void
AnalysisPredictor
::
CreateFeedFetchVar
(
framework
::
Scope
*
scope
)
{
void
AnalysisPredictor
::
CreateFeedFetchVar
(
framework
::
Scope
*
scope
)
{
PADDLE_ENFORCE_NOT_NULL
(
scope
,
platform
::
errors
::
InvalidArgument
(
PADDLE_ENFORCE_NOT_NULL
(
"The scope should not be nullptr."
));
scope
,
platform
::
errors
::
InvalidArgument
(
"The scope should not be nullptr."
));
auto
*
var
=
scope
->
Var
(
"feed"
);
auto
*
var
=
scope
->
Var
(
"feed"
);
var
->
GetMutable
<
framework
::
FeedList
>
();
var
->
GetMutable
<
framework
::
FeedList
>
();
var
=
scope
->
Var
(
"fetch"
);
var
=
scope
->
Var
(
"fetch"
);
...
@@ -1176,8 +1213,9 @@ AnalysisPredictor::GetInputTensorShape() {
...
@@ -1176,8 +1213,9 @@ AnalysisPredictor::GetInputTensorShape() {
std
::
vector
<
std
::
string
>
names
=
GetInputNames
();
std
::
vector
<
std
::
string
>
names
=
GetInputNames
();
for
(
std
::
string
name
:
names
)
{
for
(
std
::
string
name
:
names
)
{
auto
*
var
=
inference_program_
->
Block
(
0
).
FindVar
(
name
);
auto
*
var
=
inference_program_
->
Block
(
0
).
FindVar
(
name
);
PADDLE_ENFORCE_NOT_NULL
(
var
,
platform
::
errors
::
PreconditionNotMet
(
PADDLE_ENFORCE_NOT_NULL
(
"Input %s does not exist."
,
name
));
var
,
platform
::
errors
::
PreconditionNotMet
(
"Input %s does not exist."
,
name
));
input_shapes
[
name
]
=
var
->
GetShape
();
input_shapes
[
name
]
=
var
->
GetShape
();
}
}
return
input_shapes
;
return
input_shapes
;
...
@@ -1398,7 +1436,8 @@ void AnalysisPredictor::StatisticShapeRangeInfo() {
...
@@ -1398,7 +1436,8 @@ void AnalysisPredictor::StatisticShapeRangeInfo() {
std
::
vector
<
std
::
pair
<
int32_t
,
int32_t
>>
counter
;
std
::
vector
<
std
::
pair
<
int32_t
,
int32_t
>>
counter
;
for
(
auto
&
it
:
m
)
counter
.
push_back
(
it
);
for
(
auto
&
it
:
m
)
counter
.
push_back
(
it
);
std
::
sort
(
std
::
sort
(
counter
.
begin
(),
counter
.
end
(),
counter
.
begin
(),
counter
.
end
(),
[](
std
::
pair
<
int32_t
,
int32_t
>
&
a
,
std
::
pair
<
int32_t
,
int32_t
>
&
b
)
{
[](
std
::
pair
<
int32_t
,
int32_t
>
&
a
,
std
::
pair
<
int32_t
,
int32_t
>
&
b
)
{
return
a
.
second
>
b
.
second
;
return
a
.
second
>
b
.
second
;
});
});
...
@@ -1420,8 +1459,8 @@ void AnalysisPredictor::StatisticShapeRangeInfo() {
...
@@ -1420,8 +1459,8 @@ void AnalysisPredictor::StatisticShapeRangeInfo() {
opt_shapes
[
name
]
=
opt_shape
;
opt_shapes
[
name
]
=
opt_shape
;
}
}
inference
::
SerializeShapeRangeInfo
(
config_
.
shape_range_info_path
(),
inference
::
SerializeShapeRangeInfo
(
min_shapes
,
max_shapes
,
opt_shapes
);
config_
.
shape_range_info_path
(),
min_shapes
,
max_shapes
,
opt_shapes
);
}
}
bool
AnalysisPredictor
::
LoadProgramDesc
()
{
bool
AnalysisPredictor
::
LoadProgramDesc
()
{
...
@@ -1441,7 +1480,8 @@ bool AnalysisPredictor::LoadProgramDesc() {
...
@@ -1441,7 +1480,8 @@ bool AnalysisPredictor::LoadProgramDesc() {
return
false
;
return
false
;
}
}
LOG
(
ERROR
)
<<
string
::
Sprintf
(
LOG
(
ERROR
)
<<
string
::
Sprintf
(
"not valid model path '%s' or program path '%s'."
,
config_
.
model_dir
(),
"not valid model path '%s' or program path '%s'."
,
config_
.
model_dir
(),
config_
.
params_file
());
config_
.
params_file
());
return
false
;
return
false
;
}
}
...
@@ -1453,7 +1493,8 @@ bool AnalysisPredictor::LoadProgramDesc() {
...
@@ -1453,7 +1493,8 @@ bool AnalysisPredictor::LoadProgramDesc() {
// Read binary
// Read binary
std
::
ifstream
fin
(
filename
,
std
::
ios
::
in
|
std
::
ios
::
binary
);
std
::
ifstream
fin
(
filename
,
std
::
ios
::
in
|
std
::
ios
::
binary
);
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
static_cast
<
bool
>
(
fin
.
is_open
()),
true
,
static_cast
<
bool
>
(
fin
.
is_open
()),
true
,
platform
::
errors
::
NotFound
(
platform
::
errors
::
NotFound
(
"Cannot open file %s, please confirm whether the file is normal."
,
"Cannot open file %s, please confirm whether the file is normal."
,
filename
));
filename
));
...
@@ -1555,7 +1596,8 @@ void AnalysisPredictor::ClearIntermediateTensor() {
...
@@ -1555,7 +1596,8 @@ void AnalysisPredictor::ClearIntermediateTensor() {
#if PADDLE_WITH_TENSORRT
#if PADDLE_WITH_TENSORRT
bool
AnalysisPredictor
::
SaveTrtCalibToDisk
()
{
bool
AnalysisPredictor
::
SaveTrtCalibToDisk
()
{
PADDLE_ENFORCE_EQ
(
config_
.
tensorrt_engine_enabled
(),
true
,
PADDLE_ENFORCE_EQ
(
config_
.
tensorrt_engine_enabled
(),
true
,
platform
::
errors
::
PreconditionNotMet
(
platform
::
errors
::
PreconditionNotMet
(
"This func can be invoked only in trt mode"
));
"This func can be invoked only in trt mode"
));
auto
&
block
=
inference_program_
->
Block
(
0
);
auto
&
block
=
inference_program_
->
Block
(
0
);
...
@@ -1782,8 +1824,10 @@ Predictor::Predictor(const Config &config) {
...
@@ -1782,8 +1824,10 @@ Predictor::Predictor(const Config &config) {
<<
"Paddle2ONNX do't support convert the Model, fall back to using "
<<
"Paddle2ONNX do't support convert the Model, fall back to using "
"Paddle Inference."
;
"Paddle Inference."
;
}
else
{
}
else
{
predictor_
=
paddle
::
CreatePaddlePredictor
<
predictor_
=
Config
,
paddle
::
PaddleEngineKind
::
kONNXRuntime
>
(
config
);
paddle
::
CreatePaddlePredictor
<
Config
,
paddle
::
PaddleEngineKind
::
kONNXRuntime
>
(
config
);
return
;
return
;
}
}
#else
#else
...
@@ -1793,8 +1837,10 @@ Predictor::Predictor(const Config &config) {
...
@@ -1793,8 +1837,10 @@ Predictor::Predictor(const Config &config) {
"fall back to using Paddle Inference."
;
"fall back to using Paddle Inference."
;
#endif
#endif
}
}
predictor_
=
paddle
::
CreatePaddlePredictor
<
predictor_
=
Config
,
paddle
::
PaddleEngineKind
::
kAnalysis
>
(
config
);
paddle
::
CreatePaddlePredictor
<
Config
,
paddle
::
PaddleEngineKind
::
kAnalysis
>
(
config
);
}
}
std
::
vector
<
std
::
string
>
Predictor
::
GetInputNames
()
{
std
::
vector
<
std
::
string
>
Predictor
::
GetInputNames
()
{
...
@@ -1876,7 +1922,8 @@ std::shared_ptr<Predictor> CreatePredictor(const Config &config) { // NOLINT
...
@@ -1876,7 +1922,8 @@ std::shared_ptr<Predictor> CreatePredictor(const Config &config) { // NOLINT
namespace
services
{
namespace
services
{
PredictorPool
::
PredictorPool
(
const
Config
&
config
,
size_t
size
)
{
PredictorPool
::
PredictorPool
(
const
Config
&
config
,
size_t
size
)
{
PADDLE_ENFORCE_GE
(
PADDLE_ENFORCE_GE
(
size
,
1UL
,
size
,
1UL
,
paddle
::
platform
::
errors
::
InvalidArgument
(
paddle
::
platform
::
errors
::
InvalidArgument
(
"The predictor pool size should be greater than 1, but it's (%d)"
,
"The predictor pool size should be greater than 1, but it's (%d)"
,
size
));
size
));
...
@@ -1895,9 +1942,11 @@ PredictorPool::PredictorPool(const Config &config, size_t size) {
...
@@ -1895,9 +1942,11 @@ PredictorPool::PredictorPool(const Config &config, size_t size) {
Predictor
*
PredictorPool
::
Retrive
(
size_t
idx
)
{
Predictor
*
PredictorPool
::
Retrive
(
size_t
idx
)
{
PADDLE_ENFORCE_LT
(
PADDLE_ENFORCE_LT
(
idx
,
preds_
.
size
()
+
1
,
idx
,
preds_
.
size
()
+
1
,
paddle
::
platform
::
errors
::
InvalidArgument
(
paddle
::
platform
::
errors
::
InvalidArgument
(
"There are (%d) predictors in the pool, but the idx is (%d)"
,
idx
,
"There are (%d) predictors in the pool, but the idx is (%d)"
,
idx
,
preds_
.
size
()
+
1
));
preds_
.
size
()
+
1
));
if
(
idx
==
0
)
{
if
(
idx
==
0
)
{
return
main_pred_
.
get
();
return
main_pred_
.
get
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录