Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
1ea9971a
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1ea9971a
编写于
6月 30, 2022
作者:
J
JingZhuangzhuang
提交者:
GitHub
6月 30, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
modify graph_pattern to thread_local (#43945)
上级
26187c27
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
268 addition
and
129 deletion
+268
-129
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+103
-43
paddle/fluid/framework/ir/graph_pattern_detector.h
paddle/fluid/framework/ir/graph_pattern_detector.h
+64
-34
paddle/fluid/inference/api/analysis_predictor.cc
paddle/fluid/inference/api/analysis_predictor.cc
+101
-52
未找到文件。
paddle/fluid/framework/ir/graph_pattern_detector.cc
浏览文件 @
1ea9971a
...
...
@@ -28,10 +28,17 @@ using string::Style;
size_t
PDPattern
::
id_
=
0UL
;
#ifdef PADDLE_WITH_TENSORRT
namespace
patterns
{
thread_local
std
::
unordered_map
<
std
::
string
,
size_t
>
KeyCounter
::
dic_
;
}
#endif
PDNode
*
PDPattern
::
NewNode
(
const
std
::
string
&
name
)
{
if
(
!
name
.
empty
())
{
PADDLE_ENFORCE_EQ
(
node_map_
.
count
(
name
),
0UL
,
node_map_
.
count
(
name
),
0UL
,
platform
::
errors
::
PreconditionNotMet
(
"PDNode's name should be unique, get duplicate [%s]"
,
name
));
}
...
...
@@ -45,7 +52,8 @@ PDNode *PDPattern::NewNode(const std::string &name) {
PDNode
*
PDPattern
::
NewNode
(
PDNode
::
teller_t
&&
teller
,
const
std
::
string
&
name
)
{
if
(
!
name
.
empty
())
{
PADDLE_ENFORCE_EQ
(
node_map_
.
count
(
name
),
0UL
,
node_map_
.
count
(
name
),
0UL
,
platform
::
errors
::
PreconditionNotMet
(
"PDNode's name should be unique, get duplicate [%s]"
,
name
));
}
...
...
@@ -70,7 +78,9 @@ void PDPattern::AddEdge(PDNode *a, PDNode *b) {
a
,
platform
::
errors
::
NotFound
(
"PDNode %s is not found."
,
a
->
name
()));
PADDLE_ENFORCE_NOT_NULL
(
b
,
platform
::
errors
::
NotFound
(
"PDNode %s is not found."
,
b
->
name
()));
PADDLE_ENFORCE_NE
(
a
,
b
,
platform
::
errors
::
PermissionDenied
(
PADDLE_ENFORCE_NE
(
a
,
b
,
platform
::
errors
::
PermissionDenied
(
"Cannot connect the same node in the graph."
));
edges_
.
emplace_back
(
a
,
b
);
}
...
...
@@ -128,7 +138,8 @@ void GraphPatternDetector::ValidateByNodeRole(
subgraphs
->
erase
(
std
::
remove_if
(
subgraphs
->
begin
(),
subgraphs
->
end
(),
subgraphs
->
begin
(),
subgraphs
->
end
(),
[](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
)
->
bool
{
// Collect the inputs and outputs.
std
::
set
<
Node
*>
ios
;
...
...
@@ -310,7 +321,8 @@ void GraphPatternDetector::SortSubgraphs(
}
std
::
sort
(
subgraphs
->
begin
(),
subgraphs
->
end
(),
subgraphs
->
begin
(),
subgraphs
->
end
(),
[](
const
GraphPatternDetector
::
subgraph_t
&
a
,
const
GraphPatternDetector
::
subgraph_t
&
b
)
{
for
(
auto
&
item
:
a
)
{
...
...
@@ -438,7 +450,8 @@ PDNode *PDNode::assert_is_persistable_var() {
}
PDNode
*
PDNode
::
assert_is_op_nth_input
(
const
std
::
string
&
op_type
,
const
std
::
string
&
argument
,
int
nth
)
{
const
std
::
string
&
argument
,
int
nth
)
{
assert_is_var
();
assert_is_op_input
(
op_type
);
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
...
...
@@ -453,7 +466,8 @@ PDNode *PDNode::assert_is_op_nth_input(const std::string &op_type,
}
PDNode
*
PDNode
::
assert_is_op_nth_output
(
const
std
::
string
&
op_type
,
const
std
::
string
&
argument
,
int
nth
)
{
const
std
::
string
&
argument
,
int
nth
)
{
assert_is_var
();
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
for
(
auto
*
op
:
x
->
inputs
)
{
...
...
@@ -580,7 +594,8 @@ PDNode *PDNode::assert_is_ops(const std::unordered_set<std::string> &op_types) {
PDNode
*
PDNode
::
assert_is_ops_nth_input
(
const
std
::
unordered_set
<
std
::
string
>
&
op_types
,
const
std
::
string
&
argument
,
int
nth
)
{
const
std
::
string
&
argument
,
int
nth
)
{
assert_is_var
();
assert_is_ops_input
(
op_types
);
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
...
...
@@ -596,7 +611,8 @@ PDNode *PDNode::assert_is_ops_nth_input(
PDNode
*
PDNode
::
assert_is_ops_nth_output
(
const
std
::
unordered_set
<
std
::
string
>
&
op_types
,
const
std
::
string
&
argument
,
int
nth
)
{
const
std
::
string
&
argument
,
int
nth
)
{
assert_is_var
();
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
for
(
auto
*
op
:
x
->
inputs
)
{
...
...
@@ -693,11 +709,13 @@ bool VarLinksToOp(Node *node, const std::string &op_type) {
bool
IsNthInput
(
Node
*
var
,
Node
*
op
,
const
std
::
string
&
argument
,
size_t
nth
)
{
PADDLE_ENFORCE_EQ
(
var
->
IsVar
(),
true
,
var
->
IsVar
(),
true
,
platform
::
errors
::
InvalidArgument
(
"First parameter of function IsNthInput must be Node::Var"
));
PADDLE_ENFORCE_EQ
(
op
->
IsOp
(),
true
,
op
->
IsOp
(),
true
,
platform
::
errors
::
InvalidArgument
(
"Second parameter of function IsNthInput must be Node::Op"
));
if
(
!
HasInput
(
op
,
argument
)
||
op
->
Op
()
->
Input
(
argument
).
size
()
<=
nth
)
...
...
@@ -707,7 +725,8 @@ bool IsNthInput(Node *var, Node *op, const std::string &argument, size_t nth) {
bool
HasInput
(
Node
*
op
,
const
std
::
string
&
argument
)
{
PADDLE_ENFORCE_EQ
(
op
->
IsOp
(),
true
,
op
->
IsOp
(),
true
,
platform
::
errors
::
InvalidArgument
(
"First parameter of function HasInput must be Node::Op"
));
auto
const
&
names
=
op
->
Op
()
->
InputNames
();
...
...
@@ -718,7 +737,8 @@ bool HasInput(Node *op, const std::string &argument) {
bool
HasOutput
(
Node
*
op
,
const
std
::
string
&
argument
)
{
PADDLE_ENFORCE_EQ
(
op
->
IsOp
(),
true
,
op
->
IsOp
(),
true
,
platform
::
errors
::
InvalidArgument
(
"First parameter of function HasOuput must be Node::Op"
));
auto
const
&
names
=
op
->
Op
()
->
OutputNames
();
...
...
@@ -729,11 +749,13 @@ bool HasOutput(Node *op, const std::string &argument) {
bool
IsNthOutput
(
Node
*
var
,
Node
*
op
,
const
std
::
string
&
argument
,
size_t
nth
)
{
PADDLE_ENFORCE_EQ
(
var
->
IsVar
(),
true
,
var
->
IsVar
(),
true
,
platform
::
errors
::
InvalidArgument
(
"First parameter of function IsNthOutput must be Node::Var"
));
PADDLE_ENFORCE_EQ
(
op
->
IsOp
(),
true
,
op
->
IsOp
(),
true
,
platform
::
errors
::
InvalidArgument
(
"Second parameter of function IsNthOutput must be Node::Op"
));
if
(
!
HasOutput
(
op
,
argument
)
||
op
->
Op
()
->
Output
(
argument
).
size
()
<=
nth
)
...
...
@@ -875,22 +897,35 @@ PDNode *patterns::ConvBN::operator()(paddle::framework::ir::PDNode *conv_input,
eltwise_op
->
LinksFrom
({
conv_out_var
,
eltwise_y_in_var
})
.
LinksTo
({
eltwise_out_var
});
batch_norm_op
->
LinksFrom
({
eltwise_out_var
,
bn_scale_var
,
bn_bias_var
,
bn_mean_var
,
->
LinksFrom
({
eltwise_out_var
,
bn_scale_var
,
bn_bias_var
,
bn_mean_var
,
bn_variance_var
})
.
LinksTo
({
bn_out_var
,
bn_mean_out_var
,
bn_variance_out_var
,
bn_saved_mean_var
,
bn_saved_variance_var
});
.
LinksTo
({
bn_out_var
,
bn_mean_out_var
,
bn_variance_out_var
,
bn_saved_mean_var
,
bn_saved_variance_var
});
}
else
{
batch_norm_op
->
LinksFrom
({
conv_out_var
,
bn_scale_var
,
bn_bias_var
,
bn_mean_var
,
->
LinksFrom
({
conv_out_var
,
bn_scale_var
,
bn_bias_var
,
bn_mean_var
,
bn_variance_var
})
.
LinksTo
({
bn_out_var
,
bn_mean_out_var
,
bn_variance_out_var
,
bn_saved_mean_var
,
bn_saved_variance_var
});
.
LinksTo
({
bn_out_var
,
bn_mean_out_var
,
bn_variance_out_var
,
bn_saved_mean_var
,
bn_saved_variance_var
});
}
return
bn_out_var
;
}
PDNode
*
patterns
::
ConvActivation
::
operator
()(
paddle
::
framework
::
ir
::
PDNode
*
conv_input
,
std
::
string
conv_type
,
paddle
::
framework
::
ir
::
PDNode
*
conv_input
,
std
::
string
conv_type
,
std
::
string
activation_type
)
{
// Create Operators
conv_input
->
assert_is_op_input
(
conv_type
,
"Input"
);
...
...
@@ -920,7 +955,8 @@ PDNode *patterns::ConvActivation::operator()(
PDNode
*
patterns
::
ElementwiseActivation
::
operator
()(
paddle
::
framework
::
ir
::
PDNode
*
elementwise_a
,
const
std
::
string
&
elementwise_type
,
const
std
::
string
&
activation_type
)
{
const
std
::
string
&
elementwise_type
,
const
std
::
string
&
activation_type
)
{
// Create Operators
elementwise_a
->
assert_is_op_input
(
elementwise_type
,
"X"
);
auto
*
elementwise_op
=
...
...
@@ -995,7 +1031,8 @@ PDNode *patterns::SeqConvEltAddRelu::operator()(
}
PDNode
*
patterns
::
FC
::
operator
()(
paddle
::
framework
::
ir
::
PDNode
*
x
,
bool
with_bias
,
bool
with_relu
)
{
bool
with_bias
,
bool
with_relu
)
{
// Create shared nodes.
x
->
assert_is_op_input
(
"mul"
,
"X"
);
auto
*
mul
=
pattern
->
NewNode
(
mul_repr
())
->
assert_is_op
(
"mul"
);
...
...
@@ -1261,8 +1298,12 @@ PDNode *patterns::BatchNormAct::operator()(
bn
->
LinksFrom
(
{
bn_x_var
,
bn_scale_var
,
bn_bias_var
,
bn_variance_var
,
bn_mean_var
})
.
LinksTo
({
bn_mean_out_var
,
bn_variance_out_var
,
bn_saved_variance_var
,
bn_saved_mean_var
,
bn_reserve_space
,
bn_out_var
});
.
LinksTo
({
bn_mean_out_var
,
bn_variance_out_var
,
bn_saved_variance_var
,
bn_saved_mean_var
,
bn_reserve_space
,
bn_out_var
});
act
->
LinksFrom
({
bn_out_var
}).
LinksTo
({
act_out_var
});
return
act_out_var
;
...
...
@@ -1319,8 +1360,13 @@ PDNode *patterns::BatchNormActGrad::operator()(
.
LinksTo
({
d_intermediate_var
});
bn_grad
->
LinksFrom
({
bn_x_var
,
d_intermediate_var
,
bn_scale_var
,
bn_bias_var
,
bn_saved_mean_var
,
bn_saved_variance_var
,
bn_reserve_space
})
->
LinksFrom
({
bn_x_var
,
d_intermediate_var
,
bn_scale_var
,
bn_bias_var
,
bn_saved_mean_var
,
bn_saved_variance_var
,
bn_reserve_space
})
.
LinksTo
({
d_bn_x_var
,
d_bn_scale_var
,
d_bn_bias_var
});
return
bn_grad
;
...
...
@@ -1404,8 +1450,12 @@ PDNode *patterns::BatchNormAddAct::operator()(
pattern
->
NewNode
(
act_out_repr
())
->
assert_is_ops_output
(
act_types
,
"Out"
);
bn
->
LinksFrom
({
bn_x_var
,
bn_scale_var
,
bn_bias_var
})
.
LinksTo
({
bn_mean_out_var
,
bn_variance_out_var
,
bn_saved_variance_var
,
bn_saved_mean_var
,
bn_reserve_space
,
bn_out_var
});
.
LinksTo
({
bn_mean_out_var
,
bn_variance_out_var
,
bn_saved_variance_var
,
bn_saved_mean_var
,
bn_reserve_space
,
bn_out_var
});
elewise_add
->
LinksFrom
({
elewise_add_in_var
,
bn_out_var
})
.
LinksTo
({
elewise_add_out_var
});
act
->
LinksFrom
({
elewise_add_out_var
}).
LinksTo
({
act_out_var
});
...
...
@@ -1484,8 +1534,13 @@ PDNode *patterns::BatchNormAddActGrad::operator()(
.
LinksTo
({
d_elewise_add_in_var
,
d_bn_out_var
});
bn_grad
->
LinksFrom
({
bn_x_var
,
d_bn_out_var
,
bn_scale_var
,
bn_bias_var
,
bn_saved_mean_var
,
bn_saved_variance_var
,
bn_reserve_space
})
->
LinksFrom
({
bn_x_var
,
d_bn_out_var
,
bn_scale_var
,
bn_bias_var
,
bn_saved_mean_var
,
bn_saved_variance_var
,
bn_reserve_space
})
.
LinksTo
({
d_bn_x_var
,
d_bn_scale_var
,
d_bn_bias_var
});
return
bn_grad
;
...
...
@@ -1558,7 +1613,8 @@ PDNode *patterns::ElewiseAddAct::operator()(
PDNode
*
patterns
::
LinearAct
::
operator
()(
paddle
::
framework
::
ir
::
PDNode
*
linear_x_var
,
const
std
::
unordered_set
<
std
::
string
>
&
act_types
,
bool
with_grad_link
,
const
std
::
unordered_set
<
std
::
string
>
&
act_types
,
bool
with_grad_link
,
bool
is_act_grad_x_from_act
)
{
auto
*
matmul_w_var
=
pattern
->
NewNode
(
matmul_w_repr
())
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
...
...
@@ -1621,7 +1677,8 @@ PDNode *patterns::LinearAct::operator()(
PDNode
*
patterns
::
ElewiseAddMatmulAct
::
operator
()(
paddle
::
framework
::
ir
::
PDNode
*
dout_var
,
const
std
::
unordered_set
<
std
::
string
>
&
act_grad_types
,
bool
without_x_gradient
,
bool
is_act_grad_x_from_act
)
{
bool
without_x_gradient
,
bool
is_act_grad_x_from_act
)
{
auto
*
ele_grad_bias_var
=
pattern
->
NewNode
(
ele_grad_bias_repr
())
->
assert_is_op_input
(
"elementwise_add_grad"
,
"Y"
);
...
...
@@ -2052,7 +2109,8 @@ PDNode *patterns::Pool::operator()() {
return
output_var
;
}
PDNode
*
patterns
::
Elementwise
::
operator
()(
PDNode
*
x_var
,
PDNode
*
y_var
,
PDNode
*
patterns
::
Elementwise
::
operator
()(
PDNode
*
x_var
,
PDNode
*
y_var
,
const
std
::
string
elementwise_type
)
{
auto
elementwise_op
=
pattern
->
NewNode
(
elementwise_op_repr
())
->
assert_is_op
(
elementwise_type
);
...
...
@@ -2084,7 +2142,9 @@ PDNode *patterns::ElementwiseOp::operator()(
}
PDNode
*
patterns
::
ResidualElementwise
::
operator
()(
PDNode
*
op_var
,
PDNode
*
residual_var
,
const
std
::
string
elementwise_type
,
PDNode
*
op_var
,
PDNode
*
residual_var
,
const
std
::
string
elementwise_type
,
bool
as_x
)
{
auto
elementwise_op
=
pattern
->
NewNode
(
elementwise_op_repr
())
->
assert_is_op
(
elementwise_type
);
...
...
@@ -3065,7 +3125,8 @@ void patterns::DeleteQuantDequantLinearOpPattern::operator()() {
}
PDNode
*
patterns
::
ReshapeTransposeMatmulPattern
::
operator
()(
const
std
::
string
&
op_name
,
bool
with_reshape_xshape
,
const
std
::
string
&
op_name
,
bool
with_reshape_xshape
,
bool
with_transpose_xshape
)
{
auto
reshape_op
=
pattern
->
NewNode
(
reshape_op_repr
())
->
assert_is_op
(
"reshape2"
);
...
...
@@ -3098,8 +3159,7 @@ PDNode *patterns::ReshapeTransposeMatmulPattern::operator()(
transpose_out
->
assert_is_only_output_of_op
(
"transpose2"
);
auto
transpose_xshape
=
with_transpose_xshape
?
pattern
->
NewNode
(
transpose_xshape_repr
())
with_transpose_xshape
?
pattern
->
NewNode
(
transpose_xshape_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"transpose2"
,
"XShape"
)
:
nullptr
;
...
...
paddle/fluid/framework/ir/graph_pattern_detector.h
浏览文件 @
1ea9971a
...
...
@@ -122,10 +122,12 @@ struct PDNode {
PDNode
*
assert_is_op_input
(
const
std
::
string
&
op_type
,
const
std
::
string
&
argument
);
PDNode
*
assert_is_op_nth_input
(
const
std
::
string
&
op_type
,
const
std
::
string
&
argument
,
int
nth
);
const
std
::
string
&
argument
,
int
nth
);
PDNode
*
assert_is_not_op_input
(
const
std
::
string
&
argument
);
PDNode
*
assert_is_op_nth_output
(
const
std
::
string
&
op_type
,
const
std
::
string
&
argument
,
int
nth
);
const
std
::
string
&
argument
,
int
nth
);
PDNode
*
assert_is_only_input_of_op
(
const
std
::
string
&
op_type
);
PDNode
*
assert_is_only_output_of_op
(
const
std
::
string
&
op_type
);
PDNode
*
assert_op_has_n_inputs
(
const
std
::
string
&
op_type
,
size_t
n
);
...
...
@@ -138,13 +140,15 @@ struct PDNode {
const
std
::
string
&
argument
);
PDNode
*
assert_is_ops_nth_input
(
const
std
::
unordered_set
<
std
::
string
>&
op_types
,
const
std
::
string
&
argument
,
int
nth
);
const
std
::
string
&
argument
,
int
nth
);
PDNode
*
assert_is_ops_input
(
const
std
::
unordered_set
<
std
::
string
>&
op_types
);
PDNode
*
assert_is_ops_input
(
const
std
::
unordered_set
<
std
::
string
>&
op_types
,
const
std
::
string
&
argument
);
PDNode
*
assert_is_ops_nth_output
(
const
std
::
unordered_set
<
std
::
string
>&
op_types
,
const
std
::
string
&
argument
,
int
nth
);
const
std
::
string
&
argument
,
int
nth
);
PDNode
*
assert_is_only_input_of_ops
(
const
std
::
unordered_set
<
std
::
string
>&
op_types
);
...
...
@@ -164,10 +168,13 @@ struct PDNode {
}
private:
PDNode
(
PDPattern
*
pattern
,
const
std
::
string
&
name
=
""
,
PDNode
(
PDPattern
*
pattern
,
const
std
::
string
&
name
=
""
,
Type
type
=
Type
::
kVar
)
:
pattern_
(
pattern
),
name_
(
name
),
type_
(
type
)
{}
PDNode
(
teller_t
&&
teller
,
PDPattern
*
pattern
,
const
std
::
string
&
name
=
""
,
PDNode
(
teller_t
&&
teller
,
PDPattern
*
pattern
,
const
std
::
string
&
name
=
""
,
Type
type
=
Type
::
kVar
)
:
teller_
(
std
::
move
(
teller
)),
pattern_
(
pattern
),
...
...
@@ -398,16 +405,25 @@ struct KeyCounter {
return
x
;
}
#ifdef PADDLE_WITH_TENSORRT
static
int
IncCounter
(
const
std
::
string
&
key
)
{
return
dic_
[
key
]
++
;
}
static
void
CleanCounter
()
{
dic_
.
clear
();
}
private:
static
thread_local
std
::
unordered_map
<
std
::
string
,
size_t
>
dic_
;
#else
int
IncCounter
(
const
std
::
string
&
key
)
{
return
dic_
[
key
]
++
;
}
private:
std
::
unordered_map
<
std
::
string
,
size_t
>
dic_
;
#endif
};
// Generate a unique PDNode's name with name_scope and id.
// The format is {name_scope}/{repr}/{id}/{name}
static
std
::
string
PDNodeName
(
const
std
::
string
&
name_scope
,
const
std
::
string
&
repr
,
size_t
id
,
const
std
::
string
&
repr
,
size_t
id
,
const
std
::
string
&
name
)
{
return
string
::
Sprintf
(
"%s/%s/%d/%s"
,
name_scope
,
repr
,
id
,
name
);
}
...
...
@@ -415,15 +431,15 @@ static std::string PDNodeName(const std::string& name_scope,
// The format is {name_scope}/{repr}/{id}
static
std
::
string
PDNodeName
(
const
std
::
string
&
name_scope
,
const
std
::
string
&
repr
)
{
return
string
::
Sprintf
(
"%s/%s/%d"
,
name_scope
,
repr
,
KeyCounter
::
Instance
().
IncCounter
(
repr
));
return
string
::
Sprintf
(
"%s/%s/%d"
,
name_scope
,
repr
,
KeyCounter
::
Instance
().
IncCounter
(
repr
));
}
// Generate a unique key. It can be used for a universally unique temporary
// name.
// The format is {repr}/{id}
static
std
::
string
UniqueKey
(
const
std
::
string
&
repr
)
{
return
string
::
Sprintf
(
"%s/%d"
,
repr
,
KeyCounter
::
Instance
().
IncCounter
(
repr
));
return
string
::
Sprintf
(
"%s/%d"
,
repr
,
KeyCounter
::
Instance
().
IncCounter
(
repr
));
}
// Declare a PDNode in a pattern, will create two methods:
...
...
@@ -440,17 +456,19 @@ static std::string UniqueKey(const std::string& repr) {
// arg: the argument declared by PATTERN_DECL_NODE in a pattern definition.
// pat: the pattern object.
#define GET_IR_NODE_FROM_SUBGRAPH(var, arg, pat) \
PADDLE_ENFORCE_NE(subgraph.count(pat.arg##_n()), 0UL, \
PADDLE_ENFORCE_NE(subgraph.count(pat.arg##_n()), \
0UL, \
platform::errors::NotFound("Node not found for PDNode %s", \
pat.arg##_repr())); \
Node* var = subgraph.at(pat.arg##_n()); \
PADDLE_ENFORCE_NOT_NULL(
\
var, platform::errors::NotFound("node %s not exists in the sub-graph",
\
#arg));
PADDLE_ENFORCE_NOT_NULL(
var,
\
platform::errors::NotFound(
\
"node %s not exists in the sub-graph",
#arg));
// The base class of all the patterns.
struct
PatternBase
{
PatternBase
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
,
PatternBase
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
,
const
std
::
string
&
repr
)
:
pattern
(
pattern
),
name_scope_
(
name_scope
),
...
...
@@ -476,7 +494,8 @@ struct ConvBN : public PatternBase {
ConvBN
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"conv_bn"
)
{}
PDNode
*
operator
()(
PDNode
*
conv_input
,
const
std
::
string
&
conv_type
,
PDNode
*
operator
()(
PDNode
*
conv_input
,
const
std
::
string
&
conv_type
,
bool
with_eltwise_add
);
// declare operator node's name
...
...
@@ -514,7 +533,8 @@ struct ConvActivation : public PatternBase {
ConvActivation
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"conv_activation"
)
{}
PDNode
*
operator
()(
PDNode
*
conv_input
,
std
::
string
conv_type
=
"conv2d"
,
PDNode
*
operator
()(
PDNode
*
conv_input
,
std
::
string
conv_type
=
"conv2d"
,
std
::
string
activation_type
=
"relu"
);
// declare operator node's name
...
...
@@ -536,7 +556,8 @@ struct ElementwiseActivation : public PatternBase {
ElementwiseActivation
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"elementwise_add_activation"
)
{}
PDNode
*
operator
()(
PDNode
*
elementwise_a
,
const
std
::
string
&
elementwise_type
,
PDNode
*
operator
()(
PDNode
*
elementwise_a
,
const
std
::
string
&
elementwise_type
,
const
std
::
string
&
activation_type
);
// declare operator node's name
...
...
@@ -936,7 +957,8 @@ struct LinearAct : public PatternBase {
PDNode
*
operator
()(
PDNode
*
x
,
const
std
::
unordered_set
<
std
::
string
>&
act_types
,
bool
with_grad_link
,
bool
is_act_grad_x_from_act
);
bool
with_grad_link
,
bool
is_act_grad_x_from_act
);
// declare operator node's name
PATTERN_DECL_NODE
(
matmul
);
...
...
@@ -965,7 +987,8 @@ struct ElewiseAddMatmulAct : public PatternBase {
PDNode
*
operator
()(
PDNode
*
x
,
const
std
::
unordered_set
<
std
::
string
>&
act_grad_types
,
bool
without_x_gradient
,
bool
is_act_grad_x_from_act
);
bool
without_x_gradient
,
bool
is_act_grad_x_from_act
);
// declare operator node's name
PATTERN_DECL_NODE
(
ele_add_grad
);
...
...
@@ -1062,7 +1085,8 @@ struct Elementwise : public PatternBase {
Elementwise
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"elementwise"
)
{}
PDNode
*
operator
()(
PDNode
*
x_var
,
PDNode
*
y_var
,
PDNode
*
operator
()(
PDNode
*
x_var
,
PDNode
*
y_var
,
const
std
::
string
elementwise_type
);
PATTERN_DECL_NODE
(
elementwise_op
);
...
...
@@ -1088,11 +1112,14 @@ struct ElementwiseOp : public PatternBase {
// This pattern allows operator output to be X or Y
// and residual data Y or X, based on as_x flag
struct
ResidualElementwise
:
public
PatternBase
{
ResidualElementwise
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
,
ResidualElementwise
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
,
bool
as_x
)
:
PatternBase
(
pattern
,
name_scope
,
"residual_elementwise"
)
{}
PDNode
*
operator
()(
PDNode
*
op_var
,
PDNode
*
residual_var
,
const
std
::
string
elementwise_type
,
bool
as_x
);
PDNode
*
operator
()(
PDNode
*
op_var
,
PDNode
*
residual_var
,
const
std
::
string
elementwise_type
,
bool
as_x
);
PATTERN_DECL_NODE
(
operator_output
);
PATTERN_DECL_NODE
(
residual_data
);
...
...
@@ -1467,8 +1494,8 @@ struct ConvElementwiseaddAct : public PatternBase {
// Conv + ElementwiseAdd + ElementwiseAdd + Activation
struct
ConvElementwiseadd2Act
:
public
PatternBase
{
ConvElementwiseadd2Act
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"conv_elementwiseadd2_elementwiseadd_act"
)
{}
:
PatternBase
(
pattern
,
name_scope
,
"conv_elementwiseadd2_elementwiseadd_act"
)
{}
PDNode
*
operator
()(
PDNode
*
conv_in
);
...
...
@@ -1702,7 +1729,8 @@ struct DequantOpFuse : public PatternBase {
DequantOpFuse
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"dequant_fuse"
)
{}
void
operator
()(
PDNode
*
quant_op_input
,
const
std
::
string
&
quantized_op_type
,
void
operator
()(
PDNode
*
quant_op_input
,
const
std
::
string
&
quantized_op_type
,
const
std
::
string
&
dequant_type
,
const
std
::
string
&
weight_name
);
...
...
@@ -1758,8 +1786,8 @@ struct DeleteQuantDequantOpPattern : public PatternBase {
struct
DeleteQuantDequantFilterOpPattern
:
public
PatternBase
{
DeleteQuantDequantFilterOpPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"delete_quantdequant_filter_op_pattern"
)
{}
:
PatternBase
(
pattern
,
name_scope
,
"delete_quantdequant_filter_op_pattern"
)
{}
void
operator
()();
...
...
@@ -1773,7 +1801,8 @@ struct DeleteQuantDequantFilterOpPattern : public PatternBase {
struct
DeleteWeightQuantDequantLinearOpPattern
:
public
PatternBase
{
DeleteWeightQuantDequantLinearOpPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
:
PatternBase
(
pattern
,
name_scope
,
"delete_weight_quant_dequant_linear_op_pattern"
)
{}
void
operator
()();
...
...
@@ -1788,8 +1817,8 @@ struct DeleteWeightQuantDequantLinearOpPattern : public PatternBase {
struct
DeleteQuantDequantLinearOpPattern
:
public
PatternBase
{
DeleteQuantDequantLinearOpPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"delete_quant_dequant_linear_op_pattern"
)
{}
:
PatternBase
(
pattern
,
name_scope
,
"delete_quant_dequant_linear_op_pattern"
)
{}
void
operator
()();
...
...
@@ -1814,7 +1843,8 @@ struct ReshapeTransposeMatmulPattern : public PatternBase {
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"reshape_transpose_matmul"
)
{}
PDNode
*
operator
()(
const
std
::
string
&
op_name
,
bool
with_reshape_xshape
,
PDNode
*
operator
()(
const
std
::
string
&
op_name
,
bool
with_reshape_xshape
,
bool
with_transpose_xshape
);
PATTERN_DECL_NODE
(
reshape_in
);
...
...
paddle/fluid/inference/api/analysis_predictor.cc
浏览文件 @
1ea9971a
...
...
@@ -82,9 +82,9 @@ namespace paddle {
using
inference
::
Singleton
;
#if PADDLE_WITH_TENSORRT
using
inference
::
tensorrt
::
TRTInt8Calibrator
;
using
inference
::
tensorrt
::
TRTCalibratorEngine
;
using
inference
::
tensorrt
::
TRTCalibratorEngineManager
;
using
inference
::
tensorrt
::
TRTInt8Calibrator
;
#endif
int
AnalysisPredictor
::
clone_num_
=
1
;
...
...
@@ -101,7 +101,8 @@ bool IsPersistable(const framework::VarDesc *var) {
}
}
// namespace
bool
PaddleTensorToLoDTensor
(
const
PaddleTensor
&
pt
,
framework
::
LoDTensor
*
t
,
bool
PaddleTensorToLoDTensor
(
const
PaddleTensor
&
pt
,
framework
::
LoDTensor
*
t
,
const
platform
::
Place
&
place
)
{
framework
::
DDim
ddim
=
phi
::
make_ddim
(
pt
.
shape
);
void
*
input_ptr
;
...
...
@@ -129,18 +130,19 @@ bool PaddleTensorToLoDTensor(const PaddleTensor &pt, framework::LoDTensor *t,
if
(
platform
::
is_cpu_place
(
place
))
{
// TODO(panyx0718): Init LoDTensor from existing memcpy to save a copy.
std
::
memcpy
(
static_cast
<
void
*>
(
input_ptr
),
pt
.
data
.
data
(),
pt
.
data
.
length
());
std
::
memcpy
(
static_cast
<
void
*>
(
input_ptr
),
pt
.
data
.
data
(),
pt
.
data
.
length
());
}
else
if
(
platform
::
is_ipu_place
(
place
))
{
#ifdef PADDLE_WITH_IPU
std
::
memcpy
(
static_cast
<
void
*>
(
input_ptr
),
pt
.
data
.
data
(),
pt
.
data
.
length
());
std
::
memcpy
(
static_cast
<
void
*>
(
input_ptr
),
pt
.
data
.
data
(),
pt
.
data
.
length
());
#else
PADDLE_THROW
(
paddle
::
platform
::
errors
::
Fatal
(
"Not compile with WITH_IPU, should not reach here."
));
#endif
}
else
if
(
platform
::
is_gpu_place
(
place
))
{
PADDLE_ENFORCE_EQ
(
platform
::
is_xpu_place
(
place
),
false
,
PADDLE_ENFORCE_EQ
(
platform
::
is_xpu_place
(
place
),
false
,
platform
::
errors
::
InvalidArgument
(
"Only one choice can be made between CPU and XPU."
));
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
...
...
@@ -148,8 +150,11 @@ bool PaddleTensorToLoDTensor(const PaddleTensor &pt, framework::LoDTensor *t,
auto
*
dev_ctx
=
static_cast
<
const
platform
::
CUDADeviceContext
*>
(
pool
.
Get
(
place
));
auto
dst_gpu_place
=
place
;
memory
::
Copy
(
dst_gpu_place
,
static_cast
<
void
*>
(
input_ptr
),
platform
::
CPUPlace
(),
pt
.
data
.
data
(),
pt
.
data
.
length
(),
memory
::
Copy
(
dst_gpu_place
,
static_cast
<
void
*>
(
input_ptr
),
platform
::
CPUPlace
(),
pt
.
data
.
data
(),
pt
.
data
.
length
(),
dev_ctx
->
stream
());
#else
PADDLE_THROW
(
paddle
::
platform
::
errors
::
Fatal
(
...
...
@@ -158,8 +163,11 @@ bool PaddleTensorToLoDTensor(const PaddleTensor &pt, framework::LoDTensor *t,
}
else
if
(
platform
::
is_xpu_place
(
place
))
{
#ifdef PADDLE_WITH_XPU
auto
dst_xpu_place
=
place
;
memory
::
Copy
(
dst_xpu_place
,
static_cast
<
void
*>
(
input_ptr
),
platform
::
CPUPlace
(),
pt
.
data
.
data
(),
pt
.
data
.
length
());
memory
::
Copy
(
dst_xpu_place
,
static_cast
<
void
*>
(
input_ptr
),
platform
::
CPUPlace
(),
pt
.
data
.
data
(),
pt
.
data
.
length
());
#else
PADDLE_THROW
(
paddle
::
platform
::
errors
::
Fatal
(
"Not compile with XPU, should not reach here."
));
...
...
@@ -263,7 +271,8 @@ bool AnalysisPredictor::PrepareProgram(
}
bool
AnalysisPredictor
::
CreateExecutor
()
{
if
(
config_
.
use_gpu
())
{
PADDLE_ENFORCE_EQ
(
config_
.
use_xpu
(),
false
,
PADDLE_ENFORCE_EQ
(
config_
.
use_xpu
(),
false
,
platform
::
errors
::
InvalidArgument
(
"Only one choice can be made between CPU and XPU."
));
place_
=
paddle
::
platform
::
CUDAPlace
(
config_
.
gpu_device_id
());
...
...
@@ -357,7 +366,8 @@ static bool IsPrepareDataOptTargetOp(framework::OpDesc *op) {
}
static
void
DisablePrepareDataOpt
(
std
::
shared_ptr
<
framework
::
ProgramDesc
>
inference_program
,
int
block
,
std
::
shared_ptr
<
framework
::
ProgramDesc
>
inference_program
,
int
block
,
bool
pre_disable_opt
)
{
bool
disable_opt
=
false
;
auto
&
infer_block
=
inference_program
->
Block
(
block
);
...
...
@@ -367,8 +377,8 @@ static void DisablePrepareDataOpt(
}
if
(
op
->
HasAttr
(
"sub_block"
))
{
int
blockID
=
op
->
GetBlockAttrId
(
"sub_block"
);
DisablePrepareDataOpt
(
inference_program
,
blockID
,
disable_opt
||
pre_disable_opt
);
DisablePrepareDataOpt
(
inference_program
,
blockID
,
disable_opt
||
pre_disable_opt
);
}
// disable prepare data if unfriendly op is found
if
(
!
disable_opt
)
{
...
...
@@ -386,8 +396,8 @@ bool AnalysisPredictor::PrepareExecutor() {
#endif
DisablePrepareDataOpt
(
inference_program_
,
0
,
false
);
executor_
->
Prepare
(
sub_scope_
,
*
inference_program_
,
0
,
config_
.
use_feed_fetch_ops_
);
executor_
->
Prepare
(
sub_scope_
,
*
inference_program_
,
0
,
config_
.
use_feed_fetch_ops_
);
PADDLE_ENFORCE_NOT_NULL
(
sub_scope_
,
platform
::
errors
::
PreconditionNotMet
(
...
...
@@ -433,8 +443,13 @@ bool AnalysisPredictor::PrepareFleetExecutor() {
feed_fetch_vars
.
emplace_back
(
pair
.
second
);
}
fleet_exe_
->
Init
(
config_
.
dist_config
().
carrier_id
(),
*
(
inference_program_
.
get
()),
scope_
.
get
(),
place_
,
1
,
{
task_node_
.
get
()},
id_to_rank
,
feed_fetch_vars
);
*
(
inference_program_
.
get
()),
scope_
.
get
(),
place_
,
1
,
{
task_node_
.
get
()},
id_to_rank
,
feed_fetch_vars
);
return
true
;
}
...
...
@@ -471,8 +486,12 @@ bool AnalysisPredictor::CommInit() {
peer_endpoints
.
emplace_back
(
config_
.
dist_config
().
trainer_endpoints
()[
rank
]);
}
InsertCommOp
(
var_name_base
+
std
::
to_string
(
order
),
ranks_in_group
,
rank_in_group
,
peer_endpoints
,
comm_init_block
,
ring_id
);
InsertCommOp
(
var_name_base
+
std
::
to_string
(
order
),
ranks_in_group
,
rank_in_group
,
peer_endpoints
,
comm_init_block
,
ring_id
);
order
+=
1
;
}
framework
::
NaiveExecutor
e
(
place_
);
...
...
@@ -484,8 +503,11 @@ bool AnalysisPredictor::CommInit() {
}
void
AnalysisPredictor
::
InsertCommOp
(
std
::
string
tmp_var_name
,
int
nranks
,
int
rank
,
const
std
::
vector
<
std
::
string
>
&
peer_endpoints
,
framework
::
BlockDesc
*
block
,
std
::
string
tmp_var_name
,
int
nranks
,
int
rank
,
const
std
::
vector
<
std
::
string
>
&
peer_endpoints
,
framework
::
BlockDesc
*
block
,
int
ring_id
)
{
/*
* tmp_var_name: the var name for var comm_id
...
...
@@ -542,7 +564,8 @@ bool AnalysisPredictor::LoadConverterConfig(
<<
config_
.
dist_config
().
comm_init_config
()
<<
"
\n
"
;
std
::
ifstream
fin
(
config_
.
dist_config
().
comm_init_config
(),
std
::
ios
::
in
);
PADDLE_ENFORCE_EQ
(
static_cast
<
bool
>
(
fin
.
is_open
()),
true
,
static_cast
<
bool
>
(
fin
.
is_open
()),
true
,
platform
::
errors
::
NotFound
(
"Cannot open file %s, please confirm whether the file is normal."
,
config_
.
dist_config
().
comm_init_config
()));
...
...
@@ -686,8 +709,9 @@ bool AnalysisPredictor::Run(const std::vector<PaddleTensor> &inputs,
timer
.
tic
();
// set feed variable
framework
::
Scope
*
scope
=
sub_scope_
?
sub_scope_
:
scope_
.
get
();
PADDLE_ENFORCE_NOT_NULL
(
scope
,
platform
::
errors
::
PreconditionNotMet
(
"The scope should not be nullptr."
));
PADDLE_ENFORCE_NOT_NULL
(
scope
,
platform
::
errors
::
PreconditionNotMet
(
"The scope should not be nullptr."
));
if
(
!
SetFeed
(
inputs
,
scope
))
{
LOG
(
ERROR
)
<<
"fail to set feed"
;
return
false
;
...
...
@@ -790,9 +814,11 @@ bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs,
for
(
size_t
i
=
0
;
i
<
fetches_
.
size
();
++
i
)
{
int
idx
=
BOOST_GET_CONST
(
int
,
fetches_
[
i
]
->
GetAttr
(
"col"
));
PADDLE_ENFORCE_EQ
(
static_cast
<
size_t
>
(
idx
),
i
,
static_cast
<
size_t
>
(
idx
),
i
,
platform
::
errors
::
InvalidArgument
(
"Fetch op's col attr(%d) should be equal to the index(%d)"
,
idx
,
"Fetch op's col attr(%d) should be equal to the index(%d)"
,
idx
,
i
));
framework
::
FetchType
&
fetch_var
=
framework
::
GetFetchVariable
(
*
scope
,
"fetch"
,
idx
);
...
...
@@ -833,7 +859,8 @@ void AnalysisPredictor::PrepareArgument() {
if
(
!
config_
.
model_dir
().
empty
())
{
argument_
.
SetModelDir
(
config_
.
model_dir
());
}
else
{
PADDLE_ENFORCE_EQ
(
config_
.
prog_file
().
empty
(),
false
,
PADDLE_ENFORCE_EQ
(
config_
.
prog_file
().
empty
(),
false
,
platform
::
errors
::
PreconditionNotMet
(
"Either model_dir or prog_file should be set."
));
std
::
string
dir
=
inference
::
analysis
::
GetDirRoot
(
config_
.
prog_file
());
...
...
@@ -969,7 +996,8 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
Analyzer
().
Run
(
&
argument_
);
PADDLE_ENFORCE_EQ
(
argument_
.
scope_valid
(),
true
,
argument_
.
scope_valid
(),
true
,
platform
::
errors
::
InvalidArgument
(
"The argument scope should be valid."
));
VLOG
(
5
)
<<
"to prepare executor"
;
ARGUMENT_CHECK_FIELD
((
&
argument_
),
ir_analyzed_program
);
...
...
@@ -1008,8 +1036,9 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
}
template
<
>
std
::
unique_ptr
<
PaddlePredictor
>
CreatePaddlePredictor
<
AnalysisConfig
,
PaddleEngineKind
::
kAnalysis
>
(
const
AnalysisConfig
&
config
)
{
std
::
unique_ptr
<
PaddlePredictor
>
CreatePaddlePredictor
<
AnalysisConfig
,
PaddleEngineKind
::
kAnalysis
>
(
const
AnalysisConfig
&
config
)
{
// TODO(NHZlX): Should add the link to the doc of
// paddle_infer::CreatePredictor<paddle_infer::Config>
if
(
config
.
glog_info_disabled
())
{
...
...
@@ -1018,7 +1047,8 @@ std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
}
VLOG
(
3
)
<<
"create AnalysisConfig"
;
PADDLE_ENFORCE_EQ
(
config
.
is_valid
(),
true
,
config
.
is_valid
(),
true
,
platform
::
errors
::
InvalidArgument
(
"Note: Each config can only be used for one predictor."
));
...
...
@@ -1035,11 +1065,13 @@ std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
std
::
call_once
(
gflags_initialized
,
[
&
]()
{
std
::
vector
<
std
::
string
>
gflags
;
PADDLE_ENFORCE_GE
(
config
.
memory_pool_init_size_mb
(),
0.
f
,
config
.
memory_pool_init_size_mb
(),
0.
f
,
platform
::
errors
::
InvalidArgument
(
"The size of memory pool should be greater than 0."
));
PADDLE_ENFORCE_GE
(
config
.
gpu_device_id
(),
0
,
config
.
gpu_device_id
(),
0
,
platform
::
errors
::
InvalidArgument
(
"Invalid device id (%d). The device id should be greater than 0."
,
config
.
gpu_device_id
()));
...
...
@@ -1105,6 +1137,10 @@ std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
config
.
SetInValid
();
auto
predictor_p
=
dynamic_cast
<
AnalysisPredictor
*>
(
predictor
.
get
());
#ifdef PADDLE_WITH_TENSORRT
paddle
::
framework
::
ir
::
patterns
::
KeyCounter
::
Instance
().
CleanCounter
();
#endif
if
(
!
predictor_p
->
Init
(
nullptr
))
{
return
nullptr
;
}
...
...
@@ -1154,8 +1190,9 @@ void AnalysisPredictor::PrepareFeedFetch() {
}
void
AnalysisPredictor
::
CreateFeedFetchVar
(
framework
::
Scope
*
scope
)
{
PADDLE_ENFORCE_NOT_NULL
(
scope
,
platform
::
errors
::
InvalidArgument
(
"The scope should not be nullptr."
));
PADDLE_ENFORCE_NOT_NULL
(
scope
,
platform
::
errors
::
InvalidArgument
(
"The scope should not be nullptr."
));
auto
*
var
=
scope
->
Var
(
"feed"
);
var
->
GetMutable
<
framework
::
FeedList
>
();
var
=
scope
->
Var
(
"fetch"
);
...
...
@@ -1176,8 +1213,9 @@ AnalysisPredictor::GetInputTensorShape() {
std
::
vector
<
std
::
string
>
names
=
GetInputNames
();
for
(
std
::
string
name
:
names
)
{
auto
*
var
=
inference_program_
->
Block
(
0
).
FindVar
(
name
);
PADDLE_ENFORCE_NOT_NULL
(
var
,
platform
::
errors
::
PreconditionNotMet
(
"Input %s does not exist."
,
name
));
PADDLE_ENFORCE_NOT_NULL
(
var
,
platform
::
errors
::
PreconditionNotMet
(
"Input %s does not exist."
,
name
));
input_shapes
[
name
]
=
var
->
GetShape
();
}
return
input_shapes
;
...
...
@@ -1398,7 +1436,8 @@ void AnalysisPredictor::StatisticShapeRangeInfo() {
std
::
vector
<
std
::
pair
<
int32_t
,
int32_t
>>
counter
;
for
(
auto
&
it
:
m
)
counter
.
push_back
(
it
);
std
::
sort
(
counter
.
begin
(),
counter
.
end
(),
counter
.
begin
(),
counter
.
end
(),
[](
std
::
pair
<
int32_t
,
int32_t
>
&
a
,
std
::
pair
<
int32_t
,
int32_t
>
&
b
)
{
return
a
.
second
>
b
.
second
;
});
...
...
@@ -1420,8 +1459,8 @@ void AnalysisPredictor::StatisticShapeRangeInfo() {
opt_shapes
[
name
]
=
opt_shape
;
}
inference
::
SerializeShapeRangeInfo
(
config_
.
shape_range_info_path
(),
min_shapes
,
max_shapes
,
opt_shapes
);
inference
::
SerializeShapeRangeInfo
(
config_
.
shape_range_info_path
(),
min_shapes
,
max_shapes
,
opt_shapes
);
}
bool
AnalysisPredictor
::
LoadProgramDesc
()
{
...
...
@@ -1441,7 +1480,8 @@ bool AnalysisPredictor::LoadProgramDesc() {
return
false
;
}
LOG
(
ERROR
)
<<
string
::
Sprintf
(
"not valid model path '%s' or program path '%s'."
,
config_
.
model_dir
(),
"not valid model path '%s' or program path '%s'."
,
config_
.
model_dir
(),
config_
.
params_file
());
return
false
;
}
...
...
@@ -1453,7 +1493,8 @@ bool AnalysisPredictor::LoadProgramDesc() {
// Read binary
std
::
ifstream
fin
(
filename
,
std
::
ios
::
in
|
std
::
ios
::
binary
);
PADDLE_ENFORCE_EQ
(
static_cast
<
bool
>
(
fin
.
is_open
()),
true
,
static_cast
<
bool
>
(
fin
.
is_open
()),
true
,
platform
::
errors
::
NotFound
(
"Cannot open file %s, please confirm whether the file is normal."
,
filename
));
...
...
@@ -1555,7 +1596,8 @@ void AnalysisPredictor::ClearIntermediateTensor() {
#if PADDLE_WITH_TENSORRT
bool
AnalysisPredictor
::
SaveTrtCalibToDisk
()
{
PADDLE_ENFORCE_EQ
(
config_
.
tensorrt_engine_enabled
(),
true
,
PADDLE_ENFORCE_EQ
(
config_
.
tensorrt_engine_enabled
(),
true
,
platform
::
errors
::
PreconditionNotMet
(
"This func can be invoked only in trt mode"
));
auto
&
block
=
inference_program_
->
Block
(
0
);
...
...
@@ -1782,8 +1824,10 @@ Predictor::Predictor(const Config &config) {
<<
"Paddle2ONNX do't support convert the Model, fall back to using "
"Paddle Inference."
;
}
else
{
predictor_
=
paddle
::
CreatePaddlePredictor
<
Config
,
paddle
::
PaddleEngineKind
::
kONNXRuntime
>
(
config
);
predictor_
=
paddle
::
CreatePaddlePredictor
<
Config
,
paddle
::
PaddleEngineKind
::
kONNXRuntime
>
(
config
);
return
;
}
#else
...
...
@@ -1793,8 +1837,10 @@ Predictor::Predictor(const Config &config) {
"fall back to using Paddle Inference."
;
#endif
}
predictor_
=
paddle
::
CreatePaddlePredictor
<
Config
,
paddle
::
PaddleEngineKind
::
kAnalysis
>
(
config
);
predictor_
=
paddle
::
CreatePaddlePredictor
<
Config
,
paddle
::
PaddleEngineKind
::
kAnalysis
>
(
config
);
}
std
::
vector
<
std
::
string
>
Predictor
::
GetInputNames
()
{
...
...
@@ -1876,7 +1922,8 @@ std::shared_ptr<Predictor> CreatePredictor(const Config &config) { // NOLINT
namespace
services
{
PredictorPool
::
PredictorPool
(
const
Config
&
config
,
size_t
size
)
{
PADDLE_ENFORCE_GE
(
size
,
1UL
,
size
,
1UL
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"The predictor pool size should be greater than 1, but it's (%d)"
,
size
));
...
...
@@ -1895,9 +1942,11 @@ PredictorPool::PredictorPool(const Config &config, size_t size) {
Predictor
*
PredictorPool
::
Retrive
(
size_t
idx
)
{
PADDLE_ENFORCE_LT
(
idx
,
preds_
.
size
()
+
1
,
idx
,
preds_
.
size
()
+
1
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"There are (%d) predictors in the pool, but the idx is (%d)"
,
idx
,
"There are (%d) predictors in the pool, but the idx is (%d)"
,
idx
,
preds_
.
size
()
+
1
));
if
(
idx
==
0
)
{
return
main_pred_
.
get
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录