Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
97b76c94
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
97b76c94
编写于
1月 28, 2019
作者:
Z
Zhaolong Xing
提交者:
GitHub
1月 28, 2019
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #15242 from NHZlX/trt_int8_ultimate_version
add trt int8 support
上级
10bc9ffc
b43ea40c
变更
21
隐藏空白更改
内联
并排
Showing
21 changed file
with
674 addition
and
59 deletion
+674
-59
paddle/fluid/framework/ir/graph_traits.cc
paddle/fluid/framework/ir/graph_traits.cc
+2
-1
paddle/fluid/inference/analysis/argument.h
paddle/fluid/inference/analysis/argument.h
+3
-0
paddle/fluid/inference/analysis/helper.cc
paddle/fluid/inference/analysis/helper.cc
+8
-0
paddle/fluid/inference/analysis/helper.h
paddle/fluid/inference/analysis/helper.h
+54
-0
paddle/fluid/inference/analysis/ir_pass_manager.cc
paddle/fluid/inference/analysis/ir_pass_manager.cc
+19
-2
paddle/fluid/inference/analysis/ir_pass_manager.h
paddle/fluid/inference/analysis/ir_pass_manager.h
+3
-2
paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
...id/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
+46
-7
paddle/fluid/inference/analysis/passes/ir_graph_to_program_pass.cc
...uid/inference/analysis/passes/ir_graph_to_program_pass.cc
+5
-1
paddle/fluid/inference/api/analysis_config.cc
paddle/fluid/inference/api/analysis_config.cc
+5
-3
paddle/fluid/inference/api/analysis_predictor.cc
paddle/fluid/inference/api/analysis_predictor.cc
+73
-0
paddle/fluid/inference/api/analysis_predictor.h
paddle/fluid/inference/api/analysis_predictor.h
+15
-0
paddle/fluid/inference/api/paddle_analysis_config.h
paddle/fluid/inference/api/paddle_analysis_config.h
+7
-1
paddle/fluid/inference/tensorrt/CMakeLists.txt
paddle/fluid/inference/tensorrt/CMakeLists.txt
+1
-1
paddle/fluid/inference/tensorrt/engine.cc
paddle/fluid/inference/tensorrt/engine.cc
+7
-0
paddle/fluid/inference/tensorrt/engine.h
paddle/fluid/inference/tensorrt/engine.h
+16
-8
paddle/fluid/inference/tensorrt/trt_int8_calibrator.cc
paddle/fluid/inference/tensorrt/trt_int8_calibrator.cc
+147
-0
paddle/fluid/inference/tensorrt/trt_int8_calibrator.h
paddle/fluid/inference/tensorrt/trt_int8_calibrator.h
+128
-0
paddle/fluid/operators/tensorrt/tensorrt_engine_op.cc
paddle/fluid/operators/tensorrt/tensorrt_engine_op.cc
+7
-1
paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
+93
-4
paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc
paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc
+25
-25
paddle/fluid/pybind/inference_api.cc
paddle/fluid/pybind/inference_api.cc
+10
-3
未找到文件。
paddle/fluid/framework/ir/graph_traits.cc
浏览文件 @
97b76c94
...
...
@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/ir/graph_traits.h"
#include <set>
#include <vector>
namespace
paddle
{
...
...
@@ -79,7 +80,7 @@ NodesTSIterator::NodesTSIterator(const std::vector<Node *> &source) {
}
std
::
unordered_set
<
Node
*>
visited
;
std
::
unordered_
set
<
Node
*>
to_visit
{
source
.
begin
(),
source
.
end
()};
std
::
set
<
Node
*>
to_visit
{
source
.
begin
(),
source
.
end
()};
std
::
vector
<
Node
*>
inlink_visited
;
while
(
!
to_visit
.
empty
())
{
...
...
paddle/fluid/inference/analysis/argument.h
浏览文件 @
97b76c94
...
...
@@ -28,6 +28,7 @@
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/api/paddle_analysis_config.h"
#include "paddle/fluid/platform/variant.h"
namespace
paddle
{
...
...
@@ -130,6 +131,8 @@ struct Argument {
DECL_ARGUMENT_FIELD
(
tensorrt_max_batch_size
,
TensorRtMaxBatchSize
,
int
);
DECL_ARGUMENT_FIELD
(
tensorrt_workspace_size
,
TensorRtWorkspaceSize
,
int
);
DECL_ARGUMENT_FIELD
(
tensorrt_min_subgraph_size
,
TensorRtMinSubgraphSize
,
int
);
DECL_ARGUMENT_FIELD
(
tensorrt_precision_mode
,
TensorRtPrecisionMode
,
contrib
::
AnalysisConfig
::
Precision
);
// Memory optimized related.
DECL_ARGUMENT_FIELD
(
enable_memory_optim
,
EnableMemoryOptim
,
bool
);
...
...
paddle/fluid/inference/analysis/helper.cc
浏览文件 @
97b76c94
...
...
@@ -36,6 +36,14 @@ void SetAttr<int>(framework::proto::OpDesc *op, const std::string &name,
attr
->
set_i
(
data
);
}
template
<
>
void
SetAttr
<
bool
>
(
framework
::
proto
::
OpDesc
*
op
,
const
std
::
string
&
name
,
const
bool
&
data
)
{
auto
*
attr
=
op
->
add_attrs
();
attr
->
set_name
(
name
);
attr
->
set_type
(
paddle
::
framework
::
proto
::
AttrType
::
BOOLEAN
);
attr
->
set_b
(
data
);
}
template
<
>
void
SetAttr
<
int64_t
>
(
framework
::
proto
::
OpDesc
*
op
,
const
std
::
string
&
name
,
const
int64_t
&
data
)
{
auto
*
attr
=
op
->
add_attrs
();
...
...
paddle/fluid/inference/analysis/helper.h
浏览文件 @
97b76c94
...
...
@@ -17,6 +17,7 @@ limitations under the License. */
#include <sys/stat.h>
#include <cstdio>
#include <fstream>
#include <set>
#include <string>
#include <typeindex>
#include <unordered_map>
...
...
@@ -29,9 +30,14 @@ limitations under the License. */
#include "paddle/fluid/platform/port.h"
#ifdef _WIN32
#include <direct.h>
#include <io.h>
#define GCC_ATTRIBUTE(attr__) ;
#define MKDIR(path) _mkdir(path)
#else
#include <unistd.h>
#define GCC_ATTRIBUTE(attr__) __attribute__((attr__));
#define MKDIR(path) mkdir(path, S_IRWXU | S_IRWXG | S_IROTH | S_IXOTH)
#endif
#define __SHOULD_USE_RESULT__ GCC_ATTRIBUTE(warn_unused_result)
...
...
@@ -163,6 +169,54 @@ static bool PathExists(const std::string &path) {
return
false
;
}
static
std
::
string
GetDirRoot
(
const
std
::
string
&
path
)
{
char
sep
=
'/'
;
#ifdef _WIN32
sep
=
'\\'
;
#endif
size_t
i
=
path
.
rfind
(
sep
,
path
.
length
());
if
(
i
!=
std
::
string
::
npos
)
{
return
(
path
.
substr
(
0
,
i
));
}
return
path
;
}
static
std
::
string
GetOrCreateModelOptCacheDir
(
const
std
::
string
&
model_root
)
{
std
::
string
opt_cache_dir
=
model_root
+
"/_opt_cache/"
;
if
(
!
PathExists
(
opt_cache_dir
))
{
PADDLE_ENFORCE
(
MKDIR
(
opt_cache_dir
.
c_str
())
!=
-
1
,
"Can not create optimize cache directory: %s, Make sure you "
"have permission to write"
,
opt_cache_dir
);
}
return
opt_cache_dir
;
}
static
std
::
string
GetTrtCalibPath
(
const
std
::
string
&
model_root
,
const
std
::
string
&
engine_key
)
{
return
model_root
+
"/trt_calib_"
+
engine_key
;
}
// If there is no calib table data file in model_opt_cache_dir, return "".
static
std
::
string
GetTrtCalibTableData
(
const
std
::
string
&
model_opt_cache_dir
,
const
std
::
string
&
engine_key
,
bool
enable_int8
)
{
std
::
string
trt_calib_table_path
=
GetTrtCalibPath
(
model_opt_cache_dir
,
engine_key
);
if
(
enable_int8
&&
FileExists
(
trt_calib_table_path
))
{
VLOG
(
3
)
<<
"Calibration table file: "
<<
trt_calib_table_path
<<
"is found here"
;
std
::
ifstream
infile
(
trt_calib_table_path
,
std
::
ios
::
in
);
std
::
stringstream
buffer
;
buffer
<<
infile
.
rdbuf
();
std
::
string
calibration_data
(
buffer
.
str
());
return
calibration_data
;
}
return
""
;
}
}
// namespace analysis
}
// namespace inference
}
// namespace paddle
...
...
paddle/fluid/inference/analysis/ir_pass_manager.cc
浏览文件 @
97b76c94
...
...
@@ -67,6 +67,20 @@ void IRPassManager::CreatePasses(Argument *argument,
pass
->
Set
(
"max_batch_size"
,
new
int
(
argument
->
tensorrt_max_batch_size
()));
pass
->
Set
(
"min_subgraph_size"
,
new
int
(
argument
->
tensorrt_min_subgraph_size
()));
pass
->
Set
(
"program"
,
new
framework
::
ProgramDesc
*
(
&
argument
->
main_program
()));
bool
enable_int8
=
argument
->
tensorrt_precision_mode
()
==
contrib
::
AnalysisConfig
::
Precision
::
kInt8
;
pass
->
Set
(
"enable_int8"
,
new
bool
(
enable_int8
));
std
::
string
model_opt_cache_dir
=
argument
->
Has
(
"model_dir"
)
?
argument
->
model_dir
()
:
GetDirRoot
(
argument
->
model_program_path
());
pass
->
Set
(
"model_opt_cache_dir"
,
new
std
::
string
(
GetOrCreateModelOptCacheDir
(
model_opt_cache_dir
)));
}
// graph_ = pass->Apply(std::move(graph_));
...
...
@@ -91,11 +105,14 @@ std::unique_ptr<Graph> IRPassManager::Apply(std::unique_ptr<Graph> graph) {
}
framework
::
proto
::
ProgramDesc
IRPassManager
::
AcquireProgram
(
std
::
unique_ptr
<
Graph
>
*
graph
,
const
ProgramDesc
&
program
)
const
{
std
::
unique_ptr
<
Graph
>
*
graph
,
ProgramDesc
*
program
)
const
{
auto
pass
=
framework
::
ir
::
PassRegistry
::
Instance
().
Get
(
"graph_to_program_pass"
);
ProgramDesc
desc
(
program
);
// Direct using ProgramDesc desc(argument->main_program()) may cause
// incomplete copies of information.
ProgramDesc
desc
;
desc
.
CopyFrom
(
*
program
->
Proto
());
pass
->
SetNotOwned
(
"program"
,
&
desc
);
auto
*
the_graph
=
graph
->
release
();
*
graph
=
pass
->
Apply
(
std
::
unique_ptr
<
Graph
>
(
the_graph
));
...
...
paddle/fluid/inference/analysis/ir_pass_manager.h
浏览文件 @
97b76c94
...
...
@@ -29,6 +29,7 @@
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/analysis/argument.h"
#include "paddle/fluid/inference/analysis/helper.h"
namespace
paddle
{
namespace
inference
{
...
...
@@ -42,8 +43,8 @@ class IRPassManager final {
std
::
unique_ptr
<
Graph
>
Apply
(
std
::
unique_ptr
<
Graph
>
graph
);
framework
::
proto
::
ProgramDesc
AcquireProgram
(
std
::
unique_ptr
<
Graph
>
*
graph
,
const
ProgramDesc
&
program
)
const
;
framework
::
proto
::
ProgramDesc
AcquireProgram
(
std
::
unique_ptr
<
Graph
>
*
graph
,
ProgramDesc
*
program
)
const
;
framework
::
ir
::
Graph
&
graph
()
const
{
return
*
graph_
;
}
...
...
paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
浏览文件 @
97b76c94
...
...
@@ -13,6 +13,7 @@
// limitations under the License.
#include <algorithm>
#include <set>
#include <string>
#include <vector>
...
...
@@ -67,12 +68,33 @@ std::unique_ptr<framework::ir::Graph> analysis::TensorRtSubgraphPass::ApplyImpl(
return
graph
;
}
std
::
string
GenerateEngineKey
(
const
std
::
set
<
std
::
string
>
&
engine_inputs
,
const
std
::
set
<
std
::
string
>
&
engine_outputs
)
{
std
::
string
engine_hash_key
=
""
;
for
(
auto
name
:
engine_inputs
)
{
engine_hash_key
+=
name
;
}
for
(
auto
name
:
engine_outputs
)
{
engine_hash_key
+=
name
;
}
auto
engine_key
=
std
::
to_string
(
std
::
hash
<
std
::
string
>
()(
engine_hash_key
));
return
engine_key
;
}
void
TensorRtSubgraphPass
::
CreateTensorRTOp
(
framework
::
ir
::
Node
*
node
,
Graph
*
graph
)
const
{
auto
*
op_desc
=
node
->
Op
();
auto
&
subgraph
=
*
Agent
(
node
).
subgraph
();
PADDLE_ENFORCE
(
!
subgraph
.
empty
());
framework
::
ProgramDesc
*
program_desc
=
Get
<
framework
::
ProgramDesc
*>
(
"program"
);
// Add new block for TensorRTEngineOP
const
framework
::
BlockDesc
&
main_block
=
program_desc
->
Block
(
framework
::
kRootBlockIndex
);
// const framework::BlockDesc& main_block = program_desc->Block(0);
framework
::
BlockDesc
*
new_block
=
program_desc
->
AppendBlock
(
main_block
);
// An fake block desc.
framework
::
proto
::
BlockDesc
block_proto
;
framework
::
BlockDesc
block_desc
(
nullptr
,
&
block_proto
);
...
...
@@ -82,13 +104,18 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
subgraph
.
size
());
for
(
auto
*
node
:
subgraph
)
{
auto
*
new_block_op
=
new_block
->
AppendOp
();
auto
*
op
=
block_desc
.
AppendOp
();
*
new_block_op
->
Proto
()
=
*
node
->
Op
()
->
Proto
();
*
op
->
Proto
()
=
*
node
->
Op
()
->
Proto
();
}
// collect inputs
std
::
unordered_set
<
std
::
string
>
input_names
;
std
::
unordered_set
<
std
::
string
>
input_names_with_id
;
// Then, we will use the input_names_with_id and output_names_with_id to
// generate the eigine key.
// So, We use set instead of unordered_set here to ensure that the engine key
// is unique.
std
::
set
<
std
::
string
>
input_names
;
std
::
set
<
std
::
string
>
input_names_with_id
;
for
(
auto
*
x
:
node
->
inputs
)
{
input_names
.
insert
(
x
->
Name
());
input_names_with_id
.
insert
(
x
->
Name
()
+
std
::
to_string
(
x
->
id
()));
...
...
@@ -96,8 +123,8 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
op_desc
->
SetInput
(
"Xs"
,
std
::
vector
<
std
::
string
>
(
input_names
.
begin
(),
input_names
.
end
()));
std
::
unordered_
set
<
std
::
string
>
output_names
;
std
::
unordered_
set
<
std
::
string
>
output_names_with_id
;
std
::
set
<
std
::
string
>
output_names
;
std
::
set
<
std
::
string
>
output_names_with_id
;
for
(
auto
*
x
:
node
->
outputs
)
{
output_names
.
insert
(
x
->
Name
());
output_names_with_id
.
insert
(
x
->
Name
()
+
std
::
to_string
(
x
->
id
()));
...
...
@@ -182,7 +209,6 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
// to Tensor.
std
::
vector
<
std
::
string
>
output_mapping
;
for
(
auto
name
:
output_names
)
{
// LOG(INFO) << name << " " << output_name_map.size();
PADDLE_ENFORCE
(
output_name_map
.
count
(
name
)
!=
0
);
output_mapping
.
push_back
(
output_name_map
[
name
]);
}
...
...
@@ -193,16 +219,29 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
*
vars
->
Add
()
=
*
node
->
Var
()
->
Proto
();
}
}
PADDLE_ENFORCE
(
!
block_desc
.
Proto
()
->
vars
().
empty
(),
"the block has no var-desc"
);
PADDLE_ENFORCE
(
!
output_mapping
.
empty
());
// Set attrs
op_desc
->
SetBlockAttr
(
"sub_block"
,
new_block
);
SetAttr
(
op_desc
->
Proto
(),
"subgraph"
,
block_desc
.
Proto
()
->
SerializeAsString
());
// Set attrs
SetAttr
(
op_desc
->
Proto
(),
"max_batch_size"
,
Get
<
int
>
(
"max_batch_size"
));
SetAttr
(
op_desc
->
Proto
(),
"workspace_size"
,
Get
<
int
>
(
"workspace_size"
));
SetAttr
(
op_desc
->
Proto
(),
"parameters"
,
ExtractParameters
(
graph
->
Nodes
()));
SetAttr
(
op_desc
->
Proto
(),
"output_name_mapping"
,
output_mapping
);
auto
enable_int8
=
Get
<
bool
>
(
"enable_int8"
);
auto
engine_key
=
GenerateEngineKey
(
input_names_with_id
,
output_names_with_id
);
std
::
string
calibration_data
=
GetTrtCalibTableData
(
Get
<
std
::
string
>
(
"model_opt_cache_dir"
),
engine_key
,
enable_int8
);
SetAttr
(
op_desc
->
Proto
(),
"calibration_data"
,
calibration_data
);
SetAttr
(
op_desc
->
Proto
(),
"enable_int8"
,
enable_int8
);
SetAttr
(
op_desc
->
Proto
(),
"engine_key"
,
engine_key
);
}
std
::
vector
<
std
::
string
>
ExtractParameters
(
...
...
paddle/fluid/inference/analysis/passes/ir_graph_to_program_pass.cc
浏览文件 @
97b76c94
...
...
@@ -31,7 +31,11 @@ void IrGraphToProgramPass::RunImpl(Argument *argument) {
}
std
::
unique_ptr
<
Graph
>
graph
(
argument
->
main_graph_ptr
());
framework
::
ProgramDesc
desc
(
argument
->
main_program
());
// Direct using ProgramDesc desc(argument->main_program()) may cause
// incomplete copies of information.
framework
::
ProgramDesc
desc
;
desc
.
CopyFrom
(
*
argument
->
main_program
().
Proto
());
pass
->
SetNotOwned
(
"program"
,
&
desc
);
auto
thegraph
=
pass
->
Apply
(
std
::
move
(
graph
));
thegraph
.
release
();
// the argument still own the graph.
...
...
paddle/fluid/inference/api/analysis_config.cc
浏览文件 @
97b76c94
...
...
@@ -102,6 +102,7 @@ contrib::AnalysisConfig::AnalysisConfig(const contrib::AnalysisConfig &other) {
CP_MEMBER
(
tensorrt_workspace_size_
);
CP_MEMBER
(
tensorrt_max_batchsize_
);
CP_MEMBER
(
tensorrt_min_subgraph_size_
);
CP_MEMBER
(
tensorrt_precision_mode_
);
// MKLDNN releated.
CP_MEMBER
(
use_mkldnn_
);
CP_MEMBER
(
mkldnn_enabled_op_types_
);
...
...
@@ -141,9 +142,9 @@ void contrib::AnalysisConfig::EnableMKLDNN() {
Update
();
}
void
contrib
::
AnalysisConfig
::
EnableTensorRtEngine
(
int
workspace_size
,
int
max_batc
h_size
,
int
min_subgraph_siz
e
)
{
void
contrib
::
AnalysisConfig
::
EnableTensorRtEngine
(
int
workspace_size
,
int
max_batch_size
,
int
min_subgrap
h_size
,
contrib
::
AnalysisConfig
::
Precision
precision_mod
e
)
{
#ifdef PADDLE_WITH_CUDA
if
(
!
use_gpu
())
{
LOG
(
ERROR
)
<<
"To use TensorRT engine, please call EnableGpu() first"
;
...
...
@@ -154,6 +155,7 @@ void contrib::AnalysisConfig::EnableTensorRtEngine(int workspace_size,
tensorrt_workspace_size_
=
workspace_size
;
tensorrt_max_batchsize_
=
max_batch_size
;
tensorrt_min_subgraph_size_
=
min_subgraph_size
;
tensorrt_precision_mode_
=
precision_mode
;
Update
();
#else
...
...
paddle/fluid/inference/api/analysis_predictor.cc
浏览文件 @
97b76c94
...
...
@@ -15,6 +15,7 @@
#include "paddle/fluid/inference/api/analysis_predictor.h"
#include <glog/logging.h>
#include <algorithm>
#include <fstream>
#include <memory>
#include <string>
#include <vector>
...
...
@@ -25,6 +26,7 @@
#include "paddle/fluid/framework/naive_executor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/var_type_traits.h"
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/analysis/passes/memory_optimize_pass.h"
#include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
...
...
@@ -37,6 +39,8 @@
#if PADDLE_WITH_TENSORRT
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/trt_int8_calibrator.h"
#endif
DECLARE_bool
(
profile
);
...
...
@@ -44,6 +48,12 @@ DECLARE_bool(profile);
namespace
paddle
{
using
contrib
::
AnalysisConfig
;
using
inference
::
Singleton
;
#if PADDLE_WITH_TENSORRT
using
inference
::
tensorrt
::
TRTInt8Calibrator
;
using
inference
::
tensorrt
::
TRTCalibratorEngine
;
using
inference
::
tensorrt
::
TRTCalibratorEngineManager
;
#endif
namespace
{
bool
IsPersistable
(
const
framework
::
VarDesc
*
var
)
{
...
...
@@ -339,6 +349,8 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
!
config_
.
params_file
().
empty
(),
"Either model_dir or (param_file, prog_file) should be set."
);
PADDLE_ENFORCE
(
!
config_
.
prog_file
().
empty
());
std
::
string
dir
=
inference
::
analysis
::
GetDirRoot
(
config_
.
prog_file
());
argument_
.
SetModelProgramPath
(
config_
.
prog_file
());
argument_
.
SetModelParamsPath
(
config_
.
params_file
());
}
...
...
@@ -349,6 +361,7 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
argument_
.
SetTensorRtWorkspaceSize
(
config_
.
tensorrt_workspace_size_
);
argument_
.
SetTensorRtMaxBatchSize
(
config_
.
tensorrt_max_batchsize_
);
argument_
.
SetTensorRtMinSubgraphSize
(
config_
.
tensorrt_min_subgraph_size_
);
argument_
.
SetTensorRtPrecisionMode
(
config_
.
tensorrt_precision_mode_
);
}
if
(
config_
.
use_mkldnn_
)
{
...
...
@@ -569,7 +582,67 @@ bool AnalysisPredictor::LoadParameters() {
return
true
;
}
#if PADDLE_WITH_TENSORRT
bool
AnalysisPredictor
::
SaveTrtCalibToDisk
()
{
PADDLE_ENFORCE
(
config_
.
tensorrt_engine_enabled
(),
"This func can be invoked only in trt mode"
);
auto
&
block
=
inference_program_
->
Block
(
0
);
for
(
auto
&
op_desc
:
block
.
AllOps
())
{
if
(
op_desc
->
Type
()
==
"tensorrt_engine"
)
{
std
::
string
engine_name
=
boost
::
get
<
std
::
string
>
(
op_desc
->
GetAttr
(
"engine_key"
));
if
(
!
Singleton
<
TRTCalibratorEngineManager
>::
Global
().
Has
(
engine_name
))
{
LOG
(
ERROR
)
<<
"You should run the predictor(with trt) on the real data "
"to generate calibration info"
;
return
false
;
}
TRTCalibratorEngine
*
calib_engine
=
Singleton
<
TRTCalibratorEngineManager
>::
Global
().
Get
(
engine_name
);
LOG
(
INFO
)
<<
"Wait for calib threads done."
;
calib_engine
->
calib_
->
waitAndSetDone
();
LOG
(
INFO
)
<<
"Generating TRT Calibration table data, this may cost a lot "
"of time..."
;
calib_engine
->
thr_
->
join
();
std
::
string
calibration_table_data
=
calib_engine
->
calib_
->
getCalibrationTableAsString
();
if
(
calibration_table_data
.
empty
())
{
LOG
(
ERROR
)
<<
"the calibration table is empty."
;
return
false
;
}
std
::
string
model_opt_cache_dir
=
argument_
.
Has
(
"model_dir"
)
?
argument_
.
model_dir
()
:
inference
::
analysis
::
GetDirRoot
(
argument_
.
model_program_path
());
std
::
string
calibration_table_data_path
=
inference
::
analysis
::
GetTrtCalibPath
(
inference
::
analysis
::
GetOrCreateModelOptCacheDir
(
model_opt_cache_dir
),
engine_name
);
std
::
ofstream
ofile
(
calibration_table_data_path
,
std
::
ios
::
out
);
LOG
(
INFO
)
<<
"Write Paddle-TRT INT8 calibration table data to file "
<<
calibration_table_data_path
;
ofile
<<
calibration_table_data
;
ofile
.
close
();
}
}
// Free all calibrator resources.
Singleton
<
TRTCalibratorEngineManager
>::
Global
().
DeleteALL
();
return
true
;
}
#endif
AnalysisPredictor
::~
AnalysisPredictor
()
{
#if PADDLE_WITH_TENSORRT
if
(
config_
.
tensorrt_engine_enabled
()
&&
config_
.
tensorrt_precision_mode_
==
AnalysisConfig
::
Precision
::
kInt8
&&
Singleton
<
TRTCalibratorEngineManager
>::
Global
().
Has
())
{
SaveTrtCalibToDisk
();
}
#endif
if
(
FLAGS_profile
)
{
platform
::
DisableProfiler
(
platform
::
EventSortingKey
::
kTotal
,
"./profile.log"
);
...
...
paddle/fluid/inference/api/analysis_predictor.h
浏览文件 @
97b76c94
...
...
@@ -97,6 +97,21 @@ class AnalysisPredictor : public PaddlePredictor {
void
GetFetchOne
(
const
framework
::
LoDTensor
&
fetchs
,
PaddleTensor
*
output_data
);
#if PADDLE_WITH_TENSORRT
// When we use Paddle-TRT INT8 engine, we need to generate calibration table
// data first,
// the calibration table contains the range for each op's input and output,
// this whole process can be divided into several steps:
//
// 1. Builds a 32-bit engine, runs it on the calibration set, and records a
// histogram for each
// tensor of the distribution of activation values.
// 2. Builds a calibration table from the histograms.
//
// After step 2, we need to store the calibration table on disk
bool
SaveTrtCalibToDisk
();
#endif
// Some more detailed tests, they are made the friends of the predictor, so that
// the all the details can be tested.
#if PADDLE_WITH_TESTING
...
...
paddle/fluid/inference/api/paddle_analysis_config.h
浏览文件 @
97b76c94
...
...
@@ -42,6 +42,10 @@ struct AnalysisConfig {
explicit
AnalysisConfig
(
const
std
::
string
&
model_dir
);
explicit
AnalysisConfig
(
const
std
::
string
&
prog_file
,
const
std
::
string
&
params_file
);
enum
class
Precision
{
kFloat32
=
0
,
kInt8
,
};
/** Set model with a directory.
*/
...
...
@@ -135,7 +139,8 @@ struct AnalysisConfig {
* subgraph is less than this, it will not transfer to TensorRT engine.
*/
void
EnableTensorRtEngine
(
int
workspace_size
=
1
<<
20
,
int
max_batch_size
=
1
,
int
min_subgraph_size
=
3
);
int
max_batch_size
=
1
,
int
min_subgraph_size
=
3
,
Precision
precision
=
Precision
::
kFloat32
);
/** A boolean state telling whether the TensorRT engine is used.
*/
bool
tensorrt_engine_enabled
()
const
{
return
use_tensorrt_
;
}
...
...
@@ -229,6 +234,7 @@ struct AnalysisConfig {
// We set this variable to control the minimum number of nodes in the
// subgraph, 3 as default value.
int
tensorrt_min_subgraph_size_
{
3
};
Precision
tensorrt_precision_mode_
;
// memory reuse related.
bool
enable_memory_optim_
{
false
};
...
...
paddle/fluid/inference/tensorrt/CMakeLists.txt
浏览文件 @
97b76c94
nv_library
(
tensorrt_engine SRCS engine.cc DEPS
${
GLOB_OPERATOR_DEPS
}
framework_proto device_context
)
nv_library
(
tensorrt_engine SRCS engine.cc
trt_int8_calibrator.cc
DEPS
${
GLOB_OPERATOR_DEPS
}
framework_proto device_context
)
nv_library
(
tensorrt_op_teller SRCS op_teller.cc DEPS framework_proto
)
nv_test
(
test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader
)
nv_test
(
test_tensorrt_engine SRCS test_engine.cc DEPS dynload_cuda tensorrt_engine
)
...
...
paddle/fluid/inference/tensorrt/engine.cc
浏览文件 @
97b76c94
...
...
@@ -69,6 +69,13 @@ void TensorRTEngine::FreezeNetwork() {
// build engine.
infer_builder_
->
setMaxBatchSize
(
max_batch_
);
infer_builder_
->
setMaxWorkspaceSize
(
max_workspace_
);
if
(
enable_int8_
)
{
infer_builder_
->
setInt8Mode
(
true
);
PADDLE_ENFORCE
(
calibrator_
!=
nullptr
,
"The precision mode is 'INT8', the calibrator should not be nullptr"
);
infer_builder_
->
setInt8Calibrator
(
calibrator_
);
}
infer_engine_
.
reset
(
infer_builder_
->
buildCudaEngine
(
*
infer_network_
));
PADDLE_ENFORCE
(
infer_engine_
!=
nullptr
,
"build cuda engine failed!"
);
...
...
paddle/fluid/inference/tensorrt/engine.h
浏览文件 @
97b76c94
...
...
@@ -23,12 +23,14 @@ limitations under the License. */
#include "paddle/fluid/inference/engine.h"
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/inference/tensorrt/trt_int8_calibrator.h"
#include "paddle/fluid/inference/utils/singleton.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
class
TRTInt8Calibrator
;
/*
* TensorRT Engine.
*
...
...
@@ -55,13 +57,16 @@ class TensorRTEngine : public EngineBase {
};
TensorRTEngine
(
int
max_batch
,
int
max_workspace
,
cudaStream_t
stream
,
int
device
=
0
,
int
device
=
0
,
bool
enable_int8
=
false
,
TRTInt8Calibrator
*
calibrator
=
nullptr
,
nvinfer1
::
ILogger
&
logger
=
NaiveLogger
::
Global
())
:
max_batch_
(
max_batch
),
max_workspace_
(
max_workspace
),
stream_
(
stream
),
logger_
(
logger
),
device_
(
device
)
{}
device_
(
device
),
enable_int8_
(
enable_int8
),
calibrator_
(
calibrator
),
logger_
(
logger
)
{}
virtual
~
TensorRTEngine
();
...
...
@@ -139,8 +144,8 @@ class TensorRTEngine : public EngineBase {
// In the normal case, the paddle-trt exists bug when runing the googlenet.
// When there are more than two convolutions of 1 * 1 with the same input, the
// paddle-tensorrt will do the merging optimization, which fuse those conv
// into
//
one conv, and then trigger bug. So, We should use strategy to avoid
this
// into
one conv, and then trigger bug. So, We should use strategy to avoid
// this
// optimization for the time being. This bug will be fixed in the future.
std
::
unordered_map
<
std
::
string
/*name*/
,
int
/*ITensor_quote_num*/
>
itensor_quote_num
;
...
...
@@ -153,9 +158,14 @@ class TensorRTEngine : public EngineBase {
// the max memory size the engine uses
int
max_workspace_
;
cudaStream_t
stream_
;
// The specific GPU id that the TensorRTEngine bounded to.
int
device_
;
bool
enable_int8_
;
TRTInt8Calibrator
*
calibrator_
;
// batch size of the current data, will be updated each Executation.
int
batch_size_
{
-
1
};
cudaStream_t
stream_
;
nvinfer1
::
ILogger
&
logger_
;
...
...
@@ -165,8 +175,6 @@ class TensorRTEngine : public EngineBase {
std
::
unordered_map
<
std
::
string
/*name*/
,
nvinfer1
::
ITensor
*
/*ITensor*/
>
itensor_map_
;
// The specific GPU id that the TensorRTEngine bounded to.
int
device_
;
std
::
vector
<
std
::
unique_ptr
<
plugin
::
PluginTensorRT
>>
owned_plugin_
;
// TensorRT related internal members
...
...
paddle/fluid/inference/tensorrt/trt_int8_calibrator.cc
0 → 100644
浏览文件 @
97b76c94
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/tensorrt/trt_int8_calibrator.h"
#include "glog/logging.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
// set the batch size before constructing the thread to execute engine
int
TRTInt8Calibrator
::
getBatchSize
()
const
{
return
batch_size_
;
}
TRTInt8Calibrator
::
TRTInt8Calibrator
(
const
std
::
unordered_map
<
std
::
string
,
size_t
>&
buffers
,
int
batch_size
,
std
::
string
engine_name
,
const
platform
::
Place
place
)
:
batch_size_
(
batch_size
),
engine_name_
(
engine_name
)
{
int
i
=
0
;
VLOG
(
4
)
<<
"Init a new calibrator: "
<<
engine_name_
;
for
(
const
auto
it
:
buffers
)
{
framework
::
Tensor
temp_tensor
;
std
::
string
input_name
=
it
.
first
;
int
data_size
=
it
.
second
;
int
num_ele
=
data_size
/
sizeof
(
int16_t
);
framework
::
DDim
data_shape
=
framework
::
make_ddim
({
num_ele
});
temp_tensor
.
Resize
(
data_shape
);
data_tensors_
.
push_back
(
temp_tensor
);
data_buffers_
[
input_name
]
=
std
::
pair
<
void
*
,
size_t
>
(
static_cast
<
void
*>
(
temp_tensor
.
mutable_data
<
int16_t
>
(
place
)),
num_ele
);
i
+=
1
;
}
}
TRTInt8Calibrator
::
TRTInt8Calibrator
(
const
std
::
string
&
calib_data
)
:
batch_size_
(
0
),
calib_running_
(
false
),
data_is_set_
(
false
),
done_
(
true
),
calibration_table_
(
calib_data
)
{}
void
TRTInt8Calibrator
::
waitAndSetDone
()
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
mut_
);
while
((
calib_running_
||
data_is_set_
)
&&
!
done_
)
cond_
.
wait
(
lk
);
if
(
!
done_
)
{
done_
=
true
;
cond_
.
notify_all
();
}
}
// There might be more than one input for trt subgraph,
// So, we use a map to store input information.
bool
TRTInt8Calibrator
::
setBatch
(
const
std
::
unordered_map
<
std
::
string
,
void
*>&
data
)
{
VLOG
(
3
)
<<
"set batch: "
<<
engine_name_
;
std
::
unique_lock
<
std
::
mutex
>
lk
(
mut_
);
// There is a producer and a consumer. The producer set the batch data and
// the consumer get the batch data. The size of the data pool is one.
// So, the producer has to wait for the consumer to finish processing before
// they can set the data.
while
((
calib_running_
||
data_is_set_
)
&&
(
!
done_
))
cond_
.
wait
(
lk
);
// The done_ is set to true using waitAndSetDone, When all calibration data
// are processed.
if
(
done_
)
return
false
;
// Sets the batch.
for
(
const
auto
&
it
:
data
)
{
auto
dataptr
=
data_buffers_
.
find
(
it
.
first
);
if
(
dataptr
==
data_buffers_
.
end
())
{
LOG
(
FATAL
)
<<
"FATAL "
<<
engine_name_
<<
" input name '"
<<
it
.
first
<<
"' does not match with the buffer names"
;
}
const
auto
&
d
=
dataptr
->
second
;
PADDLE_ENFORCE
(
cudaMemcpy
(
d
.
first
,
it
.
second
,
d
.
second
,
cudaMemcpyDeviceToDevice
),
"Fail to cudaMemcpy %s for %s"
,
engine_name_
,
it
.
first
);
}
data_is_set_
=
true
;
cond_
.
notify_all
();
return
true
;
}
bool
TRTInt8Calibrator
::
getBatch
(
void
**
bindings
,
const
char
**
names
,
int
num_bindings
)
{
VLOG
(
4
)
<<
"get batch: "
<<
engine_name_
;
std
::
unique_lock
<
std
::
mutex
>
lk
(
mut_
);
// The consumer has just finished processing a data.
// The producer can set the data again.
calib_running_
=
false
;
cond_
.
notify_all
();
// As long as there is data in the pool, the consumer can get it.
while
(
!
data_is_set_
&&
!
done_
)
cond_
.
wait
(
lk
);
if
(
done_
)
return
false
;
// Gets the batch
for
(
int
i
=
0
;
i
<
num_bindings
;
i
++
)
{
auto
it
=
data_buffers_
.
find
(
names
[
i
]);
if
(
it
==
data_buffers_
.
end
())
{
LOG
(
FATAL
)
<<
"Calibration engine asked for unknown tensor name '"
<<
names
[
i
]
<<
"' at position "
<<
i
;
}
bindings
[
i
]
=
it
->
second
.
first
;
}
data_is_set_
=
false
;
calib_running_
=
true
;
VLOG
(
4
)
<<
"get batch done: "
<<
engine_name_
;
return
true
;
}
void
TRTInt8Calibrator
::
setDone
()
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
mut_
);
done_
=
true
;
cond_
.
notify_all
();
}
const
void
*
TRTInt8Calibrator
::
readCalibrationCache
(
size_t
&
length
)
{
if
(
calibration_table_
.
empty
())
return
nullptr
;
length
=
calibration_table_
.
size
();
return
calibration_table_
.
data
();
}
void
TRTInt8Calibrator
::
writeCalibrationCache
(
const
void
*
ptr
,
std
::
size_t
length
)
{
calibration_table_
=
std
::
string
((
const
char
*
)
ptr
,
length
);
VLOG
(
4
)
<<
"Got calibration data for "
<<
engine_name_
<<
" "
<<
ptr
<<
" length="
<<
length
;
}
TRTInt8Calibrator
::~
TRTInt8Calibrator
()
{
VLOG
(
4
)
<<
"Destroying calibrator for "
<<
engine_name_
;
}
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/trt_int8_calibrator.h
0 → 100644
浏览文件 @
97b76c94
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <atomic>
#include <memory>
#include <mutex>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include <NvInfer.h>
#include <cuda_runtime_api.h>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/platform/place.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
class
TensorRTEngine
;
struct
TRTInt8Calibrator
:
public
nvinfer1
::
IInt8EntropyCalibrator
{
public:
TRTInt8Calibrator
(
const
std
::
unordered_map
<
std
::
string
,
size_t
>&
buffers
,
int
batch_size
,
std
::
string
engine_name
,
const
platform
::
Place
place
);
explicit
TRTInt8Calibrator
(
const
std
::
string
&
calibration_data
);
~
TRTInt8Calibrator
();
int
getBatchSize
()
const
override
;
bool
getBatch
(
void
*
bindings
[],
const
char
*
names
[],
int
num_bindings
)
override
;
bool
setBatch
(
const
std
::
unordered_map
<
std
::
string
,
void
*>&
data
);
void
setDone
();
void
waitAndSetDone
();
const
void
*
readCalibrationCache
(
std
::
size_t
&
length
)
override
;
void
writeCalibrationCache
(
const
void
*
ptr
,
std
::
size_t
length
)
override
;
const
std
::
string
&
getCalibrationTableAsString
()
{
return
calibration_table_
;
}
private:
const
int
batch_size_
;
bool
calib_running_
{
true
};
bool
data_is_set_
{
false
};
bool
done_
{
false
};
std
::
mutex
mut_
;
std
::
condition_variable
cond_
;
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
void
*
,
size_t
>>
data_buffers_
;
std
::
vector
<
framework
::
Tensor
>
data_tensors_
;
std
::
string
engine_name_
;
std
::
string
calibration_table_
;
};
class
TRTCalibratorEngine
{
public:
TRTCalibratorEngine
()
{}
std
::
unique_ptr
<
TRTInt8Calibrator
>
calib_
;
std
::
unique_ptr
<
std
::
thread
>
thr_
;
std
::
unique_ptr
<
TensorRTEngine
>
engine_
;
};
/*
* Manager to control the TensorRT Int8 calibration creation and deltetion.
*/
class
TRTCalibratorEngineManager
{
public:
bool
Has
()
const
{
return
res_
.
size
()
>
0
;
}
bool
Has
(
const
std
::
string
&
name
)
const
{
if
(
res_
.
count
(
name
)
==
0
)
return
false
;
return
res_
.
at
(
name
).
get
()
!=
nullptr
;
}
// Get Int8Calibrator via name
TRTCalibratorEngine
*
Get
(
const
std
::
string
&
name
)
const
{
return
res_
.
at
(
name
).
get
();
}
// Look up or create a calibrator.
TRTCalibratorEngine
*
LookupOrCreate
(
const
std
::
string
&
engine_name
)
{
if
(
res_
.
count
(
engine_name
)
==
0
)
{
auto
*
p
=
new
TRTCalibratorEngine
;
res_
[
engine_name
].
reset
(
p
);
}
return
res_
.
at
(
engine_name
).
get
();
}
// Create an Int8Calibrator
TRTCalibratorEngine
*
Create
(
const
std
::
string
&
engine_name
)
{
auto
*
p
=
new
TRTCalibratorEngine
;
res_
[
engine_name
].
reset
(
p
);
return
p
;
}
void
DeleteALL
()
{
for
(
auto
&
item
:
res_
)
{
item
.
second
.
reset
(
nullptr
);
}
}
private:
std
::
unordered_map
<
std
::
string
,
std
::
unique_ptr
<
TRTCalibratorEngine
>>
res_
;
};
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/operators/tensorrt/tensorrt_engine_op.cc
浏览文件 @
97b76c94
...
...
@@ -29,8 +29,14 @@ class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput
(
"Xs"
,
"A list of inputs."
).
AsDuplicable
();
AddOutput
(
"Ys"
,
"A list of outputs"
).
AsDuplicable
();
AddAttr
<
std
::
string
>
(
"subgraph"
,
"the subgraph."
);
AddAttr
<
std
::
string
>
(
"calibration_data"
,
"the calibration data for int8"
);
AddAttr
<
std
::
string
>
(
"engine_key"
,
"The engine_key here is used to distinguish different TRT Engines"
);
AddAttr
<
int
>
(
"max_batch_size"
,
"the maximum batch size."
);
AddAttr
<
int
>
(
"workspace_size"
,
"the workspace size."
);
AddAttr
<
framework
::
BlockDesc
*>
(
"sub_block"
,
"the trt block"
);
AddAttr
<
bool
>
(
"enable_int8"
,
"whether swith to int8 mode"
);
AddComment
(
"TensorRT engine operator."
);
}
};
...
...
@@ -47,6 +53,6 @@ class TensorRTEngineInferVarType : public framework::VarTypeInference {
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
tensorrt_engine
,
ops
::
TensorRTEngineOp
,
ops
::
TensorRTEngineOpMaker
);
ops
::
TensorRTEngineOpMaker
,
ops
::
TensorRTEngineOpMaker
);
#endif // PADDLE_WITH_CUDA
paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
浏览文件 @
97b76c94
...
...
@@ -17,8 +17,10 @@
#ifdef PADDLE_WITH_CUDA
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/inference/analysis/helper.h"
...
...
@@ -62,6 +64,9 @@ nvinfer1::Dims Vec2TRT_Dims(const std::vector<int64_t> &shape) {
using
inference
::
Singleton
;
using
inference
::
tensorrt
::
TensorRTEngine
;
using
inference
::
tensorrt
::
TRTInt8Calibrator
;
using
inference
::
tensorrt
::
TRTCalibratorEngine
;
using
inference
::
tensorrt
::
TRTCalibratorEngineManager
;
class
TensorRTEngineOp
:
public
framework
::
OperatorBase
{
private:
...
...
@@ -70,6 +75,11 @@ class TensorRTEngineOp : public framework::OperatorBase {
mutable
std
::
unique_ptr
<
TensorRTEngine
>
trt_engine_
;
int
max_batch_size_
;
int
workspace_size_
;
std
::
unique_ptr
<
TRTInt8Calibrator
>
calibrator_
;
bool
enable_int8_
;
std
::
string
calibration_data_
;
std
::
string
engine_key_
;
bool
calibration_mode_
;
public:
TensorRTEngineOp
(
const
std
::
string
&
type
,
...
...
@@ -80,19 +90,96 @@ class TensorRTEngineOp : public framework::OperatorBase {
input_names_
=
Inputs
(
"Xs"
);
max_batch_size_
=
Attr
<
int
>
(
"max_batch_size"
);
workspace_size_
=
Attr
<
int
>
(
"workspace_size"
);
enable_int8_
=
Attr
<
bool
>
(
"enable_int8"
);
calibration_data_
=
Attr
<
std
::
string
>
(
"calibration_data"
);
engine_key_
=
Attr
<
std
::
string
>
(
"engine_key"
);
auto
params
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"parameters"
);
for
(
const
auto
&
param
:
params
)
{
param_names_
.
insert
(
param
);
}
// calibration_mode is ture represents we need to
// generate the calibration table data.
calibration_mode_
=
(
enable_int8_
&&
calibration_data_
.
size
()
==
0
);
VLOG
(
4
)
<<
"calibration_mode: "
<<
calibration_mode_
;
if
(
enable_int8_
&&
calibration_data_
.
size
())
{
calibrator_
.
reset
(
new
TRTInt8Calibrator
(
calibration_data_
));
}
}
protected:
void
RunNativeImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
{
framework
::
Executor
executor
(
dev_place
);
auto
*
block
=
Attr
<
framework
::
BlockDesc
*>
(
"sub_block"
);
auto
*
program
=
block
->
Program
();
auto
&
current_scope
=
scope
.
NewScope
();
auto
ctx
=
executor
.
Prepare
(
*
program
,
block
->
ID
());
executor
.
RunPreparedContext
(
ctx
.
get
(),
&
current_scope
,
false
,
true
,
true
);
}
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
if
(
calibration_mode_
==
true
)
{
RunCalibration
(
scope
,
dev_place
);
return
;
}
RunTrt
(
scope
,
dev_place
);
}
void
RunCalibration
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
{
// This process will builds a 32-bit trt engine, runs it on the calibration
// set, and records a histogram for each
// tensor of the distribution of activation values.
LOG_FIRST_N
(
INFO
,
1
)
<<
"The TRT engine: "
<<
engine_key_
<<
" is running calibration trt int8... "
;
int
runtime_batch
=
1
;
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
dev_ctx
).
stream
();
if
(
!
Singleton
<
TRTCalibratorEngineManager
>::
Global
().
Has
(
engine_key_
))
{
TRTCalibratorEngine
*
calib_res
=
Singleton
<
TRTCalibratorEngineManager
>::
Global
().
Create
(
engine_key_
);
std
::
unordered_map
<
std
::
string
,
size_t
>
calib_buffers
;
for
(
auto
&
x
:
input_names_
)
{
if
(
param_names_
.
count
(
x
))
continue
;
auto
&
t
=
inference
::
analysis
::
GetFromScope
<
framework
::
LoDTensor
>
(
scope
,
x
);
calib_buffers
[
x
]
=
t
.
memory_size
();
auto
t_shape
=
framework
::
vectorize
(
t
.
dims
());
runtime_batch
=
t_shape
[
0
];
}
calib_res
->
calib_
.
reset
(
new
TRTInt8Calibrator
(
calib_buffers
,
runtime_batch
,
engine_key_
,
dev_place
));
calib_res
->
thr_
.
reset
(
new
std
::
thread
([
&
]()
{
calib_res
->
engine_
.
reset
(
new
TensorRTEngine
(
max_batch_size_
,
workspace_size_
,
stream
,
boost
::
get
<
platform
::
CUDAPlace
>
(
dev_place
).
device
,
enable_int8_
,
calib_res
->
calib_
.
get
()));
VLOG
(
3
)
<<
"start the calib trt engine thread"
;
Prepare
(
scope
,
dev_place
,
calib_res
->
engine_
.
get
());
}));
}
TRTInt8Calibrator
*
temp_calibrator
=
Singleton
<
TRTCalibratorEngineManager
>::
Global
()
.
Get
(
engine_key_
)
->
calib_
.
get
();
std
::
unordered_map
<
std
::
string
,
void
*>
calib_data
;
for
(
auto
&
x
:
Inputs
(
"Xs"
))
{
if
(
param_names_
.
count
(
x
))
continue
;
auto
&
t
=
inference
::
analysis
::
GetFromScope
<
framework
::
LoDTensor
>
(
scope
,
x
);
calib_data
.
emplace
(
x
,
t
.
data
<
void
>
());
}
temp_calibrator
->
setBatch
(
calib_data
);
RunNativeImpl
(
scope
,
dev_place
);
}
void
RunTrt
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
{
int
runtime_batch
=
1
;
...
...
@@ -101,9 +188,10 @@ class TensorRTEngineOp : public framework::OperatorBase {
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
dev_ctx
).
stream
();
if
(
trt_engine_
.
get
()
==
nullptr
)
{
trt_engine_
.
reset
(
new
TensorRTEngine
(
max_batch_size_
,
workspace_size_
,
stream
,
boost
::
get
<
platform
::
CUDAPlace
>
(
dev_place
).
device
));
trt_engine_
.
reset
(
new
TensorRTEngine
(
max_batch_size_
,
workspace_size_
,
stream
,
boost
::
get
<
platform
::
CUDAPlace
>
(
dev_place
).
device
,
enable_int8_
,
calibrator_
.
get
()));
Prepare
(
scope
,
dev_place
,
trt_engine_
.
get
());
}
...
...
@@ -173,7 +261,8 @@ class TensorRTEngineOp : public framework::OperatorBase {
void
Prepare
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
,
TensorRTEngine
*
engine
)
const
{
VLOG
(
4
)
<<
"Prepare engine"
;
LOG
(
INFO
)
<<
"Prepare TRT engine (Optimize model structure, Select OP "
"kernel etc). This process may cost a lot of time."
;
framework
::
proto
::
BlockDesc
block_desc
;
block_desc
.
ParseFromString
(
Attr
<
std
::
string
>
(
"subgraph"
));
...
...
paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc
浏览文件 @
97b76c94
...
...
@@ -96,19 +96,20 @@ TEST(TensorRTEngineOp, manual) {
engine_op_desc
.
SetType
(
"tensorrt_engine"
);
engine_op_desc
.
SetInput
(
"Xs"
,
std
::
vector
<
std
::
string
>
({
"x"
}));
engine_op_desc
.
SetOutput
(
"Ys"
,
std
::
vector
<
std
::
string
>
({
"z0"
}));
SetAttr
<
std
::
string
>
(
engine_op_desc
.
Proto
(),
"subgraph"
,
block_
->
SerializeAsString
());
SetAttr
<
int
>
(
engine_op_desc
.
Proto
(),
"max_batch_size"
,
2
);
SetAttr
<
int
>
(
engine_op_desc
.
Proto
(),
"workspace_size"
,
1
<<
20
);
SetAttr
<
std
::
string
>
(
engine_op_desc
.
Proto
(),
"engine_uniq_key"
,
"a_engine"
);
SetAttr
<
std
::
vector
<
std
::
string
>>
(
engine_op_desc
.
Proto
(),
"parameters"
,
std
::
vector
<
std
::
string
>
({}));
SetAttr
<
std
::
vector
<
std
::
string
>>
(
engine_op_desc
.
Proto
(),
"output_name_mapping"
,
std
::
vector
<
std
::
string
>
({
"z0"
}));
engine_op_desc
.
SetBlockAttr
(
"sub_block"
,
&
block_desc
);
engine_op_desc
.
SetAttr
(
"max_batch_size"
,
static_cast
<
int
>
(
2
));
engine_op_desc
.
SetAttr
(
"workspace_size"
,
static_cast
<
int
>
(
1
<<
20
));
engine_op_desc
.
SetAttr
(
"parameters"
,
std
::
vector
<
std
::
string
>
({}));
engine_op_desc
.
SetAttr
(
"engine_key"
,
std
::
string
(
"a_engine"
));
engine_op_desc
.
SetAttr
(
"calibration_data"
,
std
::
string
(
""
));
engine_op_desc
.
SetAttr
(
"enable_int8"
,
static_cast
<
bool
>
(
false
));
engine_op_desc
.
SetAttr
(
"output_name_mapping"
,
std
::
vector
<
std
::
string
>
({
"z0"
}));
engine_op_desc
.
SetAttr
(
"subgraph"
,
std
::
string
(
block_
->
SerializeAsString
()));
LOG
(
INFO
)
<<
"create engine op"
;
auto
engine_op
=
framework
::
OpRegistry
::
CreateOp
(
*
engine_op_desc
.
Proto
()
);
auto
engine_op
=
framework
::
OpRegistry
::
CreateOp
(
engine_op_desc
);
LOG
(
INFO
)
<<
"engine_op "
<<
engine_op
.
get
();
framework
::
Scope
scope
;
...
...
@@ -190,20 +191,19 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) {
engine_op_desc
.
SetInput
(
"Xs"
,
std
::
vector
<
std
::
string
>
({
"x0"
}));
engine_op_desc
.
SetOutput
(
"Ys"
,
std
::
vector
<
std
::
string
>
({
"z3"
}));
SetAttr
<
std
::
string
>
(
engine_op_desc
.
Proto
(),
"subgraph"
,
block_
->
SerializeAsString
());
SetAttr
<
int
>
(
engine_op_desc
.
Proto
(),
"max_batch_size"
,
batch_size
);
SetAttr
<
int
>
(
engine_op_desc
.
Proto
(),
"workspace_size"
,
1
<<
20
);
SetAttr
<
std
::
vector
<
std
::
string
>>
(
engine_op_desc
.
Proto
(),
"parameters"
,
std
::
vector
<
std
::
string
>
({
"y0"
,
"y1"
,
"y2"
,
"y3"
}));
SetAttr
<
std
::
string
>
(
engine_op_desc
.
Proto
(),
"engine_uniq_key"
,
"b_engine"
);
SetAttr
<
std
::
vector
<
std
::
string
>>
(
engine_op_desc
.
Proto
(),
"output_name_mapping"
,
std
::
vector
<
std
::
string
>
({
"z3"
}));
auto
engine_op
=
framework
::
OpRegistry
::
CreateOp
(
*
engine_op_desc
.
Proto
());
engine_op_desc
.
SetBlockAttr
(
"sub_block"
,
&
block_desc
);
engine_op_desc
.
SetAttr
(
"max_batch_size"
,
static_cast
<
int
>
(
batch_size
));
engine_op_desc
.
SetAttr
(
"workspace_size"
,
static_cast
<
int
>
(
1
<<
20
));
engine_op_desc
.
SetAttr
(
"parameters"
,
std
::
vector
<
std
::
string
>
({
"y0"
,
"y1"
,
"y2"
,
"y3"
}));
engine_op_desc
.
SetAttr
(
"engine_key"
,
std
::
string
(
"b_engine"
));
engine_op_desc
.
SetAttr
(
"calibration_data"
,
std
::
string
(
""
));
engine_op_desc
.
SetAttr
(
"enable_int8"
,
static_cast
<
bool
>
(
false
));
engine_op_desc
.
SetAttr
(
"output_name_mapping"
,
std
::
vector
<
std
::
string
>
({
"z3"
}));
engine_op_desc
.
SetAttr
(
"subgraph"
,
std
::
string
(
block_
->
SerializeAsString
()));
auto
engine_op
=
framework
::
OpRegistry
::
CreateOp
(
engine_op_desc
);
// Execute them.
engine_op
->
Run
(
scope
,
place
);
...
...
paddle/fluid/pybind/inference_api.cc
浏览文件 @
97b76c94
...
...
@@ -180,8 +180,14 @@ void BindNativePredictor(py::module *m) {
}
void
BindAnalysisConfig
(
py
::
module
*
m
)
{
py
::
class_
<
AnalysisConfig
>
(
*
m
,
"AnalysisConfig"
)
.
def
(
py
::
init
<
const
AnalysisConfig
&>
())
py
::
class_
<
AnalysisConfig
>
analysis_config
(
*
m
,
"AnalysisConfig"
);
py
::
enum_
<
AnalysisConfig
::
Precision
>
(
analysis_config
,
"Precision"
)
.
value
(
"Float32"
,
AnalysisConfig
::
Precision
::
kFloat32
)
.
value
(
"Int8"
,
AnalysisConfig
::
Precision
::
kInt8
)
.
export_values
();
analysis_config
.
def
(
py
::
init
<
const
AnalysisConfig
&>
())
.
def
(
py
::
init
<
const
std
::
string
&>
())
.
def
(
py
::
init
<
const
std
::
string
&
,
const
std
::
string
&>
())
.
def
(
"set_model"
,
(
void
(
AnalysisConfig
::*
)(
const
std
::
string
&
))
&
...
...
@@ -215,7 +221,8 @@ void BindAnalysisConfig(py::module *m) {
.
def
(
"specify_input_name"
,
&
AnalysisConfig
::
specify_input_name
)
.
def
(
"enable_tensorrt_engine"
,
&
AnalysisConfig
::
EnableTensorRtEngine
,
py
::
arg
(
"workspace_size"
)
=
1
<<
20
,
py
::
arg
(
"max_batch_size"
)
=
1
,
py
::
arg
(
"min_subgraph_size"
)
=
3
)
py
::
arg
(
"min_subgraph_size"
)
=
3
,
py
::
arg
(
"precision_mode"
)
=
AnalysisConfig
::
Precision
::
kFloat32
)
.
def
(
"tensorrt_engine_enabled"
,
&
AnalysisConfig
::
tensorrt_engine_enabled
)
.
def
(
"switch_ir_debug"
,
&
AnalysisConfig
::
SwitchIrDebug
,
py
::
arg
(
"x"
)
=
true
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录