Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
3e6aa498
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3e6aa498
编写于
3月 29, 2019
作者:
Z
Zhaolong Xing
提交者:
GitHub
3月 29, 2019
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #16526 from NHZlX/refine_trt_anakin
refine subgraph trt and anakin
上级
e014950e
7cde2d9e
变更
21
隐藏空白更改
内联
并排
Showing
21 changed file
with
435 addition
and
142 deletion
+435
-142
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+5
-14
paddle/fluid/framework/ir/fillconstant_elementwisemul_fuse.cc
...le/fluid/framework/ir/fillconstant_elementwisemul_fuse.cc
+7
-7
paddle/fluid/framework/ir/fillconstant_elementwisemul_fuse.h
paddle/fluid/framework/ir/fillconstant_elementwisemul_fuse.h
+2
-2
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+83
-11
paddle/fluid/framework/ir/graph_pattern_detector.h
paddle/fluid/framework/ir/graph_pattern_detector.h
+21
-4
paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc
paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc
+173
-0
paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.h
paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.h
+35
-0
paddle/fluid/framework/ir/simplify_anakin_priorbox_detection_out_pass.cc
...amework/ir/simplify_anakin_priorbox_detection_out_pass.cc
+23
-33
paddle/fluid/framework/ir/simplify_anakin_priorbox_detection_out_pass.h
...ramework/ir/simplify_anakin_priorbox_detection_out_pass.h
+0
-1
paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.cc
.../fluid/framework/ir/transpose_flatten_concat_fuse_pass.cc
+10
-25
paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.h
...e/fluid/framework/ir/transpose_flatten_concat_fuse_pass.h
+2
-1
paddle/fluid/inference/anakin/convert/density_prior_box.cc
paddle/fluid/inference/anakin/convert/density_prior_box.cc
+33
-16
paddle/fluid/inference/anakin/convert/op_converter.h
paddle/fluid/inference/anakin/convert/op_converter.h
+1
-1
paddle/fluid/inference/anakin/op_teller.cc
paddle/fluid/inference/anakin/op_teller.cc
+2
-0
paddle/fluid/inference/analysis/ir_passes/anakin_subgraph_pass.cc
...luid/inference/analysis/ir_passes/anakin_subgraph_pass.cc
+5
-5
paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
...id/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
+1
-0
paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc
...ence/analysis/passes/ir_params_sync_among_devices_pass.cc
+1
-0
paddle/fluid/inference/api/analysis_predictor.cc
paddle/fluid/inference/api/analysis_predictor.cc
+1
-0
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+11
-16
paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
+15
-6
paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc
paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc
+4
-0
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
3e6aa498
...
@@ -68,21 +68,12 @@ pass_library(transpose_flatten_concat_fuse_pass inference)
...
@@ -68,21 +68,12 @@ pass_library(transpose_flatten_concat_fuse_pass inference)
pass_library
(
identity_scale_op_clean_pass base
)
pass_library
(
identity_scale_op_clean_pass base
)
pass_library
(
sync_batch_norm_pass base
)
pass_library
(
sync_batch_norm_pass base
)
pass_library
(
runtime_context_cache_pass base
)
pass_library
(
runtime_context_cache_pass base
)
pass_library
(
simplify_anakin_detection_pattern
_pass inference
)
pass_library
(
quant_conv2d_dequant_fuse
_pass inference
)
pass_library
(
anakin_
fillconstant_elementwisemul_fuse inference
)
pass_library
(
fillconstant_elementwisemul_fuse inference
)
# There may be many transpose-flatten structures in a model, and the output of
if
(
ANAKIN_FOUND
)
# these structures will be used as inputs to the concat Op. This pattern will
pass_library
(
simplify_anakin_priorbox_detection_out_pass inference
)
# be detected by our pass. The index here represents the number of structures in the
endif
()
# pattern. We use index 3 ~ 6, because these quantities of structures are
# common in the models.
foreach
(
index RANGE 2 6
)
file
(
APPEND
${
pass_file
}
"USE_PASS(transpose_flatten
${
index
}
_concat_fuse_pass);
\n
"
)
endforeach
()
foreach
(
index RANGE 2 6
)
file
(
APPEND
${
pass_file
}
"USE_PASS(simplify_anakin_detection_pattern_pass
${
index
}
);
\n
"
)
endforeach
()
if
(
WITH_MKLDNN
)
if
(
WITH_MKLDNN
)
pass_library
(
mkldnn_placement_pass base mkldnn
)
pass_library
(
mkldnn_placement_pass base mkldnn
)
...
...
paddle/fluid/framework/ir/
anakin_
fillconstant_elementwisemul_fuse.cc
→
paddle/fluid/framework/ir/fillconstant_elementwisemul_fuse.cc
浏览文件 @
3e6aa498
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
#include <memory>
#include <memory>
#include <string>
#include <string>
#include "paddle/fluid/framework/ir/
anakin_
fillconstant_elementwisemul_fuse.h"
#include "paddle/fluid/framework/ir/fillconstant_elementwisemul_fuse.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -29,8 +29,8 @@ namespace ir {
...
@@ -29,8 +29,8 @@ namespace ir {
GET_IR_NODE(elementwise_mul); \
GET_IR_NODE(elementwise_mul); \
GET_IR_NODE(elementwise_mul_out);
GET_IR_NODE(elementwise_mul_out);
void
Anakin
FillconstantElementwisemulFuse
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
void
FillconstantElementwisemulFuse
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
const
std
::
string
pattern_name
=
"
anakin_
fillconstant_elementwisemul_fuse"
;
const
std
::
string
pattern_name
=
"fillconstant_elementwisemul_fuse"
;
FusePassBase
::
Init
(
pattern_name
,
graph
);
FusePassBase
::
Init
(
pattern_name
,
graph
);
GraphPatternDetector
gpd
;
GraphPatternDetector
gpd
;
...
@@ -39,8 +39,8 @@ void AnakinFillconstantElementwisemulFuse::ApplyImpl(ir::Graph* graph) const {
...
@@ -39,8 +39,8 @@ void AnakinFillconstantElementwisemulFuse::ApplyImpl(ir::Graph* graph) const {
->
assert_is_op_input
(
"elementwise_mul"
,
"X"
)
->
assert_is_op_input
(
"elementwise_mul"
,
"X"
)
->
AsInput
();
->
AsInput
();
patterns
::
Anakin
FillConstantElementWiseMulFuse
pattern
(
gpd
.
mutable_pattern
(),
patterns
::
FillConstantElementWiseMulFuse
pattern
(
gpd
.
mutable_pattern
(),
pattern_name
);
pattern_name
);
pattern
(
x
);
pattern
(
x
);
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
...
@@ -79,5 +79,5 @@ void AnakinFillconstantElementwisemulFuse::ApplyImpl(ir::Graph* graph) const {
...
@@ -79,5 +79,5 @@ void AnakinFillconstantElementwisemulFuse::ApplyImpl(ir::Graph* graph) const {
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
REGISTER_PASS
(
anakin_
fillconstant_elementwisemul_fuse
,
REGISTER_PASS
(
fillconstant_elementwisemul_fuse
,
paddle
::
framework
::
ir
::
Anakin
FillconstantElementwisemulFuse
);
paddle
::
framework
::
ir
::
FillconstantElementwisemulFuse
);
paddle/fluid/framework/ir/
anakin_
fillconstant_elementwisemul_fuse.h
→
paddle/fluid/framework/ir/fillconstant_elementwisemul_fuse.h
浏览文件 @
3e6aa498
...
@@ -21,9 +21,9 @@ namespace paddle {
...
@@ -21,9 +21,9 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
ir
{
namespace
ir
{
class
Anakin
FillconstantElementwisemulFuse
:
public
FusePassBase
{
class
FillconstantElementwisemulFuse
:
public
FusePassBase
{
public:
public:
virtual
~
Anakin
FillconstantElementwisemulFuse
()
{}
virtual
~
FillconstantElementwisemulFuse
()
{}
protected:
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
...
...
paddle/fluid/framework/ir/graph_pattern_detector.cc
浏览文件 @
3e6aa498
...
@@ -1471,7 +1471,8 @@ PDNode *patterns::TransposeFlattenConcat::operator()(
...
@@ -1471,7 +1471,8 @@ PDNode *patterns::TransposeFlattenConcat::operator()(
}
}
PDNode
*
patterns
::
AnakinDetectionPattern
::
operator
()(
PDNode
*
patterns
::
AnakinDetectionPattern
::
operator
()(
std
::
vector
<
PDNode
*>
conv_in
,
int
times
)
{
std
::
vector
<
PDNode
*>
conv_in
,
int
times
,
std
::
string
priorbox_type
,
bool
is_reshape
)
{
// The times represents the repeat times of the
// The times represents the repeat times of the
// {prior_box, prior_box_loc_out, flatten, prior_box_var_out, reshape}
// {prior_box, prior_box_loc_out, flatten, prior_box_var_out, reshape}
const
int
kNumFields
=
7
;
const
int
kNumFields
=
7
;
...
@@ -1486,37 +1487,38 @@ PDNode *patterns::AnakinDetectionPattern::operator()(
...
@@ -1486,37 +1487,38 @@ PDNode *patterns::AnakinDetectionPattern::operator()(
const
int
kMultiClassSecondInputNmsOffset
=
times
+
1
;
const
int
kMultiClassSecondInputNmsOffset
=
times
+
1
;
std
::
vector
<
PDNode
*>
nodes
;
std
::
vector
<
PDNode
*>
nodes
;
std
::
string
op_after_priorbox
=
is_reshape
?
"reshape2"
:
"flatten2"
;
for
(
int
i
=
0
;
i
<
times
;
i
++
)
{
for
(
int
i
=
0
;
i
<
times
;
i
++
)
{
nodes
.
push_back
(
nodes
.
push_back
(
pattern
->
NewNode
(
GetNodeName
(
"prior_box"
+
std
::
to_string
(
i
)))
pattern
->
NewNode
(
GetNodeName
(
"prior_box"
+
std
::
to_string
(
i
)))
->
assert_is_op
(
"density_prior_box"
));
->
assert_is_op
(
priorbox_type
));
nodes
.
push_back
(
pattern
->
NewNode
(
GetNodeName
(
"box_out"
+
std
::
to_string
(
i
)))
nodes
.
push_back
(
pattern
->
NewNode
(
GetNodeName
(
"box_out"
+
std
::
to_string
(
i
)))
->
assert_is_op_output
(
"density_prior_box"
,
"Boxes"
)
->
assert_is_op_output
(
priorbox_type
,
"Boxes"
)
->
assert_is_op_input
(
"reshape2"
,
"X"
)
->
assert_is_op_input
(
op_after_priorbox
,
"X"
)
->
AsIntermediate
());
->
AsIntermediate
());
nodes
.
push_back
(
nodes
.
push_back
(
pattern
->
NewNode
(
GetNodeName
(
"reshape1"
+
std
::
to_string
(
i
)))
pattern
->
NewNode
(
GetNodeName
(
"reshape1"
+
std
::
to_string
(
i
)))
->
assert_is_op
(
"reshape2"
));
->
assert_is_op
(
op_after_priorbox
));
nodes
.
push_back
(
nodes
.
push_back
(
pattern
->
NewNode
(
GetNodeName
(
"reshape1_out"
+
std
::
to_string
(
i
)))
pattern
->
NewNode
(
GetNodeName
(
"reshape1_out"
+
std
::
to_string
(
i
)))
->
assert_is_op_output
(
"reshape2"
)
->
assert_is_op_output
(
op_after_priorbox
)
->
assert_is_op_nth_input
(
"concat"
,
"X"
,
i
)
->
assert_is_op_nth_input
(
"concat"
,
"X"
,
i
)
->
AsIntermediate
());
->
AsIntermediate
());
nodes
.
push_back
(
nodes
.
push_back
(
pattern
->
NewNode
(
GetNodeName
(
"box_var_out"
+
std
::
to_string
(
i
)))
pattern
->
NewNode
(
GetNodeName
(
"box_var_out"
+
std
::
to_string
(
i
)))
->
assert_is_op_output
(
"density_prior_box"
,
"Variances"
)
->
assert_is_op_output
(
priorbox_type
,
"Variances"
)
->
assert_is_op_input
(
"reshape2"
,
"X"
)
->
assert_is_op_input
(
op_after_priorbox
,
"X"
)
->
AsIntermediate
());
->
AsIntermediate
());
nodes
.
push_back
(
nodes
.
push_back
(
pattern
->
NewNode
(
GetNodeName
(
"reshape2"
+
std
::
to_string
(
i
)))
pattern
->
NewNode
(
GetNodeName
(
"reshape2"
+
std
::
to_string
(
i
)))
->
assert_is_op
(
"reshape2"
));
->
assert_is_op
(
op_after_priorbox
));
nodes
.
push_back
(
nodes
.
push_back
(
pattern
->
NewNode
(
GetNodeName
(
"reshape2_out"
+
std
::
to_string
(
i
)))
pattern
->
NewNode
(
GetNodeName
(
"reshape2_out"
+
std
::
to_string
(
i
)))
->
assert_is_op_output
(
"reshape2"
)
->
assert_is_op_output
(
op_after_priorbox
)
->
assert_is_op_nth_input
(
"concat"
,
"X"
,
i
)
->
assert_is_op_nth_input
(
"concat"
,
"X"
,
i
)
->
AsIntermediate
());
->
AsIntermediate
());
}
}
...
@@ -1612,7 +1614,7 @@ PDNode *patterns::AnakinDetectionPattern::operator()(
...
@@ -1612,7 +1614,7 @@ PDNode *patterns::AnakinDetectionPattern::operator()(
return
multiclass_nms_out
;
return
multiclass_nms_out
;
}
}
PDNode
*
patterns
::
Anakin
FillConstantElementWiseMulFuse
::
operator
()(
PDNode
*
patterns
::
FillConstantElementWiseMulFuse
::
operator
()(
PDNode
*
elementwise_op_input
)
{
PDNode
*
elementwise_op_input
)
{
auto
fill_constant
=
auto
fill_constant
=
pattern
->
NewNode
(
fill_constant_repr
())
->
assert_is_op
(
"fill_constant"
);
pattern
->
NewNode
(
fill_constant_repr
())
->
assert_is_op
(
"fill_constant"
);
...
@@ -1635,6 +1637,76 @@ PDNode *patterns::AnakinFillConstantElementWiseMulFuse::operator()(
...
@@ -1635,6 +1637,76 @@ PDNode *patterns::AnakinFillConstantElementWiseMulFuse::operator()(
return
elementwise_mul_out
;
return
elementwise_mul_out
;
}
}
void
patterns
::
QuantDequantOpFuse
::
operator
()(
PDNode
*
quant_op_input
,
const
std
::
string
&
op_type
,
const
std
::
string
&
weight_name
,
int
times
)
{
const
int
kNumFields
=
5
;
const
int
kQuantizedWeightOffset
=
0
;
const
int
kQuantizedOpOffset
=
1
;
const
int
kQuantizedOpOutOffset
=
2
;
const
int
kDequantOpOffset
=
3
;
const
int
kDequantOpOutOffset
=
4
;
// the quant op always be one.
auto
quant_op_in_scale
=
pattern
->
NewNode
(
GetNodeName
(
"quant_op_in_scale"
))
->
assert_is_op_input
(
"fake_quantize_range_abs_max"
,
"InScale"
)
->
AsInput
();
auto
quant_op
=
pattern
->
NewNode
(
GetNodeName
(
"quant_op"
))
->
assert_is_op
(
"fake_quantize_range_abs_max"
);
auto
quant_op_out_scale
=
pattern
->
NewNode
(
GetNodeName
(
"quant_op_out_scale"
))
->
assert_is_op_output
(
"fake_quantize_range_abs_max"
,
"OutScale"
)
->
assert_is_op_input
(
"fake_dequantize_max_abs"
,
"Scale"
)
->
AsIntermediate
();
auto
quant_op_out
=
pattern
->
NewNode
(
GetNodeName
(
"quant_op_out"
))
->
assert_is_op_output
(
"fake_quantize_range_abs_max"
,
"Out"
)
->
assert_is_op_input
(
op_type
)
->
AsIntermediate
();
// there are 'times' quantized and dequant op
std
::
vector
<
PDNode
*>
nodes
;
for
(
int
i
=
0
;
i
<
times
;
i
++
)
{
nodes
.
push_back
(
pattern
->
NewNode
(
GetNodeName
(
"quantized_op_weight"
)
+
std
::
to_string
(
i
))
->
assert_is_op_input
(
op_type
,
weight_name
)
->
AsInput
());
nodes
.
push_back
(
pattern
->
NewNode
(
GetNodeName
(
"quantized_op"
)
+
std
::
to_string
(
i
))
->
assert_is_op
(
op_type
));
nodes
.
push_back
(
pattern
->
NewNode
(
GetNodeName
(
"quantized_op_out"
)
+
std
::
to_string
(
i
))
->
assert_is_op_output
(
op_type
)
->
assert_is_op_input
(
"fake_dequantize_max_abs"
,
"X"
)
->
AsIntermediate
());
nodes
.
push_back
(
pattern
->
NewNode
(
GetNodeName
(
"dequant_op"
)
+
std
::
to_string
(
i
))
->
assert_is_op
(
"fake_dequantize_max_abs"
));
nodes
.
push_back
(
pattern
->
NewNode
(
GetNodeName
(
"dequant_op_out"
)
+
std
::
to_string
(
i
))
->
assert_is_op_output
(
"fake_dequantize_max_abs"
,
"Out"
)
->
AsOutput
());
}
quant_op
->
LinksFrom
({
quant_op_input
,
quant_op_in_scale
});
quant_op_out
->
LinksFrom
({
quant_op
});
for
(
int
i
=
0
;
i
<
times
;
i
++
)
{
nodes
[
i
*
kNumFields
+
kQuantizedOpOffset
]
->
LinksFrom
(
{
quant_op_out
,
nodes
[
i
*
kNumFields
+
kQuantizedWeightOffset
]});
nodes
[
i
*
kNumFields
+
kQuantizedOpOutOffset
]
->
LinksFrom
(
{
nodes
[
i
*
kNumFields
+
kQuantizedOpOffset
]});
nodes
[
i
*
kNumFields
+
kDequantOpOffset
]
->
LinksFrom
(
{
nodes
[
i
*
kNumFields
+
kQuantizedOpOutOffset
],
quant_op_out_scale
});
nodes
[
i
*
kNumFields
+
kDequantOpOutOffset
]
->
LinksFrom
(
{
nodes
[
i
*
kNumFields
+
kDequantOpOffset
]});
}
}
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/ir/graph_pattern_detector.h
浏览文件 @
3e6aa498
...
@@ -848,7 +848,8 @@ struct AnakinDetectionPattern : public PatternBase {
...
@@ -848,7 +848,8 @@ struct AnakinDetectionPattern : public PatternBase {
AnakinDetectionPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
AnakinDetectionPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"anakin_detect_pattern"
)
{}
:
PatternBase
(
pattern
,
name_scope
,
"anakin_detect_pattern"
)
{}
PDNode
*
operator
()(
std
::
vector
<
PDNode
*>
conv_inputs
,
int
times
);
PDNode
*
operator
()(
std
::
vector
<
PDNode
*>
conv_inputs
,
int
times
,
std
::
string
priorbox_type
,
bool
is_reshape
);
std
::
string
GetNodeName
(
const
std
::
string
&
op_type
)
{
std
::
string
GetNodeName
(
const
std
::
string
&
op_type
)
{
return
PDNodeName
(
name_scope_
,
repr_
,
id_
,
op_type
);
return
PDNodeName
(
name_scope_
,
repr_
,
id_
,
op_type
);
...
@@ -859,9 +860,9 @@ struct AnakinDetectionPattern : public PatternBase {
...
@@ -859,9 +860,9 @@ struct AnakinDetectionPattern : public PatternBase {
}
}
};
};
struct
Anakin
FillConstantElementWiseMulFuse
:
public
PatternBase
{
struct
FillConstantElementWiseMulFuse
:
public
PatternBase
{
Anakin
FillConstantElementWiseMulFuse
(
PDPattern
*
pattern
,
FillConstantElementWiseMulFuse
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
:
PatternBase
(
pattern
,
name_scope
,
"anakin_fillconstant_elementwisemul_fuse"
)
{}
"anakin_fillconstant_elementwisemul_fuse"
)
{}
...
@@ -874,6 +875,22 @@ struct AnakinFillConstantElementWiseMulFuse : public PatternBase {
...
@@ -874,6 +875,22 @@ struct AnakinFillConstantElementWiseMulFuse : public PatternBase {
PATTERN_DECL_NODE
(
elementwise_mul_out
);
PATTERN_DECL_NODE
(
elementwise_mul_out
);
};
};
struct
QuantDequantOpFuse
:
public
PatternBase
{
QuantDequantOpFuse
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"quant_dequant_fuse"
)
{}
void
operator
()(
PDNode
*
quant_op_input
,
const
std
::
string
&
op_name
,
const
std
::
string
&
weight_name
,
int
times
=
1
);
std
::
string
GetNodeName
(
const
std
::
string
&
op_type
)
{
return
PDNodeName
(
name_scope_
,
repr_
,
id_
,
op_type
);
}
PDNode
*
GetPDNode
(
const
std
::
string
&
op_type
)
{
return
pattern
->
RetrieveNode
(
GetNodeName
(
op_type
));
}
};
}
// namespace patterns
}
// namespace patterns
// Link two ir::Nodes from each other.
// Link two ir::Nodes from each other.
...
...
paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc
0 → 100644
浏览文件 @
3e6aa498
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
void
RunQuantDequant
(
ir
::
Graph
*
graph
,
Scope
*
scope
,
int
times
,
std
::
string
op_type
)
{
const
std
::
string
pattern_name
=
"quant_dequant_fuse"
;
// FusePassBase::Init(pattern_name, graph);
const
int
kNumFields
=
5
;
const
int
kQuantizedWeightOffset
=
0
;
const
int
kQuantizedOpOffset
=
1
;
const
int
kQuantizedOpOutOffset
=
2
;
const
int
kDequantOpOffset
=
3
;
const
int
kDequantOpOutOffset
=
4
;
GraphPatternDetector
gpd
;
auto
*
x
=
gpd
.
mutable_pattern
()
->
NewNode
(
"x"
)
->
assert_is_op_input
(
"fake_quantize_range_abs_max"
,
"X"
)
->
AsInput
();
std
::
string
quantized_op_type
=
""
;
std
::
string
weight_name
=
""
;
if
(
op_type
==
"conv2d"
)
{
quantized_op_type
=
"conv2d"
;
weight_name
=
"Filter"
;
}
else
if
(
op_type
==
"conv2d_fusion"
)
{
quantized_op_type
=
"conv2d_fusion"
;
weight_name
=
"Filter"
;
}
else
if
(
op_type
==
"mul"
)
{
quantized_op_type
=
"mul"
;
weight_name
=
"Y"
;
}
else
if
(
op_type
==
"fc"
)
{
quantized_op_type
=
"fc"
;
weight_name
=
"W"
;
}
else
{
PADDLE_ENFORCE
(
"QuantDequantFuse: We only support conv2d, conv2d_fusion, fc, mul for "
"now."
);
}
patterns
::
QuantDequantOpFuse
pattern
(
gpd
.
mutable_pattern
(),
pattern_name
);
pattern
(
x
,
quantized_op_type
,
weight_name
,
times
);
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
PADDLE_ENFORCE
(
subgraph
.
count
(
x
));
auto
*
input_node
=
subgraph
.
at
(
x
);
Node
*
quant_op_in_scale
=
subgraph
.
at
(
pattern
.
GetPDNode
(
"quant_op_in_scale"
));
Node
*
quant_op
=
subgraph
.
at
(
pattern
.
GetPDNode
(
"quant_op"
));
Node
*
quant_op_out_scale
=
subgraph
.
at
(
pattern
.
GetPDNode
(
"quant_op_out_scale"
));
Node
*
quant_op_out
=
subgraph
.
at
(
pattern
.
GetPDNode
(
"quant_op_out"
));
std
::
vector
<
Node
*>
nodes
;
for
(
int
i
=
0
;
i
<
times
;
i
++
)
{
nodes
.
push_back
(
subgraph
.
at
(
pattern
.
GetPDNode
(
"quantized_op_weight"
+
std
::
to_string
(
i
))));
nodes
.
push_back
(
subgraph
.
at
(
pattern
.
GetPDNode
(
"quantized_op"
+
std
::
to_string
(
i
))));
nodes
.
push_back
(
subgraph
.
at
(
pattern
.
GetPDNode
(
"quantized_op_out"
+
std
::
to_string
(
i
))));
nodes
.
push_back
(
subgraph
.
at
(
pattern
.
GetPDNode
(
"dequant_op"
+
std
::
to_string
(
i
))));
nodes
.
push_back
(
subgraph
.
at
(
pattern
.
GetPDNode
(
"dequant_op_out"
+
std
::
to_string
(
i
))));
}
int
bit_length
=
boost
::
get
<
int
>
(
quant_op
->
Op
()
->
GetAttr
(
"bit_length"
));
int
range
=
((
1
<<
(
bit_length
-
1
))
-
1
);
// Prepare input scale
std
::
string
input_scale_var_name
=
quant_op
->
Op
()
->
Input
(
"InScale"
).
front
();
PADDLE_ENFORCE
(
scope
);
const
LoDTensor
&
input_scale_tensor
=
scope
->
FindVar
(
input_scale_var_name
)
->
Get
<
LoDTensor
>
();
PADDLE_ENFORCE
(
paddle
::
platform
::
is_cpu_place
(
input_scale_tensor
.
place
()));
const
float
*
input_scale_data
=
input_scale_tensor
.
data
<
float
>
();
float
input_scale
=
input_scale_data
[
0
];
std
::
unordered_set
<
const
Node
*>
delete_nodes
;
for
(
int
i
=
0
;
i
<
times
;
i
++
)
{
// max_range = (range * range) / weight_scale
float
max_range
=
boost
::
get
<
float
>
(
nodes
[
i
*
kNumFields
+
kDequantOpOffset
]
->
Op
()
->
GetAttr
(
"max_range"
));
float
weight_scale
=
(
range
*
range
)
/
max_range
;
auto
base_op_desc
=
*
nodes
[
i
*
kNumFields
+
kQuantizedOpOffset
]
->
Op
()
->
Proto
();
std
::
string
new_input
=
input_node
->
Name
();
std
::
string
new_output
=
nodes
[
i
*
kNumFields
+
kDequantOpOutOffset
]
->
Name
();
framework
::
OpDesc
new_op_desc
(
base_op_desc
,
nullptr
);
new_op_desc
.
SetType
(
quantized_op_type
);
if
(
quantized_op_type
==
"conv2d"
||
quantized_op_type
==
"conv2d_fusion"
)
{
new_op_desc
.
SetInput
(
"Input"
,
{
new_input
});
new_op_desc
.
SetOutput
(
"Output"
,
{
new_output
});
}
else
if
(
quantized_op_type
==
"fc"
)
{
new_op_desc
.
SetInput
(
"Input"
,
{
new_input
});
new_op_desc
.
SetOutput
(
"Out"
,
{
new_output
});
}
else
if
(
quantized_op_type
==
"mul"
)
{
new_op_desc
.
SetInput
(
"X"
,
{
new_input
});
new_op_desc
.
SetOutput
(
"Out"
,
{
new_output
});
}
new_op_desc
.
SetAttr
(
"enable_int8"
,
true
);
new_op_desc
.
SetAttr
(
"input_scale"
,
input_scale
);
new_op_desc
.
SetAttr
(
"weight_scale"
,
weight_scale
);
new_op_desc
.
Flush
();
auto
*
new_op
=
graph
->
CreateOpNode
(
&
new_op_desc
);
IR_NODE_LINK_TO
(
input_node
,
new_op
);
IR_NODE_LINK_TO
(
nodes
[
i
*
kNumFields
+
kQuantizedWeightOffset
],
new_op
);
IR_NODE_LINK_TO
(
new_op
,
nodes
[
i
*
kNumFields
+
kDequantOpOutOffset
]);
delete_nodes
.
insert
(
nodes
[
i
*
kNumFields
+
kQuantizedOpOffset
]);
delete_nodes
.
insert
(
nodes
[
i
*
kNumFields
+
kQuantizedOpOutOffset
]);
delete_nodes
.
insert
(
nodes
[
i
*
kNumFields
+
kDequantOpOffset
]);
}
delete_nodes
.
insert
(
quant_op_in_scale
);
delete_nodes
.
insert
(
quant_op
);
delete_nodes
.
insert
(
quant_op_out
);
delete_nodes
.
insert
(
quant_op_out_scale
);
// Delete the unneeded nodes.
GraphSafeRemoveNodes
(
graph
,
delete_nodes
);
};
gpd
(
graph
,
handler
);
}
void
QuantDequantFusePass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
const
std
::
string
pattern_name
=
"quant_dequant_fuse"
;
FusePassBase
::
Init
(
pattern_name
,
graph
);
std
::
unordered_set
<
std
::
string
>
quantized_op_types
=
{
"conv2d"
,
"mul"
};
auto
*
scope
=
param_scope
();
for
(
auto
&
op_type
:
quantized_op_types
)
{
for
(
int
i
=
1
;
i
<=
6
;
i
++
)
{
RunQuantDequant
(
graph
,
scope
,
i
,
op_type
);
}
}
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
quant_conv2d_dequant_fuse_pass
,
paddle
::
framework
::
ir
::
QuantDequantFusePass
);
paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.h
0 → 100644
浏览文件 @
3e6aa498
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
QuantDequantFusePass
:
public
FusePassBase
{
public:
virtual
~
QuantDequantFusePass
()
{}
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/simplify_anakin_
detection_pattern
_pass.cc
→
paddle/fluid/framework/ir/simplify_anakin_
priorbox_detection_out
_pass.cc
浏览文件 @
3e6aa498
...
@@ -17,25 +17,24 @@
...
@@ -17,25 +17,24 @@
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/ir/simplify_anakin_
detection_pattern
_pass.h"
#include "paddle/fluid/framework/ir/simplify_anakin_
priorbox_detection_out
_pass.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
ir
{
namespace
ir
{
template
<
int
times
>
void
RunSimplifyAnakinDetection
(
ir
::
Graph
*
graph
,
int
times
,
bool
is_density
,
void
SimplifyAnakinDetectionPatternPass
<
times
>::
ApplyImpl
(
bool
is_reshape
)
{
ir
::
Graph
*
graph
)
const
{
const
std
::
string
pattern_name
=
const
std
::
string
pattern_name
=
"simplify_anakin_detection_pattern_pass"
+
std
::
to_string
(
times
);
"simplify_anakin_detection_pattern_pass"
+
std
::
to_string
(
times
);
FusePassBase
::
Init
(
pattern_name
,
graph
)
;
std
::
string
priorbox_type
=
is_density
?
"density_prior_box"
:
"prior_box"
;
GraphPatternDetector
gpd
;
GraphPatternDetector
gpd
;
std
::
vector
<
PDNode
*>
input_nodes
;
std
::
vector
<
PDNode
*>
input_nodes
;
for
(
int
i
=
0
;
i
<
times
;
i
++
)
{
for
(
int
i
=
0
;
i
<
times
;
i
++
)
{
input_nodes
.
push_back
(
gpd
.
mutable_pattern
()
input_nodes
.
push_back
(
gpd
.
mutable_pattern
()
->
NewNode
(
"x"
+
std
::
to_string
(
i
))
->
NewNode
(
"x"
+
std
::
to_string
(
i
))
->
assert_is_op_input
(
"density_prior_box"
,
"Input"
)
->
assert_is_op_input
(
priorbox_type
,
"Input"
)
->
AsInput
());
->
AsInput
());
}
}
input_nodes
.
push_back
(
gpd
.
mutable_pattern
()
input_nodes
.
push_back
(
gpd
.
mutable_pattern
()
...
@@ -49,7 +48,7 @@ void SimplifyAnakinDetectionPatternPass<times>::ApplyImpl(
...
@@ -49,7 +48,7 @@ void SimplifyAnakinDetectionPatternPass<times>::ApplyImpl(
->
AsInput
());
->
AsInput
());
patterns
::
AnakinDetectionPattern
pattern
(
gpd
.
mutable_pattern
(),
pattern_name
);
patterns
::
AnakinDetectionPattern
pattern
(
gpd
.
mutable_pattern
(),
pattern_name
);
pattern
(
input_nodes
,
times
);
pattern
(
input_nodes
,
times
,
priorbox_type
,
is_reshape
);
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
Graph
*
g
)
{
...
@@ -119,8 +118,7 @@ void SimplifyAnakinDetectionPatternPass<times>::ApplyImpl(
...
@@ -119,8 +118,7 @@ void SimplifyAnakinDetectionPatternPass<times>::ApplyImpl(
boost
::
get
<
std
::
string
>
(
box_coder_op
->
Op
()
->
GetAttr
(
"code_type"
));
boost
::
get
<
std
::
string
>
(
box_coder_op
->
Op
()
->
GetAttr
(
"code_type"
));
bool
box_normalized
=
bool
box_normalized
=
boost
::
get
<
bool
>
(
box_coder_op
->
Op
()
->
GetAttr
(
"box_normalized"
));
boost
::
get
<
bool
>
(
box_coder_op
->
Op
()
->
GetAttr
(
"box_normalized"
));
// auto variance =
// boost::get<std::vector<float>>(box_coder_op->Op()->GetAttr("variance"));
int
background_label
=
int
background_label
=
boost
::
get
<
int
>
(
multiclass_nms
->
Op
()
->
GetAttr
(
"background_label"
));
boost
::
get
<
int
>
(
multiclass_nms
->
Op
()
->
GetAttr
(
"background_label"
));
float
score_threshold
=
float
score_threshold
=
...
@@ -138,7 +136,6 @@ void SimplifyAnakinDetectionPatternPass<times>::ApplyImpl(
...
@@ -138,7 +136,6 @@ void SimplifyAnakinDetectionPatternPass<times>::ApplyImpl(
nodes
[
i
*
kNumFields
+
kPriorBoxLocOffset
]
->
Name
());
nodes
[
i
*
kNumFields
+
kPriorBoxLocOffset
]
->
Name
());
}
}
// int axis = boost::get<int>(concat_op1->Op()->GetAttr("axis"));
framework
::
OpDesc
concat1_desc
;
framework
::
OpDesc
concat1_desc
;
concat1_desc
.
SetType
(
"concat"
);
concat1_desc
.
SetType
(
"concat"
);
concat1_desc
.
SetInput
(
"X"
,
concat1_input_names
);
concat1_desc
.
SetInput
(
"X"
,
concat1_input_names
);
...
@@ -213,31 +210,24 @@ void SimplifyAnakinDetectionPatternPass<times>::ApplyImpl(
...
@@ -213,31 +210,24 @@ void SimplifyAnakinDetectionPatternPass<times>::ApplyImpl(
gpd
(
graph
,
handler
);
gpd
(
graph
,
handler
);
}
}
template
class
SimplifyAnakinDetectionPatternPass
<
1
>;
void
SimplifyAnakinDetectionPatternPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
template
class
SimplifyAnakinDetectionPatternPass
<
2
>;
const
int
pattern_nums
=
6
;
template
class
SimplifyAnakinDetectionPatternPass
<
3
>;
const
std
::
string
pattern_name
=
"simplify_anakin_detection_pattern_pass"
;
template
class
SimplifyAnakinDetectionPatternPass
<
4
>;
FusePassBase
::
Init
(
pattern_name
,
graph
);
template
class
SimplifyAnakinDetectionPatternPass
<
5
>;
std
::
vector
<
bool
>
options
=
{
true
,
false
};
template
class
SimplifyAnakinDetectionPatternPass
<
6
>;
for
(
const
auto
&
is_density
:
options
)
{
for
(
const
auto
&
is_reshape
:
options
)
{
for
(
int
i
=
1
;
i
<=
pattern_nums
;
i
++
)
{
RunSimplifyAnakinDetection
(
graph
,
i
,
is_density
,
is_reshape
);
}
}
}
}
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
REGISTER_PASS
(
simplify_anakin_detection_pattern_pass
,
typedef
paddle
::
framework
::
ir
::
SimplifyAnakinDetectionPatternPass
paddle
::
framework
::
ir
::
SimplifyAnakinDetectionPatternPass
<
1
>
);
priorbox_pattern
;
REGISTER_PASS
(
simplify_anakin_priorbox_detection_out_pass
,
priorbox_pattern
);
REGISTER_PASS
(
simplify_anakin_detection_pattern_pass2
,
paddle
::
framework
::
ir
::
SimplifyAnakinDetectionPatternPass
<
2
>
);
REGISTER_PASS
(
simplify_anakin_detection_pattern_pass3
,
paddle
::
framework
::
ir
::
SimplifyAnakinDetectionPatternPass
<
3
>
);
REGISTER_PASS
(
simplify_anakin_detection_pattern_pass4
,
paddle
::
framework
::
ir
::
SimplifyAnakinDetectionPatternPass
<
4
>
);
REGISTER_PASS
(
simplify_anakin_detection_pattern_pass5
,
paddle
::
framework
::
ir
::
SimplifyAnakinDetectionPatternPass
<
5
>
);
REGISTER_PASS
(
simplify_anakin_detection_pattern_pass6
,
paddle
::
framework
::
ir
::
SimplifyAnakinDetectionPatternPass
<
6
>
);
paddle/fluid/framework/ir/simplify_anakin_
detection_pattern
_pass.h
→
paddle/fluid/framework/ir/simplify_anakin_
priorbox_detection_out
_pass.h
浏览文件 @
3e6aa498
...
@@ -26,7 +26,6 @@ namespace ir {
...
@@ -26,7 +26,6 @@ namespace ir {
// these structures will be used as inputs to the concat Op. This pattern will
// these structures will be used as inputs to the concat Op. This pattern will
// be detected by our pass. The times here represents the repeat times of this
// be detected by our pass. The times here represents the repeat times of this
// structure.
// structure.
template
<
int
times
>
class
SimplifyAnakinDetectionPatternPass
:
public
FusePassBase
{
class
SimplifyAnakinDetectionPatternPass
:
public
FusePassBase
{
public:
public:
virtual
~
SimplifyAnakinDetectionPatternPass
()
{}
virtual
~
SimplifyAnakinDetectionPatternPass
()
{}
...
...
paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.cc
浏览文件 @
3e6aa498
...
@@ -25,11 +25,9 @@ namespace paddle {
...
@@ -25,11 +25,9 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
ir
{
namespace
ir
{
template
<
int
times
>
void
RunTransposeFlattenConcatFuse
(
ir
::
Graph
*
graph
,
int
times
)
{
void
TransposeFlattenConcatFusePass
<
times
>::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
const
std
::
string
pattern_name
=
const
std
::
string
pattern_name
=
"transpose_flatten"
+
std
::
to_string
(
times
)
+
"_concat_fuse"
;
"transpose_flatten"
+
std
::
to_string
(
times
)
+
"_concat_fuse"
;
FusePassBase
::
Init
(
pattern_name
,
graph
);
GraphPatternDetector
gpd
;
GraphPatternDetector
gpd
;
std
::
vector
<
PDNode
*>
input_nodes
;
std
::
vector
<
PDNode
*>
input_nodes
;
...
@@ -122,31 +120,18 @@ void TransposeFlattenConcatFusePass<times>::ApplyImpl(ir::Graph *graph) const {
...
@@ -122,31 +120,18 @@ void TransposeFlattenConcatFusePass<times>::ApplyImpl(ir::Graph *graph) const {
gpd
(
graph
,
handler
);
gpd
(
graph
,
handler
);
}
}
template
class
TransposeFlattenConcatFusePass
<
1
>;
void
TransposeFlattenConcatFusePass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
template
class
TransposeFlattenConcatFusePass
<
2
>;
const
int
pattern_nums
=
6
;
template
class
TransposeFlattenConcatFusePass
<
3
>;
const
std
::
string
pattern_name
=
"transpose_flatten_concat_fuse"
;
template
class
TransposeFlattenConcatFusePass
<
4
>;
FusePassBase
::
Init
(
pattern_name
,
graph
);
template
class
TransposeFlattenConcatFusePass
<
5
>;
for
(
int
i
=
1
;
i
<=
pattern_nums
;
i
++
)
{
template
class
TransposeFlattenConcatFusePass
<
6
>;
RunTransposeFlattenConcatFuse
(
graph
,
i
);
}
}
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
REGISTER_PASS
(
transpose_flatten_concat_fuse_pass
,
REGISTER_PASS
(
transpose_flatten_concat_fuse_pass
,
paddle
::
framework
::
ir
::
TransposeFlattenConcatFusePass
<
1
>
);
paddle
::
framework
::
ir
::
TransposeFlattenConcatFusePass
);
REGISTER_PASS
(
transpose_flatten2_concat_fuse_pass
,
paddle
::
framework
::
ir
::
TransposeFlattenConcatFusePass
<
2
>
);
REGISTER_PASS
(
transpose_flatten3_concat_fuse_pass
,
paddle
::
framework
::
ir
::
TransposeFlattenConcatFusePass
<
3
>
);
REGISTER_PASS
(
transpose_flatten4_concat_fuse_pass
,
paddle
::
framework
::
ir
::
TransposeFlattenConcatFusePass
<
4
>
);
REGISTER_PASS
(
transpose_flatten5_concat_fuse_pass
,
paddle
::
framework
::
ir
::
TransposeFlattenConcatFusePass
<
5
>
);
REGISTER_PASS
(
transpose_flatten6_concat_fuse_pass
,
paddle
::
framework
::
ir
::
TransposeFlattenConcatFusePass
<
6
>
);
paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.h
浏览文件 @
3e6aa498
...
@@ -13,6 +13,8 @@
...
@@ -13,6 +13,8 @@
// limitations under the License.
// limitations under the License.
#pragma once
#pragma once
#include <memory>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
...
@@ -24,7 +26,6 @@ namespace ir {
...
@@ -24,7 +26,6 @@ namespace ir {
// these structures will be used as inputs to the concat Op. This pattern will
// these structures will be used as inputs to the concat Op. This pattern will
// be detected by our pass. The times here represents the repeat times of this
// be detected by our pass. The times here represents the repeat times of this
// structure.
// structure.
template
<
int
times
>
class
TransposeFlattenConcatFusePass
:
public
FusePassBase
{
class
TransposeFlattenConcatFusePass
:
public
FusePassBase
{
public:
public:
virtual
~
TransposeFlattenConcatFusePass
()
{}
virtual
~
TransposeFlattenConcatFusePass
()
{}
...
...
paddle/fluid/inference/anakin/convert/density_prior_box.cc
浏览文件 @
3e6aa498
...
@@ -34,25 +34,41 @@ void DensityPriorBoxOpConverter::operator()(const framework::proto::OpDesc& op,
...
@@ -34,25 +34,41 @@ void DensityPriorBoxOpConverter::operator()(const framework::proto::OpDesc& op,
auto
input_name
=
op_desc
.
Input
(
"Input"
).
front
();
auto
input_name
=
op_desc
.
Input
(
"Input"
).
front
();
auto
image_name
=
op_desc
.
Input
(
"Image"
).
front
();
auto
image_name
=
op_desc
.
Input
(
"Image"
).
front
();
auto
output_name
=
op_desc
.
Output
(
"Boxes"
).
front
();
auto
output_name
=
op_desc
.
Output
(
"Boxes"
).
front
();
auto
op_type
=
op_desc
.
Type
();
auto
op_name
=
op_type
+
":"
+
op_desc
.
Output
(
"Boxes"
).
front
();
auto
op_name
=
op_desc
.
Type
()
+
":"
+
op_desc
.
Output
(
"Boxes"
).
front
();
// only for density_prior_box
std
::
vector
<
float
>
fixed_sizes
=
{};
std
::
vector
<
float
>
fixed_ratios
=
{};
std
::
vector
<
int
>
densities
=
{};
auto
fixed_sizes
=
std
::
vector
<
float
>
min_sizes
=
{};
boost
::
get
<
std
::
vector
<
float
>>
(
op_desc
.
GetAttr
(
"fixed_sizes"
));
std
::
vector
<
float
>
max_sizes
=
{};
auto
fixed_ratios
=
std
::
vector
<
float
>
aspect_ratios
=
{};
boost
::
get
<
std
::
vector
<
float
>>
(
op_desc
.
GetAttr
(
"fixed_ratios"
));
bool
is_clip
=
false
;
auto
densities
=
boost
::
get
<
std
::
vector
<
int
>>
(
op_desc
.
GetAttr
(
"densities"
));
bool
is_flip
=
false
;
if
(
op_type
==
"density_prior_box"
)
{
fixed_sizes
=
boost
::
get
<
std
::
vector
<
float
>>
(
op_desc
.
GetAttr
(
"fixed_sizes"
));
fixed_ratios
=
boost
::
get
<
std
::
vector
<
float
>>
(
op_desc
.
GetAttr
(
"fixed_ratios"
));
densities
=
boost
::
get
<
std
::
vector
<
int
>>
(
op_desc
.
GetAttr
(
"densities"
));
is_clip
=
boost
::
get
<
bool
>
(
op_desc
.
GetAttr
(
"clip"
));
}
else
if
(
op_type
==
"prior_box"
)
{
min_sizes
=
boost
::
get
<
std
::
vector
<
float
>>
(
op_desc
.
GetAttr
(
"min_sizes"
));
max_sizes
=
boost
::
get
<
std
::
vector
<
float
>>
(
op_desc
.
GetAttr
(
"max_sizes"
));
aspect_ratios
=
boost
::
get
<
std
::
vector
<
float
>>
(
op_desc
.
GetAttr
(
"aspect_ratios"
));
is_clip
=
boost
::
get
<
bool
>
(
op_desc
.
GetAttr
(
"clip"
));
is_flip
=
boost
::
get
<
bool
>
(
op_desc
.
GetAttr
(
"flip"
));
}
std
::
vector
<
float
>
dens
;
std
::
vector
<
float
>
dens
;
for
(
auto
&
ele
:
densities
)
{
for
(
auto
&
ele
:
densities
)
{
dens
.
push_back
(
static_cast
<
float
>
(
ele
));
dens
.
push_back
(
static_cast
<
float
>
(
ele
));
}
}
// lack flip
// auto clip = boost::get<bool>(op_desc.GetAttr("clip"));
auto
variances
=
boost
::
get
<
std
::
vector
<
float
>>
(
op_desc
.
GetAttr
(
"variances"
));
auto
variances
=
boost
::
get
<
std
::
vector
<
float
>>
(
op_desc
.
GetAttr
(
"variances"
));
for
(
auto
&
ele
:
variances
)
{
LOG
(
INFO
)
<<
ele
;
}
// lack img_h, img_w
// lack img_h, img_w
auto
step_h
=
boost
::
get
<
float
>
(
op_desc
.
GetAttr
(
"step_h"
));
auto
step_h
=
boost
::
get
<
float
>
(
op_desc
.
GetAttr
(
"step_h"
));
...
@@ -66,14 +82,14 @@ void DensityPriorBoxOpConverter::operator()(const framework::proto::OpDesc& op,
...
@@ -66,14 +82,14 @@ void DensityPriorBoxOpConverter::operator()(const framework::proto::OpDesc& op,
std
::
vector
<
float
>
temp_v
=
{};
std
::
vector
<
float
>
temp_v
=
{};
engine_
->
AddOp
(
op_name
,
"PriorBox"
,
{
input_name
,
image_name
},
{
output_name
});
engine_
->
AddOp
(
op_name
,
"PriorBox"
,
{
input_name
,
image_name
},
{
output_name
});
engine_
->
AddOpAttr
<
PTuple
<
float
>>
(
op_name
,
"min_size"
,
temp_v
);
engine_
->
AddOpAttr
<
PTuple
<
float
>>
(
op_name
,
"min_size"
,
min_sizes
);
engine_
->
AddOpAttr
<
PTuple
<
float
>>
(
op_name
,
"max_size"
,
temp_v
);
engine_
->
AddOpAttr
<
PTuple
<
float
>>
(
op_name
,
"max_size"
,
max_sizes
);
engine_
->
AddOpAttr
<
PTuple
<
float
>>
(
op_name
,
"aspect_ratio"
,
temp_v
);
engine_
->
AddOpAttr
<
PTuple
<
float
>>
(
op_name
,
"aspect_ratio"
,
aspect_ratios
);
engine_
->
AddOpAttr
<
PTuple
<
float
>>
(
op_name
,
"fixed_size"
,
fixed_sizes
);
engine_
->
AddOpAttr
<
PTuple
<
float
>>
(
op_name
,
"fixed_size"
,
fixed_sizes
);
engine_
->
AddOpAttr
<
PTuple
<
float
>>
(
op_name
,
"fixed_ratio"
,
fixed_ratios
);
engine_
->
AddOpAttr
<
PTuple
<
float
>>
(
op_name
,
"fixed_ratio"
,
fixed_ratios
);
engine_
->
AddOpAttr
<
PTuple
<
float
>>
(
op_name
,
"density"
,
dens
);
engine_
->
AddOpAttr
<
PTuple
<
float
>>
(
op_name
,
"density"
,
dens
);
engine_
->
AddOpAttr
(
op_name
,
"is_flip"
,
static_cast
<
bool
>
(
false
)
);
engine_
->
AddOpAttr
(
op_name
,
"is_flip"
,
is_flip
);
engine_
->
AddOpAttr
(
op_name
,
"is_clip"
,
static_cast
<
bool
>
(
false
)
);
engine_
->
AddOpAttr
(
op_name
,
"is_clip"
,
is_clip
);
engine_
->
AddOpAttr
<
PTuple
<
float
>>
(
op_name
,
"variance"
,
variances
);
engine_
->
AddOpAttr
<
PTuple
<
float
>>
(
op_name
,
"variance"
,
variances
);
engine_
->
AddOpAttr
(
op_name
,
"img_h"
,
static_cast
<
int
>
(
0
));
engine_
->
AddOpAttr
(
op_name
,
"img_h"
,
static_cast
<
int
>
(
0
));
engine_
->
AddOpAttr
(
op_name
,
"img_w"
,
static_cast
<
int
>
(
0
));
engine_
->
AddOpAttr
(
op_name
,
"img_w"
,
static_cast
<
int
>
(
0
));
...
@@ -88,3 +104,4 @@ void DensityPriorBoxOpConverter::operator()(const framework::proto::OpDesc& op,
...
@@ -88,3 +104,4 @@ void DensityPriorBoxOpConverter::operator()(const framework::proto::OpDesc& op,
}
// namespace paddle
}
// namespace paddle
REGISTER_ANAKIN_OP_CONVERTER
(
density_prior_box
,
DensityPriorBoxOpConverter
);
REGISTER_ANAKIN_OP_CONVERTER
(
density_prior_box
,
DensityPriorBoxOpConverter
);
REGISTER_ANAKIN_OP_CONVERTER
(
prior_box
,
DensityPriorBoxOpConverter
);
paddle/fluid/inference/anakin/convert/op_converter.h
浏览文件 @
3e6aa498
...
@@ -48,7 +48,7 @@ class AnakinOpConverter {
...
@@ -48,7 +48,7 @@ class AnakinOpConverter {
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
std
::
string
op_type
=
op_desc
.
Type
();
std
::
string
op_type
=
op_desc
.
Type
();
AnakinOpConverter
*
it
=
nullptr
;
AnakinOpConverter
*
it
=
nullptr
;
if
(
op_type
==
"depthwise_conv2d"
)
op_type
=
"conv2d"
;
if
(
op_type
==
"reshape2"
)
op_type
=
"reshape"
;
if
(
op_type
==
"reshape2"
)
op_type
=
"reshape"
;
if
(
op_type
==
"transpose2"
)
op_type
=
"transpose"
;
if
(
op_type
==
"transpose2"
)
op_type
=
"transpose"
;
if
(
op_type
==
"flatten2"
)
op_type
=
"flatten"
;
if
(
op_type
==
"flatten2"
)
op_type
=
"flatten"
;
...
...
paddle/fluid/inference/anakin/op_teller.cc
浏览文件 @
3e6aa498
...
@@ -42,6 +42,8 @@ struct SimpleOpTypeSetTeller : public Teller {
...
@@ -42,6 +42,8 @@ struct SimpleOpTypeSetTeller : public Teller {
teller_set
.
insert
(
"dropout"
);
teller_set
.
insert
(
"dropout"
);
teller_set
.
insert
(
"sigmoid"
);
teller_set
.
insert
(
"sigmoid"
);
teller_set
.
insert
(
"sum"
);
teller_set
.
insert
(
"sum"
);
teller_set
.
insert
(
"depthwise_conv2d"
);
teller_set
.
insert
(
"prior_box"
);
}
}
bool
operator
()(
const
std
::
string
&
op_type
,
bool
operator
()(
const
std
::
string
&
op_type
,
...
...
paddle/fluid/inference/analysis/ir_passes/anakin_subgraph_pass.cc
浏览文件 @
3e6aa498
...
@@ -37,14 +37,14 @@ using framework::ir::Node;
...
@@ -37,14 +37,14 @@ using framework::ir::Node;
void
analysis
::
AnakinSubgraphPass
::
ApplyImpl
(
void
analysis
::
AnakinSubgraphPass
::
ApplyImpl
(
framework
::
ir
::
Graph
*
graph
)
const
{
framework
::
ir
::
Graph
*
graph
)
const
{
framework
::
ir
::
FusePassBase
::
Init
(
"anakin_subgraph_pass"
,
graph
.
get
()
);
framework
::
ir
::
FusePassBase
::
Init
(
"anakin_subgraph_pass"
,
graph
);
auto
teller
=
[](
const
framework
::
ir
::
Node
*
node
)
{
auto
teller
=
[](
const
framework
::
ir
::
Node
*
node
)
{
if
(
!
node
->
IsOp
()
||
!
node
->
Op
())
return
false
;
if
(
!
node
->
IsOp
()
||
!
node
->
Op
())
return
false
;
return
anakin
::
OpTeller
::
Global
().
Tell
(
node
->
Op
()
->
Type
(),
*
node
->
Op
());
return
anakin
::
OpTeller
::
Global
().
Tell
(
node
->
Op
()
->
Type
(),
*
node
->
Op
());
};
};
SubGraphFuser
fuser
(
graph
.
get
()
,
teller
,
6
/* min_subgraph_size */
);
SubGraphFuser
fuser
(
graph
,
teller
,
6
/* min_subgraph_size */
);
fuser
();
fuser
();
std
::
vector
<
std
::
string
>
graph_param_names
=
std
::
vector
<
std
::
string
>
graph_param_names
=
...
@@ -56,10 +56,10 @@ void analysis::AnakinSubgraphPass::ApplyImpl(
...
@@ -56,10 +56,10 @@ void analysis::AnakinSubgraphPass::ApplyImpl(
for
(
auto
*
node
:
graph
->
Nodes
())
{
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
IsOp
()
&&
!
Agent
(
node
).
subgraph
()
->
empty
())
{
if
(
node
->
IsOp
()
&&
!
Agent
(
node
).
subgraph
()
->
empty
())
{
CreateAnakinOp
(
node
,
graph
.
get
()
,
graph_param_names
,
&
repetitive_params
);
CreateAnakinOp
(
node
,
graph
,
graph_param_names
,
&
repetitive_params
);
std
::
unordered_set
<
const
Node
*>
nodes2remove
(
std
::
unordered_set
<
const
Node
*>
nodes2remove
(
Agent
(
node
).
subgraph
()
->
begin
(),
Agent
(
node
).
subgraph
()
->
end
());
Agent
(
node
).
subgraph
()
->
begin
(),
Agent
(
node
).
subgraph
()
->
end
());
framework
::
ir
::
GraphSafeRemoveNodes
(
graph
.
get
()
,
nodes2remove
);
framework
::
ir
::
GraphSafeRemoveNodes
(
graph
,
nodes2remove
);
}
}
}
}
...
@@ -69,7 +69,7 @@ void analysis::AnakinSubgraphPass::ApplyImpl(
...
@@ -69,7 +69,7 @@ void analysis::AnakinSubgraphPass::ApplyImpl(
nodes2remove
.
insert
(
node
);
nodes2remove
.
insert
(
node
);
}
}
}
}
framework
::
ir
::
GraphSafeRemoveNodes
(
graph
.
get
()
,
nodes2remove
);
framework
::
ir
::
GraphSafeRemoveNodes
(
graph
,
nodes2remove
);
graph
->
Set
(
framework
::
ir
::
kRepetitiveParamAttr
,
graph
->
Set
(
framework
::
ir
::
kRepetitiveParamAttr
,
new
std
::
vector
<
std
::
string
>
(
repetitive_params
));
new
std
::
vector
<
std
::
string
>
(
repetitive_params
));
}
}
...
...
paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
浏览文件 @
3e6aa498
...
@@ -192,6 +192,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
...
@@ -192,6 +192,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
block_desc
.
Proto
()
->
SerializeAsString
());
block_desc
.
Proto
()
->
SerializeAsString
());
SetAttr
(
op_desc
->
Proto
(),
"max_batch_size"
,
Get
<
int
>
(
"max_batch_size"
));
SetAttr
(
op_desc
->
Proto
(),
"max_batch_size"
,
Get
<
int
>
(
"max_batch_size"
));
SetAttr
(
op_desc
->
Proto
(),
"workspace_size"
,
Get
<
int
>
(
"workspace_size"
));
SetAttr
(
op_desc
->
Proto
(),
"workspace_size"
,
Get
<
int
>
(
"workspace_size"
));
SetAttr
(
op_desc
->
Proto
(),
"gpu_id"
,
Get
<
int
>
(
"gpu_device_id"
));
SetAttr
(
op_desc
->
Proto
(),
"output_name_mapping"
,
output_mapping
);
SetAttr
(
op_desc
->
Proto
(),
"output_name_mapping"
,
output_mapping
);
SetAttr
(
op_desc
->
Proto
(),
"parameters"
,
params
);
SetAttr
(
op_desc
->
Proto
(),
"parameters"
,
params
);
...
...
paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc
浏览文件 @
3e6aa498
...
@@ -52,6 +52,7 @@ void IrParamsSyncAmongDevicesPass::RunImpl(Argument *argument) {
...
@@ -52,6 +52,7 @@ void IrParamsSyncAmongDevicesPass::RunImpl(Argument *argument) {
for
(
auto
&
var_name
:
all_vars
)
{
for
(
auto
&
var_name
:
all_vars
)
{
if
(
std
::
count
(
repetitive_params
.
begin
(),
repetitive_params
.
end
(),
if
(
std
::
count
(
repetitive_params
.
begin
(),
repetitive_params
.
end
(),
var_name
))
{
var_name
))
{
scope
->
EraseVars
({
var_name
});
continue
;
continue
;
}
}
auto
*
var
=
scope
->
FindLocalVar
(
var_name
);
auto
*
var
=
scope
->
FindLocalVar
(
var_name
);
...
...
paddle/fluid/inference/api/analysis_predictor.cc
浏览文件 @
3e6aa498
...
@@ -886,4 +886,5 @@ USE_ANAKIN_CONVERTER(detection_out);
...
@@ -886,4 +886,5 @@ USE_ANAKIN_CONVERTER(detection_out);
USE_ANAKIN_CONVERTER
(
density_prior_box
);
USE_ANAKIN_CONVERTER
(
density_prior_box
);
USE_ANAKIN_CONVERTER
(
dropout
);
USE_ANAKIN_CONVERTER
(
dropout
);
USE_ANAKIN_CONVERTER
(
sum
);
USE_ANAKIN_CONVERTER
(
sum
);
USE_ANAKIN_CONVERTER
(
prior_box
);
#endif
#endif
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
3e6aa498
...
@@ -70,17 +70,15 @@ void GpuPassStrategy::EnableMKLDNN() {
...
@@ -70,17 +70,15 @@ void GpuPassStrategy::EnableMKLDNN() {
// The following passes works for Anakin sub-graph engine.
// The following passes works for Anakin sub-graph engine.
const
std
::
vector
<
std
::
string
>
kAnakinSubgraphPasses
({
const
std
::
vector
<
std
::
string
>
kAnakinSubgraphPasses
({
"infer_clean_graph_pass"
,
//
"infer_clean_graph_pass"
,
//
"simplify_anakin_detection_pattern_pass5"
,
//
"simplify_anakin_priorbox_detection_out_pass"
,
//
"simplify_anakin_detection_pattern_pass4"
,
//
"fillconstant_elementwisemul_fuse"
,
//
"simplify_anakin_detection_pattern_pass3"
,
//
"fc_fuse_pass"
,
//
"simplify_anakin_detection_pattern_pass2"
,
//
"conv_elementwise_add_fuse_pass"
,
//
"anakin_fillconstant_elementwisemul_fuse"
,
//
"conv_bn_fuse_pass"
,
//
"fc_fuse_pass"
,
//
"conv_elementwise_add_fuse_pass"
,
//
"conv_elementwise_add_fuse_pass"
,
//
"fc_gru_fuse_pass"
,
//
"conv_bn_fuse_pass"
,
//
"quant_conv2d_dequant_fuse_pass"
,
//
"conv_elementwise_add_fuse_pass"
,
//
"fc_gru_fuse_pass"
,
//
"anakin_subgraph_pass"
,
"anakin_subgraph_pass"
,
});
});
...
@@ -97,13 +95,10 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
...
@@ -97,13 +95,10 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"conv_elementwise_add2_act_fuse_pass"
,
//
"conv_elementwise_add2_act_fuse_pass"
,
//
"conv_elementwise_add_fuse_pass"
,
//
"conv_elementwise_add_fuse_pass"
,
//
"runtime_context_cache_pass"
,
//
"runtime_context_cache_pass"
,
//
#endif
#endif //
"transpose_flatten_concat_fuse_pass"
,
});
});
for
(
int
i
=
6
;
i
>=
2
;
i
--
)
{
passes_
.
push_back
(
"transpose_flatten"
+
std
::
to_string
(
i
)
+
"_concat_fuse_pass"
);
}
use_gpu_
=
true
;
use_gpu_
=
true
;
}
}
...
...
paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
浏览文件 @
3e6aa498
...
@@ -52,6 +52,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
...
@@ -52,6 +52,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
std
::
string
engine_key_
;
std
::
string
engine_key_
;
std
::
string
engine_serialized_data_
;
std
::
string
engine_serialized_data_
;
bool
calibration_mode_
;
bool
calibration_mode_
;
int
device_id_
;
public:
public:
TensorRTEngineOp
(
const
std
::
string
&
type
,
TensorRTEngineOp
(
const
std
::
string
&
type
,
...
@@ -62,6 +63,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
...
@@ -62,6 +63,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
input_names_
=
Inputs
(
"Xs"
);
input_names_
=
Inputs
(
"Xs"
);
max_batch_size_
=
Attr
<
int
>
(
"max_batch_size"
);
max_batch_size_
=
Attr
<
int
>
(
"max_batch_size"
);
workspace_size_
=
Attr
<
int
>
(
"workspace_size"
);
workspace_size_
=
Attr
<
int
>
(
"workspace_size"
);
device_id_
=
Attr
<
int
>
(
"gpu_id"
);
enable_int8_
=
Attr
<
bool
>
(
"enable_int8"
);
enable_int8_
=
Attr
<
bool
>
(
"enable_int8"
);
calibration_data_
=
Attr
<
std
::
string
>
(
"calibration_data"
);
calibration_data_
=
Attr
<
std
::
string
>
(
"calibration_data"
);
engine_key_
=
Attr
<
std
::
string
>
(
"engine_key"
);
engine_key_
=
Attr
<
std
::
string
>
(
"engine_key"
);
...
@@ -79,6 +81,17 @@ class TensorRTEngineOp : public framework::OperatorBase {
...
@@ -79,6 +81,17 @@ class TensorRTEngineOp : public framework::OperatorBase {
if
(
enable_int8_
&&
calibration_data_
.
size
())
{
if
(
enable_int8_
&&
calibration_data_
.
size
())
{
calibrator_
.
reset
(
new
TRTInt8Calibrator
(
calibration_data_
));
calibrator_
.
reset
(
new
TRTInt8Calibrator
(
calibration_data_
));
}
}
if
(
!
calibration_mode_
&&
!
engine_serialized_data_
.
empty
())
{
trt_engine_
.
reset
(
new
inference
::
tensorrt
::
TensorRTEngine
(
max_batch_size_
,
workspace_size_
,
enable_int8_
,
calibrator_
.
get
(),
device_id_
));
PADDLE_ENFORCE
(
engine_serialized_data_
.
size
(),
"TRT serialized data should not be empty here,"
"there must be error when generate serialized data in TRT "
"subgraph detect pass."
);
trt_engine_
->
Deserialize
(
engine_serialized_data_
);
}
}
}
protected:
protected:
...
@@ -225,12 +238,8 @@ class TensorRTEngineOp : public framework::OperatorBase {
...
@@ -225,12 +238,8 @@ class TensorRTEngineOp : public framework::OperatorBase {
if
(
!
trt_engine_
)
{
if
(
!
trt_engine_
)
{
trt_engine_
.
reset
(
new
inference
::
tensorrt
::
TensorRTEngine
(
trt_engine_
.
reset
(
new
inference
::
tensorrt
::
TensorRTEngine
(
max_batch_size_
,
workspace_size_
,
enable_int8_
,
calibrator_
.
get
(),
max_batch_size_
,
workspace_size_
,
enable_int8_
,
calibrator_
.
get
(),
boost
::
get
<
platform
::
CUDAPlace
>
(
dev_place
).
device
));
device_id_
));
if
(
!
engine_serialized_data_
.
empty
())
{
PrepareTRTEngine
(
scope
,
trt_engine_
.
get
());
trt_engine_
->
Deserialize
(
engine_serialized_data_
);
}
else
{
PrepareTRTEngine
(
scope
,
trt_engine_
.
get
());
}
}
}
return
trt_engine_
.
get
();
return
trt_engine_
.
get
();
}
}
...
...
paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc
浏览文件 @
3e6aa498
...
@@ -108,6 +108,8 @@ TEST(TensorRTEngineOp, manual) {
...
@@ -108,6 +108,8 @@ TEST(TensorRTEngineOp, manual) {
std
::
vector
<
std
::
string
>
({
"z0"
}));
std
::
vector
<
std
::
string
>
({
"z0"
}));
engine_op_desc
.
SetAttr
(
"subgraph"
,
std
::
string
(
block_
->
SerializeAsString
()));
engine_op_desc
.
SetAttr
(
"subgraph"
,
std
::
string
(
block_
->
SerializeAsString
()));
engine_op_desc
.
SetAttr
(
"engine_serialized_data"
,
std
::
string
(
""
));
engine_op_desc
.
SetAttr
(
"engine_serialized_data"
,
std
::
string
(
""
));
int
device_id
=
0
;
engine_op_desc
.
SetAttr
(
"gpu_id"
,
device_id
);
LOG
(
INFO
)
<<
"create engine op"
;
LOG
(
INFO
)
<<
"create engine op"
;
auto
engine_op
=
framework
::
OpRegistry
::
CreateOp
(
engine_op_desc
);
auto
engine_op
=
framework
::
OpRegistry
::
CreateOp
(
engine_op_desc
);
...
@@ -204,6 +206,8 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) {
...
@@ -204,6 +206,8 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) {
std
::
vector
<
std
::
string
>
({
"z3"
}));
std
::
vector
<
std
::
string
>
({
"z3"
}));
engine_op_desc
.
SetAttr
(
"subgraph"
,
std
::
string
(
block_
->
SerializeAsString
()));
engine_op_desc
.
SetAttr
(
"subgraph"
,
std
::
string
(
block_
->
SerializeAsString
()));
engine_op_desc
.
SetAttr
(
"engine_serialized_data"
,
std
::
string
(
""
));
engine_op_desc
.
SetAttr
(
"engine_serialized_data"
,
std
::
string
(
""
));
int
device_id
=
0
;
engine_op_desc
.
SetAttr
(
"gpu_id"
,
device_id
);
auto
engine_op
=
framework
::
OpRegistry
::
CreateOp
(
engine_op_desc
);
auto
engine_op
=
framework
::
OpRegistry
::
CreateOp
(
engine_op_desc
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录