Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
cb71fea0
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
cb71fea0
编写于
1月 07, 2021
作者:
S
Shang Zhizhou
提交者:
GitHub
1月 07, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add inference api: DisableTensorRtOps (#30109) (#30178)
上级
a2b0357d
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
32 addition
and
2 deletion
+32
-2
paddle/fluid/inference/analysis/argument.h
paddle/fluid/inference/analysis/argument.h
+2
-0
paddle/fluid/inference/analysis/ir_pass_manager.cc
paddle/fluid/inference/analysis/ir_pass_manager.cc
+2
-0
paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
...id/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
+7
-0
paddle/fluid/inference/api/analysis_config.cc
paddle/fluid/inference/api/analysis_config.cc
+9
-0
paddle/fluid/inference/api/analysis_predictor.cc
paddle/fluid/inference/api/analysis_predictor.cc
+1
-0
paddle/fluid/inference/api/paddle_analysis_config.h
paddle/fluid/inference/api/paddle_analysis_config.h
+10
-2
paddle/fluid/inference/tests/api/trt_mobilenet_test.cc
paddle/fluid/inference/tests/api/trt_mobilenet_test.cc
+1
-0
未找到文件。
paddle/fluid/inference/analysis/argument.h
浏览文件 @
cb71fea0
...
@@ -202,6 +202,8 @@ struct Argument {
...
@@ -202,6 +202,8 @@ struct Argument {
DECL_ARGUMENT_FIELD
(
tensorrt_max_batch_size
,
TensorRtMaxBatchSize
,
int
);
DECL_ARGUMENT_FIELD
(
tensorrt_max_batch_size
,
TensorRtMaxBatchSize
,
int
);
DECL_ARGUMENT_FIELD
(
tensorrt_workspace_size
,
TensorRtWorkspaceSize
,
int
);
DECL_ARGUMENT_FIELD
(
tensorrt_workspace_size
,
TensorRtWorkspaceSize
,
int
);
DECL_ARGUMENT_FIELD
(
tensorrt_min_subgraph_size
,
TensorRtMinSubgraphSize
,
int
);
DECL_ARGUMENT_FIELD
(
tensorrt_min_subgraph_size
,
TensorRtMinSubgraphSize
,
int
);
DECL_ARGUMENT_FIELD
(
tensorrt_disabled_ops
,
TensorRtDisabledOPs
,
std
::
vector
<
std
::
string
>
);
DECL_ARGUMENT_FIELD
(
tensorrt_precision_mode
,
TensorRtPrecisionMode
,
DECL_ARGUMENT_FIELD
(
tensorrt_precision_mode
,
TensorRtPrecisionMode
,
AnalysisConfig
::
Precision
);
AnalysisConfig
::
Precision
);
DECL_ARGUMENT_FIELD
(
tensorrt_use_static_engine
,
TensorRtUseStaticEngine
,
DECL_ARGUMENT_FIELD
(
tensorrt_use_static_engine
,
TensorRtUseStaticEngine
,
...
...
paddle/fluid/inference/analysis/ir_pass_manager.cc
浏览文件 @
cb71fea0
...
@@ -141,6 +141,8 @@ void IRPassManager::CreatePasses(Argument *argument,
...
@@ -141,6 +141,8 @@ void IRPassManager::CreatePasses(Argument *argument,
pass
->
Set
(
"optim_input_shape"
,
pass
->
Set
(
"optim_input_shape"
,
new
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
(
new
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
(
argument
->
optim_input_shape
()));
argument
->
optim_input_shape
()));
pass
->
Set
(
"trt_disabled_ops"
,
new
std
::
vector
<
std
::
string
>
(
argument
->
tensorrt_disabled_ops
()));
// Setting the disable_trt_plugin_fp16 to true means that TRT plugin will
// Setting the disable_trt_plugin_fp16 to true means that TRT plugin will
// not
// not
// run fp16.
// run fp16.
...
...
paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
浏览文件 @
cb71fea0
...
@@ -39,8 +39,15 @@ void analysis::TensorRtSubgraphPass::ApplyImpl(
...
@@ -39,8 +39,15 @@ void analysis::TensorRtSubgraphPass::ApplyImpl(
auto
enable_int8
=
Get
<
bool
>
(
"enable_int8"
);
auto
enable_int8
=
Get
<
bool
>
(
"enable_int8"
);
auto
use_calib_mode
=
Get
<
bool
>
(
"use_calib_mode"
);
auto
use_calib_mode
=
Get
<
bool
>
(
"use_calib_mode"
);
bool
no_calib_int8
=
enable_int8
&&
!
(
use_calib_mode
);
bool
no_calib_int8
=
enable_int8
&&
!
(
use_calib_mode
);
auto
trt_disabled_ops
=
Get
<
std
::
vector
<
std
::
string
>>
(
"trt_disabled_ops"
);
auto
teller
=
[
&
](
const
framework
::
ir
::
Node
*
node
)
{
auto
teller
=
[
&
](
const
framework
::
ir
::
Node
*
node
)
{
if
(
!
node
->
IsOp
()
||
!
node
->
Op
())
return
false
;
if
(
!
node
->
IsOp
()
||
!
node
->
Op
())
return
false
;
if
(
find
(
trt_disabled_ops
.
begin
(),
trt_disabled_ops
.
end
(),
node
->
Op
()
->
Type
())
!=
trt_disabled_ops
.
end
())
{
VLOG
(
3
)
<<
node
->
Op
()
->
Type
().
c_str
()
<<
" is diabled by config in TensorRT"
;
return
false
;
}
return
tensorrt
::
OpTeller
::
Global
().
Tell
(
node
->
Op
()
->
Type
(),
*
node
->
Op
(),
return
tensorrt
::
OpTeller
::
Global
().
Tell
(
node
->
Op
()
->
Type
(),
*
node
->
Op
(),
no_calib_int8
);
no_calib_int8
);
};
};
...
...
paddle/fluid/inference/api/analysis_config.cc
浏览文件 @
cb71fea0
...
@@ -125,6 +125,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
...
@@ -125,6 +125,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
CP_MEMBER
(
tensorrt_max_batchsize_
);
CP_MEMBER
(
tensorrt_max_batchsize_
);
CP_MEMBER
(
tensorrt_min_subgraph_size_
);
CP_MEMBER
(
tensorrt_min_subgraph_size_
);
CP_MEMBER
(
tensorrt_precision_mode_
);
CP_MEMBER
(
tensorrt_precision_mode_
);
CP_MEMBER
(
trt_disabled_ops_
);
CP_MEMBER
(
trt_use_static_engine_
);
CP_MEMBER
(
trt_use_static_engine_
);
CP_MEMBER
(
trt_use_calib_mode_
);
CP_MEMBER
(
trt_use_calib_mode_
);
CP_MEMBER
(
trt_use_oss_
);
CP_MEMBER
(
trt_use_oss_
);
...
@@ -304,6 +305,11 @@ void AnalysisConfig::SetTRTDynamicShapeInfo(
...
@@ -304,6 +305,11 @@ void AnalysisConfig::SetTRTDynamicShapeInfo(
disable_trt_plugin_fp16_
=
disable_trt_plugin_fp16
;
disable_trt_plugin_fp16_
=
disable_trt_plugin_fp16
;
}
}
void
AnalysisConfig
::
Exp_DisableTensorRtOPs
(
const
std
::
vector
<
std
::
string
>
&
ops
)
{
trt_disabled_ops_
.
insert
(
trt_disabled_ops_
.
end
(),
ops
.
begin
(),
ops
.
end
());
}
void
AnalysisConfig
::
EnableTensorRtOSS
()
{
trt_use_oss_
=
true
;
}
void
AnalysisConfig
::
EnableTensorRtOSS
()
{
trt_use_oss_
=
true
;
}
// TODO(Superjomn) refactor this, buggy.
// TODO(Superjomn) refactor this, buggy.
...
@@ -443,6 +449,9 @@ std::string AnalysisConfig::SerializeInfoCache() {
...
@@ -443,6 +449,9 @@ std::string AnalysisConfig::SerializeInfoCache() {
ss
<<
tensorrt_max_batchsize_
;
ss
<<
tensorrt_max_batchsize_
;
ss
<<
tensorrt_min_subgraph_size_
;
ss
<<
tensorrt_min_subgraph_size_
;
for
(
auto
&
op
:
trt_disabled_ops_
)
ss
<<
op
.
c_str
();
ss
<<
";"
;
ss
<<
enable_memory_optim_
;
ss
<<
enable_memory_optim_
;
ss
<<
use_mkldnn_
;
ss
<<
use_mkldnn_
;
...
...
paddle/fluid/inference/api/analysis_predictor.cc
浏览文件 @
cb71fea0
...
@@ -476,6 +476,7 @@ void AnalysisPredictor::PrepareArgument() {
...
@@ -476,6 +476,7 @@ void AnalysisPredictor::PrepareArgument() {
argument_
.
SetTensorRtWorkspaceSize
(
config_
.
tensorrt_workspace_size_
);
argument_
.
SetTensorRtWorkspaceSize
(
config_
.
tensorrt_workspace_size_
);
argument_
.
SetTensorRtMaxBatchSize
(
config_
.
tensorrt_max_batchsize_
);
argument_
.
SetTensorRtMaxBatchSize
(
config_
.
tensorrt_max_batchsize_
);
argument_
.
SetTensorRtMinSubgraphSize
(
config_
.
tensorrt_min_subgraph_size_
);
argument_
.
SetTensorRtMinSubgraphSize
(
config_
.
tensorrt_min_subgraph_size_
);
argument_
.
SetTensorRtDisabledOPs
(
config_
.
trt_disabled_ops_
);
argument_
.
SetTensorRtPrecisionMode
(
config_
.
tensorrt_precision_mode_
);
argument_
.
SetTensorRtPrecisionMode
(
config_
.
tensorrt_precision_mode_
);
argument_
.
SetTensorRtUseStaticEngine
(
config_
.
trt_use_static_engine_
);
argument_
.
SetTensorRtUseStaticEngine
(
config_
.
trt_use_static_engine_
);
argument_
.
SetTensorRtUseCalibMode
(
config_
.
trt_use_calib_mode_
);
argument_
.
SetTensorRtUseCalibMode
(
config_
.
trt_use_calib_mode_
);
...
...
paddle/fluid/inference/api/paddle_analysis_config.h
浏览文件 @
cb71fea0
...
@@ -313,10 +313,17 @@ struct PD_INFER_DECL AnalysisConfig {
...
@@ -313,10 +313,17 @@ struct PD_INFER_DECL AnalysisConfig {
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
optim_input_shape
,
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
optim_input_shape
,
bool
disable_trt_plugin_fp16
=
false
);
bool
disable_trt_plugin_fp16
=
false
);
///
/// \brief Prevent ops running in Paddle-TRT
/// NOTE: just experimental, not an official stable API, easy to be broken.
///
void
Exp_DisableTensorRtOPs
(
const
std
::
vector
<
std
::
string
>&
ops
);
///
///
/// \brief Replace some TensorRT plugins to TensorRT OSS(
/// \brief Replace some TensorRT plugins to TensorRT OSS(
/// https://github.com/NVIDIA/TensorRT), with which some models's inference may
/// https://github.com/NVIDIA/TensorRT), with which some models's inference
/// be more high-performance. Libnvinfer_plugin.so greater than V7.2.1 is needed.
/// may be more high-performance. Libnvinfer_plugin.so greater than
/// V7.2.1 is needed.
///
///
void
EnableTensorRtOSS
();
void
EnableTensorRtOSS
();
///
///
...
@@ -587,6 +594,7 @@ struct PD_INFER_DECL AnalysisConfig {
...
@@ -587,6 +594,7 @@ struct PD_INFER_DECL AnalysisConfig {
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
min_input_shape_
{};
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
min_input_shape_
{};
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
max_input_shape_
{};
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
max_input_shape_
{};
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
optim_input_shape_
{};
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
optim_input_shape_
{};
std
::
vector
<
std
::
string
>
trt_disabled_ops_
{};
bool
disable_trt_plugin_fp16_
{
false
};
bool
disable_trt_plugin_fp16_
{
false
};
// memory reuse related.
// memory reuse related.
...
...
paddle/fluid/inference/tests/api/trt_mobilenet_test.cc
浏览文件 @
cb71fea0
...
@@ -57,6 +57,7 @@ TEST(PredictorPool, use_gpu) {
...
@@ -57,6 +57,7 @@ TEST(PredictorPool, use_gpu) {
config
.
EnableUseGpu
(
100
,
0
);
config
.
EnableUseGpu
(
100
,
0
);
config
.
SetModel
(
model_dir
);
config
.
SetModel
(
model_dir
);
config
.
EnableTensorRtEngine
();
config
.
EnableTensorRtEngine
();
config
.
Exp_DisableTensorRtOPs
({
"fc"
});
services
::
PredictorPool
pred_pool
(
config
,
1
);
services
::
PredictorPool
pred_pool
(
config
,
1
);
auto
predictor
=
pred_pool
.
Retrive
(
0
);
auto
predictor
=
pred_pool
.
Retrive
(
0
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录