Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
4e3522e5
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4e3522e5
编写于
1月 09, 2019
作者:
N
nhzlx
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add trt int8 support
test=develop
上级
d09d6ead
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
514 addition
and
50 deletion
+514
-50
paddle/fluid/inference/analysis/argument.h
paddle/fluid/inference/analysis/argument.h
+3
-0
paddle/fluid/inference/analysis/helper.h
paddle/fluid/inference/analysis/helper.h
+15
-0
paddle/fluid/inference/analysis/ir_pass_manager.cc
paddle/fluid/inference/analysis/ir_pass_manager.cc
+10
-1
paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
...id/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
+28
-1
paddle/fluid/inference/api/analysis_config.cc
paddle/fluid/inference/api/analysis_config.cc
+6
-1
paddle/fluid/inference/api/analysis_predictor.cc
paddle/fluid/inference/api/analysis_predictor.cc
+57
-0
paddle/fluid/inference/api/analysis_predictor.h
paddle/fluid/inference/api/analysis_predictor.h
+3
-0
paddle/fluid/inference/api/paddle_analysis_config.h
paddle/fluid/inference/api/paddle_analysis_config.h
+3
-1
paddle/fluid/inference/tensorrt/CMakeLists.txt
paddle/fluid/inference/tensorrt/CMakeLists.txt
+1
-1
paddle/fluid/inference/tensorrt/engine.cc
paddle/fluid/inference/tensorrt/engine.cc
+7
-0
paddle/fluid/inference/tensorrt/engine.h
paddle/fluid/inference/tensorrt/engine.h
+17
-40
paddle/fluid/inference/tensorrt/trt_int8_calibrator.cc
paddle/fluid/inference/tensorrt/trt_int8_calibrator.cc
+144
-0
paddle/fluid/inference/tensorrt/trt_int8_calibrator.h
paddle/fluid/inference/tensorrt/trt_int8_calibrator.h
+128
-0
paddle/fluid/operators/tensorrt/tensorrt_engine_op.cc
paddle/fluid/operators/tensorrt/tensorrt_engine_op.cc
+8
-1
paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
+84
-4
未找到文件。
paddle/fluid/inference/analysis/argument.h
浏览文件 @
4e3522e5
...
...
@@ -104,6 +104,7 @@ struct Argument {
DECL_ARGUMENT_FIELD
(
model_program_path
,
ModelProgramPath
,
std
::
string
);
DECL_ARGUMENT_FIELD
(
model_params_path
,
ModelParamsPath
,
std
::
string
);
DECL_ARGUMENT_FIELD
(
model_from_memory
,
ModelFromMemory
,
bool
);
DECL_ARGUMENT_FIELD
(
model_path
,
ModelPath
,
std
::
string
);
// The overall graph to work on.
DECL_ARGUMENT_UNIQUE_FIELD
(
main_graph
,
MainGraph
,
framework
::
ir
::
Graph
);
...
...
@@ -126,6 +127,8 @@ struct Argument {
DECL_ARGUMENT_FIELD
(
tensorrt_max_batch_size
,
TensorRtMaxBatchSize
,
int
);
DECL_ARGUMENT_FIELD
(
tensorrt_workspace_size
,
TensorRtWorkspaceSize
,
int
);
DECL_ARGUMENT_FIELD
(
tensorrt_min_subgraph_size
,
TensorRtMinSubgraphSize
,
int
);
DECL_ARGUMENT_FIELD
(
tensorrt_precision_mode
,
TensorRtPrecisionMode
,
std
::
string
);
// The program transformed by IR analysis phase.
DECL_ARGUMENT_UNIQUE_FIELD
(
ir_analyzed_program
,
IrAnalyzedProgram
,
...
...
paddle/fluid/inference/analysis/helper.h
浏览文件 @
4e3522e5
...
...
@@ -156,6 +156,21 @@ static bool PathExists(const std::string &path) {
return
false
;
}
static
std
::
string
SplitPath
(
const
std
::
string
path
)
{
char
sep
=
'/'
;
#ifdef _WIN32
sep
=
'\\'
;
#endif
size_t
i
=
path
.
rfind
(
sep
,
path
.
length
());
if
(
i
!=
std
::
string
::
npos
)
{
return
(
path
.
substr
(
0
,
i
));
}
return
path
;
}
}
// namespace analysis
}
// namespace inference
}
// namespace paddle
...
...
paddle/fluid/inference/analysis/ir_pass_manager.cc
浏览文件 @
4e3522e5
...
...
@@ -67,9 +67,17 @@ void IRPassManager::CreatePasses(Argument *argument,
pass
->
Set
(
"max_batch_size"
,
new
int
(
argument
->
tensorrt_max_batch_size
()));
pass
->
Set
(
"min_subgraph_size"
,
new
int
(
argument
->
tensorrt_min_subgraph_size
()));
pass
->
Set
(
"program"
,
new
framework
::
ProgramDesc
*
(
const_cast
<
framework
::
ProgramDesc
*>
(
&
argument
->
main_program
())));
pass
->
Set
(
"precision_mode"
,
new
std
::
string
(
argument
->
tensorrt_precision_mode
()));
pass
->
Set
(
"model_dir"
,
new
std
::
string
(
argument
->
model_path
()));
}
// graph_ = pass->Apply(std::move(graph_));
pre_pass
=
pass_name
;
passes_
.
emplace_back
(
std
::
move
(
pass
));
...
...
@@ -94,7 +102,8 @@ framework::proto::ProgramDesc IRPassManager::AcquireProgram(
auto
pass
=
framework
::
ir
::
PassRegistry
::
Instance
().
Get
(
"graph_to_program_pass"
);
ProgramDesc
desc
(
program
);
ProgramDesc
desc
;
desc
.
CopyFrom
(
*
const_cast
<
ProgramDesc
&>
(
program
).
Proto
());
pass
->
SetNotOwned
(
"program"
,
&
desc
);
auto
*
the_graph
=
graph
->
release
();
*
graph
=
pass
->
Apply
(
std
::
unique_ptr
<
Graph
>
(
the_graph
));
...
...
paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
浏览文件 @
4e3522e5
...
...
@@ -72,13 +72,23 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
auto
&
subgraph
=
*
Agent
(
node
).
subgraph
();
PADDLE_ENFORCE
(
!
subgraph
.
empty
());
framework
::
ProgramDesc
*
program_desc
=
Get
<
framework
::
ProgramDesc
*>
(
"program"
);
// Add new block for TensorRTEngineOP
const
framework
::
BlockDesc
&
main_block
=
program_desc
->
Block
(
framework
::
kRootBlockIndex
);
// const framework::BlockDesc& main_block = program_desc->Block(0);
framework
::
BlockDesc
*
new_block
=
program_desc
->
AppendBlock
(
main_block
);
// An fake block desc.
framework
::
proto
::
BlockDesc
block_proto
;
framework
::
BlockDesc
block_desc
(
nullptr
,
&
block_proto
);
block_desc
.
Proto
()
->
set_parent_idx
(
-
1
);
block_desc
.
Proto
()
->
set_idx
(
0
);
for
(
auto
*
node
:
subgraph
)
{
auto
*
new_block_op
=
new_block
->
AppendOp
();
auto
*
op
=
block_desc
.
AppendOp
();
*
new_block_op
->
Proto
()
=
*
node
->
Op
()
->
Proto
();
*
op
->
Proto
()
=
*
node
->
Op
()
->
Proto
();
}
...
...
@@ -178,7 +188,6 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
// to Tensor.
std
::
vector
<
std
::
string
>
output_mapping
;
for
(
auto
name
:
output_names
)
{
// LOG(INFO) << name << " " << output_name_map.size();
PADDLE_ENFORCE
(
output_name_map
.
count
(
name
)
!=
0
);
output_mapping
.
push_back
(
output_name_map
[
name
]);
}
...
...
@@ -189,9 +198,11 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
*
vars
->
Add
()
=
*
node
->
Var
()
->
Proto
();
}
}
PADDLE_ENFORCE
(
!
block_desc
.
Proto
()
->
vars
().
empty
(),
"the block has no var-desc"
);
PADDLE_ENFORCE
(
!
output_mapping
.
empty
());
op_desc
->
SetBlockAttr
(
"sub_block"
,
new_block
);
// Set attrs
SetAttr
(
op_desc
->
Proto
(),
"subgraph"
,
block_desc
.
Proto
()
->
SerializeAsString
());
...
...
@@ -199,6 +210,22 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
SetAttr
(
op_desc
->
Proto
(),
"workspace_size"
,
Get
<
int
>
(
"workspace_size"
));
SetAttr
(
op_desc
->
Proto
(),
"parameters"
,
ExtractParameters
(
graph
->
Nodes
()));
SetAttr
(
op_desc
->
Proto
(),
"output_name_mapping"
,
output_mapping
);
std
::
string
engine_key
=
std
::
to_string
(
std
::
hash
<
std
::
string
>
()(
block_desc
.
Proto
()
->
SerializeAsString
()));
std
::
string
precision_mode
=
Get
<
std
::
string
>
(
"precision_mode"
);
SetAttr
(
op_desc
->
Proto
(),
"calibration_data"
,
std
::
string
(
""
));
std
::
string
trt_calib_file
=
Get
<
std
::
string
>
(
"model_dir"
)
+
"/trt_calib_"
+
engine_key
;
if
(
precision_mode
==
"INT8"
&&
FileExists
(
trt_calib_file
))
{
std
::
ifstream
infile
(
trt_calib_file
,
std
::
ios
::
in
);
std
::
stringstream
buffer
;
buffer
<<
infile
.
rdbuf
();
std
::
string
calibration_data
(
buffer
.
str
());
SetAttr
(
op_desc
->
Proto
(),
"calibration_data"
,
calibration_data
);
}
SetAttr
(
op_desc
->
Proto
(),
"precision_mode"
,
precision_mode
);
SetAttr
(
op_desc
->
Proto
(),
"engine_key"
,
engine_key
);
}
std
::
vector
<
std
::
string
>
ExtractParameters
(
...
...
paddle/fluid/inference/api/analysis_config.cc
浏览文件 @
4e3522e5
...
...
@@ -86,6 +86,7 @@ contrib::AnalysisConfig::AnalysisConfig(const contrib::AnalysisConfig &other) {
CP_MEMBER
(
tensorrt_workspace_size_
);
CP_MEMBER
(
tensorrt_max_batchsize_
);
CP_MEMBER
(
tensorrt_min_subgraph_size_
);
CP_MEMBER
(
tensorrt_precision_mode_
);
// MKLDNN releated.
CP_MEMBER
(
use_mkldnn_
);
CP_MEMBER
(
mkldnn_enabled_op_types_
);
...
...
@@ -123,10 +124,13 @@ void contrib::AnalysisConfig::EnableMKLDNN() {
void
contrib
::
AnalysisConfig
::
EnableTensorRtEngine
(
int
workspace_size
,
int
max_batch_size
,
int
min_subgraph_size
)
{
int
min_subgraph_size
,
std
::
string
precision_mode
)
{
use_tensorrt_
=
true
;
tensorrt_workspace_size_
=
workspace_size
;
tensorrt_max_batchsize_
=
max_batch_size
;
tensorrt_precision_mode_
=
precision_mode
;
Update
();
}
void
contrib
::
AnalysisConfig
::
Update
()
{
...
...
@@ -176,6 +180,7 @@ std::string contrib::AnalysisConfig::SerializeInfoCache() {
ss
<<
use_tensorrt_
;
ss
<<
tensorrt_workspace_size_
;
ss
<<
tensorrt_max_batchsize_
;
ss
<<
tensorrt_precision_mode_
;
ss
<<
use_mkldnn_
;
ss
<<
enable_ir_optim_
;
...
...
paddle/fluid/inference/api/analysis_predictor.cc
浏览文件 @
4e3522e5
...
...
@@ -15,6 +15,7 @@
#include "paddle/fluid/inference/api/analysis_predictor.h"
#include <glog/logging.h>
#include <algorithm>
#include <fstream>
#include <memory>
#include <string>
#include <vector>
...
...
@@ -30,6 +31,8 @@
#if PADDLE_WITH_TENSORRT
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#endif
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/tensorrt/trt_int8_calibrator.h"
#include "paddle/fluid/inference/utils/singleton.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/cpu_helper.h"
...
...
@@ -41,6 +44,10 @@ DECLARE_bool(profile);
namespace
paddle
{
using
contrib
::
AnalysisConfig
;
using
inference
::
Singleton
;
using
inference
::
tensorrt
::
TRTInt8Calibrator
;
using
inference
::
tensorrt
::
TRTCalibratorRes
;
using
inference
::
tensorrt
::
TRTCalibratorResManager
;
namespace
{
bool
IsPersistable
(
const
framework
::
VarDesc
*
var
)
{
...
...
@@ -321,11 +328,15 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
// Analyze inference_program
if
(
!
config_
.
model_dir
().
empty
())
{
argument_
.
SetModelDir
(
config_
.
model_dir
());
argument_
.
SetModelPath
(
config_
.
model_dir
());
}
else
{
PADDLE_ENFORCE
(
!
config_
.
params_file
().
empty
(),
"Either model_dir or (param_file, prog_file) should be set."
);
PADDLE_ENFORCE
(
!
config_
.
prog_file
().
empty
());
std
::
string
dir
=
inference
::
analysis
::
SplitPath
(
config_
.
prog_file
());
argument_
.
SetModelPath
(
dir
);
argument_
.
SetModelProgramPath
(
config_
.
prog_file
());
argument_
.
SetModelParamsPath
(
config_
.
params_file
());
}
...
...
@@ -335,6 +346,7 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
argument_
.
SetTensorRtWorkspaceSize
(
config_
.
tensorrt_workspace_size_
);
argument_
.
SetTensorRtMaxBatchSize
(
config_
.
tensorrt_max_batchsize_
);
argument_
.
SetTensorRtMinSubgraphSize
(
config_
.
tensorrt_min_subgraph_size_
);
argument_
.
SetTensorRtPrecisionMode
(
config_
.
tensorrt_precision_mode_
);
}
if
(
config_
.
use_mkldnn_
)
{
...
...
@@ -550,7 +562,52 @@ bool AnalysisPredictor::LoadParameters() {
return
true
;
}
bool
AnalysisPredictor
::
SaveTrtCalibToDisk
()
{
PADDLE_ENFORCE
(
config_
.
tensorrt_engine_enabled
(),
"This func can be invoked only in trt mode"
);
auto
&
block
=
inference_program_
->
Block
(
0
);
for
(
auto
&
op_desc
:
block
.
AllOps
())
{
if
(
op_desc
->
Type
()
==
"tensorrt_engine"
)
{
std
::
string
engine_name
=
boost
::
get
<
std
::
string
>
(
op_desc
->
GetAttr
(
"engine_key"
));
if
(
!
Singleton
<
TRTCalibratorResManager
>::
Global
().
Has
(
engine_name
))
{
LOG
(
ERROR
)
<<
"You should run the predictor(with trt) on the real data "
"to generate calibration info"
;
return
false
;
}
TRTCalibratorRes
*
calib_res
=
Singleton
<
TRTCalibratorResManager
>::
Global
().
Get
(
engine_name
);
LOG
(
INFO
)
<<
"Wait for calib threads done."
;
calib_res
->
calib_
->
waitAndSetDone
();
LOG
(
INFO
)
<<
"Finish wait."
;
calib_res
->
thr_
->
join
();
std
::
string
calibration_data
=
calib_res
->
calib_
->
getCalibrationTableAsString
();
if
(
calibration_data
.
size
()
==
0
)
{
LOG
(
ERROR
)
<<
"the calibration table is empty."
;
return
false
;
}
std
::
string
calibration_data_path
=
argument_
.
model_path
()
+
"/trt_calib_"
+
engine_name
;
std
::
ofstream
ofile
(
calibration_data_path
,
std
::
ios
::
out
);
LOG
(
INFO
)
<<
"Write Paddle-TRT INT8 calibration data to file "
<<
calibration_data_path
;
ofile
<<
calibration_data
;
ofile
.
close
();
}
}
// Free all calibrator resources.
Singleton
<
TRTCalibratorResManager
>::
Global
().
DeleteALL
();
return
true
;
}
AnalysisPredictor
::~
AnalysisPredictor
()
{
if
(
config_
.
tensorrt_engine_enabled
()
&&
config_
.
tensorrt_precision_mode_
==
"INT8"
&&
Singleton
<
TRTCalibratorResManager
>::
Global
().
Has
())
{
SaveTrtCalibToDisk
();
}
if
(
FLAGS_profile
)
{
platform
::
DisableProfiler
(
platform
::
EventSortingKey
::
kTotal
,
"./profile.log"
);
...
...
paddle/fluid/inference/api/analysis_predictor.h
浏览文件 @
4e3522e5
...
...
@@ -90,6 +90,9 @@ class AnalysisPredictor : public PaddlePredictor {
template
<
typename
T
>
void
GetFetchOne
(
const
framework
::
LoDTensor
&
fetchs
,
PaddleTensor
*
output_data
);
bool
SaveTrtCalibToDisk
();
~
AnalysisPredictor
();
// Some more detailed tests, they are made the friends of the predictor, so that
...
...
paddle/fluid/inference/api/paddle_analysis_config.h
浏览文件 @
4e3522e5
...
...
@@ -135,7 +135,8 @@ struct AnalysisConfig {
* subgraph is less than this, it will not transfer to TensorRT engine.
*/
void
EnableTensorRtEngine
(
int
workspace_size
=
1
<<
20
,
int
max_batch_size
=
1
,
int
min_subgraph_size
=
3
);
int
max_batch_size
=
1
,
int
min_subgraph_size
=
3
,
std
::
string
precision
=
"FP32"
);
/** A boolean state telling whether the TensorRT engine is used.
*/
bool
tensorrt_engine_enabled
()
const
{
return
use_tensorrt_
;
}
...
...
@@ -231,6 +232,7 @@ struct AnalysisConfig {
// We set this variable to control the minimum number of nodes in the
// subgraph, 3 as default value.
int
tensorrt_min_subgraph_size_
{
3
};
std
::
string
tensorrt_precision_mode_
;
bool
use_mkldnn_
{
false
};
std
::
unordered_set
<
std
::
string
>
mkldnn_enabled_op_types_
;
...
...
paddle/fluid/inference/tensorrt/CMakeLists.txt
浏览文件 @
4e3522e5
nv_library
(
tensorrt_engine SRCS engine.cc DEPS
${
GLOB_OPERATOR_DEPS
}
framework_proto device_context
)
nv_library
(
tensorrt_engine SRCS engine.cc
trt_int8_calibrator.cc
DEPS
${
GLOB_OPERATOR_DEPS
}
framework_proto device_context
)
nv_library
(
tensorrt_op_teller SRCS op_teller.cc DEPS framework_proto
)
nv_test
(
test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader
)
nv_test
(
test_tensorrt_engine SRCS test_engine.cc DEPS dynload_cuda tensorrt_engine
)
...
...
paddle/fluid/inference/tensorrt/engine.cc
浏览文件 @
4e3522e5
...
...
@@ -70,6 +70,13 @@ void TensorRTEngine::FreezeNetwork() {
// build engine.
infer_builder_
->
setMaxBatchSize
(
max_batch_
);
infer_builder_
->
setMaxWorkspaceSize
(
max_workspace_
);
if
(
precision_mode_
==
"INT8"
)
{
infer_builder_
->
setInt8Mode
(
true
);
PADDLE_ENFORCE
(
calibrator_
!=
nullptr
,
"The precision mode is 'INT8', the calibrator should not be nullptr"
);
infer_builder_
->
setInt8Calibrator
(
calibrator_
);
}
infer_engine_
.
reset
(
infer_builder_
->
buildCudaEngine
(
*
infer_network_
));
PADDLE_ENFORCE
(
infer_engine_
!=
nullptr
,
"build cuda engine failed!"
);
...
...
paddle/fluid/inference/tensorrt/engine.h
浏览文件 @
4e3522e5
...
...
@@ -23,12 +23,14 @@ limitations under the License. */
#include "paddle/fluid/inference/engine.h"
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/inference/tensorrt/trt_int8_calibrator.h"
#include "paddle/fluid/inference/utils/singleton.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
class
TRTInt8Calibrator
;
/*
* TensorRT Engine.
*
...
...
@@ -56,12 +58,16 @@ class TensorRTEngine : public EngineBase {
TensorRTEngine
(
int
max_batch
,
int
max_workspace
,
cudaStream_t
*
stream
=
nullptr
,
int
device
=
0
,
std
::
string
precision_mode
=
"FP32"
,
TRTInt8Calibrator
*
calibrator
=
nullptr
,
nvinfer1
::
ILogger
&
logger
=
NaiveLogger
::
Global
())
:
max_batch_
(
max_batch
),
max_workspace_
(
max_workspace
),
stream_
(
stream
?
stream
:
&
default_stream_
),
logger_
(
logger
),
device_
(
device
)
{
device_
(
device
),
precision_mode_
(
precision_mode
),
calibrator_
(
calibrator
),
logger_
(
logger
)
{
freshDeviceId
();
cudaStreamCreate
(
stream_
);
}
...
...
@@ -142,8 +148,8 @@ class TensorRTEngine : public EngineBase {
// In the normal case, the paddle-trt exists bug when runing the googlenet.
// When there are more than two convolutions of 1 * 1 with the same input, the
// paddle-tensorrt will do the merging optimization, which fuse those conv
// into
//
one conv, and then trigger bug. So, We should use strategy to avoid
this
// into
one conv, and then trigger bug. So, We should use strategy to avoid
// this
// optimization for the time being. This bug will be fixed in the future.
std
::
unordered_map
<
std
::
string
/*name*/
,
int
/*ITensor_quote_num*/
>
itensor_quote_num
;
...
...
@@ -156,11 +162,16 @@ class TensorRTEngine : public EngineBase {
// the max memory size the engine uses
int
max_workspace_
;
// batch size of the current data, will be updated each Executation.
int
batch_size_
{
-
1
};
cudaStream_t
*
stream_
;
// If stream_ is not set from outside, hold its own stream.
cudaStream_t
default_stream_
;
// The specific GPU id that the TensorRTEngine bounded to.
int
device_
;
std
::
string
precision_mode_
;
TRTInt8Calibrator
*
calibrator_
;
// batch size of the current data, will be updated each Executation.
int
batch_size_
{
-
1
};
nvinfer1
::
ILogger
&
logger_
;
std
::
vector
<
Buffer
>
buffers_
;
...
...
@@ -169,8 +180,6 @@ class TensorRTEngine : public EngineBase {
std
::
unordered_map
<
std
::
string
/*name*/
,
nvinfer1
::
ITensor
*
/*ITensor*/
>
itensor_map_
;
// The specific GPU id that the TensorRTEngine bounded to.
int
device_
;
std
::
vector
<
std
::
unique_ptr
<
plugin
::
PluginTensorRT
>>
owned_plugin_
;
// TensorRT related internal members
...
...
@@ -208,38 +217,6 @@ class TensorRTEngine : public EngineBase {
#define TRT_ENGINE_ADD_LAYER(engine__, layer__, ARGS...) \
engine__->network()->add##layer__(ARGS);
/*
* Helper to control the TensorRT engine's creation and deletion.
*/
class
TRT_EngineManager
{
public:
bool
HasEngine
(
const
std
::
string
&
name
)
const
{
return
engines_
.
count
(
name
)
!=
0
;
}
// Get an engine called `name`.
TensorRTEngine
*
Get
(
const
std
::
string
&
name
)
const
{
return
engines_
.
at
(
name
).
get
();
}
// Create or get an engine called `name`
TensorRTEngine
*
Create
(
int
max_batch
,
int
max_workspace
,
cudaStream_t
*
stream
,
const
std
::
string
&
name
,
int
gpu_device
=
0
)
{
auto
*
p
=
new
TensorRTEngine
(
max_batch
,
max_workspace
,
stream
,
gpu_device
);
engines_
[
name
].
reset
(
p
);
return
p
;
}
void
DeleteALl
()
{
for
(
auto
&
item
:
engines_
)
{
item
.
second
.
reset
(
nullptr
);
}
}
private:
std
::
unordered_map
<
std
::
string
,
std
::
unique_ptr
<
TensorRTEngine
>>
engines_
;
};
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/trt_int8_calibrator.cc
0 → 100644
浏览文件 @
4e3522e5
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/tensorrt/trt_int8_calibrator.h"
#include "glog/logging.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
// set the batch size before constructing the thread to execute engine
int
TRTInt8Calibrator
::
getBatchSize
()
const
{
return
batch_size_
;
}
TRTInt8Calibrator
::
TRTInt8Calibrator
(
const
std
::
unordered_map
<
std
::
string
,
size_t
>&
buffers
,
int
batch_size
,
std
::
string
engine_name
,
const
platform
::
Place
place
)
:
batch_size_
(
batch_size
),
calib_running_
(
true
),
data_is_set_
(
false
),
done_
(
false
),
engine_name_
(
engine_name
)
{
int
i
=
0
;
VLOG
(
4
)
<<
"Init a new calibrator: "
<<
engine_name_
;
for
(
const
auto
it
:
buffers
)
{
framework
::
Tensor
temp_tensor
;
std
::
string
input_name
=
it
.
first
;
int
data_size
=
it
.
second
;
int
num_ele
=
data_size
/
sizeof
(
int16_t
);
framework
::
DDim
data_shape
=
framework
::
make_ddim
({
num_ele
});
temp_tensor
.
Resize
(
data_shape
);
data_tensors_
.
push_back
(
temp_tensor
);
data_buffers_
[
input_name
]
=
std
::
pair
<
void
*
,
size_t
>
(
static_cast
<
void
*>
(
temp_tensor
.
mutable_data
<
int16_t
>
(
place
)),
num_ele
);
i
+=
1
;
}
}
TRTInt8Calibrator
::
TRTInt8Calibrator
(
const
std
::
string
&
calib_data
)
:
batch_size_
(
0
),
calib_running_
(
false
),
data_is_set_
(
false
),
done_
(
true
),
calibration_table_
(
calib_data
)
{}
void
TRTInt8Calibrator
::
waitAndSetDone
()
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
mut_
);
while
((
calib_running_
||
data_is_set_
)
&&
!
done_
)
cond_
.
wait
(
lk
);
if
(
!
done_
)
{
done_
=
true
;
cond_
.
notify_all
();
}
}
bool
TRTInt8Calibrator
::
setBatch
(
const
std
::
unordered_map
<
std
::
string
,
void
*>&
data
)
{
VLOG
(
3
)
<<
"set batch: "
<<
engine_name_
;
std
::
unique_lock
<
std
::
mutex
>
lk
(
mut_
);
while
((
calib_running_
||
data_is_set_
)
&&
(
!
done_
))
cond_
.
wait
(
lk
);
if
(
done_
)
return
false
;
// Sets the batch.
for
(
const
auto
it
:
data
)
{
auto
dataptr
=
data_buffers_
.
find
(
it
.
first
);
if
(
dataptr
==
data_buffers_
.
end
())
{
LOG
(
FATAL
)
<<
"FATAL "
<<
engine_name_
<<
" input name '"
<<
it
.
first
<<
"' does not match with the buffer names"
;
}
const
auto
&
d
=
dataptr
->
second
;
auto
status
=
cudaMemcpy
(
d
.
first
,
it
.
second
,
d
.
second
,
cudaMemcpyDeviceToDevice
);
if
(
status
!=
cudaSuccess
)
{
LOG
(
FATAL
)
<<
"cudaMemcpy "
<<
engine_name_
<<
" for '"
<<
it
.
first
<<
"' failed with "
<<
status
;
}
}
data_is_set_
=
true
;
cond_
.
notify_all
();
return
true
;
}
bool
TRTInt8Calibrator
::
getBatch
(
void
**
bindings
,
const
char
**
names
,
int
num_bindings
)
{
VLOG
(
4
)
<<
"get batch: "
<<
engine_name_
;
std
::
unique_lock
<
std
::
mutex
>
lk
(
mut_
);
calib_running_
=
false
;
cond_
.
notify_all
();
while
(
!
data_is_set_
&&
!
done_
)
cond_
.
wait
(
lk
);
if
(
done_
)
return
false
;
// Gets the batch
for
(
int
i
=
0
;
i
<
num_bindings
;
i
++
)
{
auto
it
=
data_buffers_
.
find
(
names
[
i
]);
if
(
it
==
data_buffers_
.
end
())
{
LOG
(
FATAL
)
<<
"Calibration engine asked for unknown tensor name '"
<<
names
[
i
]
<<
"' at position "
<<
i
;
}
bindings
[
i
]
=
it
->
second
.
first
;
}
data_is_set_
=
false
;
calib_running_
=
true
;
VLOG
(
4
)
<<
"get batch done: "
<<
engine_name_
;
return
true
;
}
void
TRTInt8Calibrator
::
setDone
()
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
mut_
);
done_
=
true
;
cond_
.
notify_all
();
}
const
void
*
TRTInt8Calibrator
::
readCalibrationCache
(
std
::
size_t
&
length
)
{
if
(
calibration_table_
.
empty
())
return
nullptr
;
length
=
calibration_table_
.
size
();
return
calibration_table_
.
data
();
}
void
TRTInt8Calibrator
::
writeCalibrationCache
(
const
void
*
ptr
,
std
::
size_t
length
)
{
calibration_table_
=
std
::
string
((
const
char
*
)
ptr
,
length
);
VLOG
(
4
)
<<
"Got calibration data for "
<<
engine_name_
<<
" "
<<
ptr
<<
" length="
<<
length
;
}
TRTInt8Calibrator
::~
TRTInt8Calibrator
()
{
VLOG
(
4
)
<<
"Destroying calibrator for "
<<
engine_name_
;
}
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/trt_int8_calibrator.h
0 → 100644
浏览文件 @
4e3522e5
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cuda_runtime_api.h>
#include <atomic>
#include <memory>
#include <mutex>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "NvInfer.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/platform/place.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
class
TensorRTEngine
;
struct
TRTInt8Calibrator
:
public
nvinfer1
::
IInt8EntropyCalibrator
{
public:
TRTInt8Calibrator
(
const
std
::
unordered_map
<
std
::
string
,
size_t
>&
buffers
,
int
batch_size
,
std
::
string
engine_name
,
const
platform
::
Place
place
);
explicit
TRTInt8Calibrator
(
const
std
::
string
&
calibration_data
);
~
TRTInt8Calibrator
();
int
getBatchSize
()
const
override
;
bool
getBatch
(
void
*
bindings
[],
const
char
*
names
[],
int
num_bindings
)
override
;
bool
setBatch
(
const
std
::
unordered_map
<
std
::
string
,
void
*>&
data
);
void
setDone
();
void
waitAndSetDone
();
const
void
*
readCalibrationCache
(
std
::
size_t
&
length
)
override
;
void
writeCalibrationCache
(
const
void
*
ptr
,
std
::
size_t
length
)
override
;
const
std
::
string
&
getCalibrationTableAsString
()
{
return
calibration_table_
;
}
private:
const
int
batch_size_
;
bool
calib_running_
;
bool
data_is_set_
;
bool
done_
;
std
::
mutex
mut_
;
std
::
condition_variable
cond_
;
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
void
*
,
size_t
>>
data_buffers_
;
std
::
vector
<
framework
::
Tensor
>
data_tensors_
;
std
::
string
engine_name_
;
std
::
string
calibration_table_
;
};
class
TRTCalibratorRes
{
public:
TRTCalibratorRes
()
{}
std
::
unique_ptr
<
TRTInt8Calibrator
>
calib_
;
std
::
unique_ptr
<
std
::
thread
>
thr_
;
std
::
unique_ptr
<
TensorRTEngine
>
engine_
;
};
/*
* Manager to control the TensorRT Int8 calibration creation and deltetion.
*/
class
TRTCalibratorResManager
{
public:
bool
Has
()
const
{
return
res_
.
size
()
>
0
;
}
bool
Has
(
const
std
::
string
&
name
)
const
{
if
(
res_
.
count
(
name
)
==
0
)
return
false
;
return
res_
.
at
(
name
).
get
()
!=
nullptr
;
}
// Get Int8Calibrator via name
TRTCalibratorRes
*
Get
(
const
std
::
string
&
name
)
const
{
return
res_
.
at
(
name
).
get
();
}
// Look up or create a calibrator.
TRTCalibratorRes
*
LookupOrCreate
(
const
std
::
string
&
engine_name
)
{
if
(
res_
.
count
(
engine_name
)
==
0
)
{
auto
*
p
=
new
TRTCalibratorRes
();
res_
[
engine_name
].
reset
(
p
);
}
return
res_
.
at
(
engine_name
).
get
();
}
// Create an Int8Calibrator
TRTCalibratorRes
*
Create
(
const
std
::
string
&
engine_name
)
{
auto
*
p
=
new
TRTCalibratorRes
();
res_
[
engine_name
].
reset
(
p
);
return
p
;
}
void
DeleteALL
()
{
for
(
auto
&
item
:
res_
)
{
item
.
second
.
reset
(
nullptr
);
}
}
private:
std
::
unordered_map
<
std
::
string
,
std
::
unique_ptr
<
TRTCalibratorRes
>>
res_
;
};
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/operators/tensorrt/tensorrt_engine_op.cc
浏览文件 @
4e3522e5
...
...
@@ -29,8 +29,15 @@ class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput
(
"Xs"
,
"A list of inputs."
).
AsDuplicable
();
AddOutput
(
"Ys"
,
"A list of outputs"
).
AsDuplicable
();
AddAttr
<
std
::
string
>
(
"subgraph"
,
"the subgraph."
);
AddAttr
<
std
::
string
>
(
"calibration_data"
,
"the calibration data for int8"
);
AddAttr
<
std
::
string
>
(
"engine_key"
,
"The engine_key here is used to distinguish different TRT Engines"
);
AddAttr
<
int
>
(
"max_batch_size"
,
"the maximum batch size."
);
AddAttr
<
int
>
(
"workspace_size"
,
"the workspace size."
);
AddAttr
<
framework
::
BlockDesc
*>
(
"sub_block"
,
"the trt block"
);
AddAttr
<
std
::
string
>
(
"precision_mode"
,
"the precision mode: 'FP32', 'INT8' "
);
AddComment
(
"TensorRT engine operator."
);
}
};
...
...
@@ -47,6 +54,6 @@ class TensorRTEngineInferVarType : public framework::VarTypeInference {
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
tensorrt_engine
,
ops
::
TensorRTEngineOp
,
ops
::
TensorRTEngineOpMaker
);
ops
::
TensorRTEngineOpMaker
,
ops
::
TensorRTEngineOpMaker
);
#endif // PADDLE_WITH_CUDA
paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
浏览文件 @
4e3522e5
...
...
@@ -17,8 +17,10 @@
#ifdef PADDLE_WITH_CUDA
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/inference/analysis/helper.h"
...
...
@@ -62,6 +64,9 @@ nvinfer1::Dims Vec2TRT_Dims(const std::vector<int64_t> &shape) {
using
inference
::
Singleton
;
using
inference
::
tensorrt
::
TensorRTEngine
;
using
inference
::
tensorrt
::
TRTInt8Calibrator
;
using
inference
::
tensorrt
::
TRTCalibratorRes
;
using
inference
::
tensorrt
::
TRTCalibratorResManager
;
class
TensorRTEngineOp
:
public
framework
::
OperatorBase
{
private:
...
...
@@ -70,6 +75,11 @@ class TensorRTEngineOp : public framework::OperatorBase {
mutable
std
::
unique_ptr
<
TensorRTEngine
>
trt_engine_
;
int
max_batch_size_
;
int
workspace_size_
;
std
::
unique_ptr
<
TRTInt8Calibrator
>
calibrator_
;
std
::
string
precision_mode_
;
std
::
string
calibration_data_
;
std
::
string
engine_key_
;
bool
calibration_mode_
;
public:
TensorRTEngineOp
(
const
std
::
string
&
type
,
...
...
@@ -80,26 +90,95 @@ class TensorRTEngineOp : public framework::OperatorBase {
input_names_
=
Inputs
(
"Xs"
);
max_batch_size_
=
Attr
<
int
>
(
"max_batch_size"
);
workspace_size_
=
Attr
<
int
>
(
"workspace_size"
);
precision_mode_
=
Attr
<
std
::
string
>
(
"precision_mode"
);
calibration_data_
=
Attr
<
std
::
string
>
(
"calibration_data"
);
engine_key_
=
Attr
<
std
::
string
>
(
"engine_key"
);
auto
params
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"parameters"
);
for
(
const
auto
&
param
:
params
)
{
param_names_
.
insert
(
param
);
}
calibration_mode_
=
(
precision_mode_
==
"INT8"
&&
calibration_data_
.
size
()
==
0
);
if
(
precision_mode_
==
"INT8"
&&
calibration_data_
.
size
())
{
calibrator_
.
reset
(
new
TRTInt8Calibrator
(
calibration_data_
));
}
}
protected:
void
RunNative
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
{
framework
::
Executor
executor
(
dev_place
);
auto
*
block
=
Attr
<
framework
::
BlockDesc
*>
(
"sub_block"
);
auto
*
program
=
block
->
Program
();
auto
*
scope_ptr
=
const_cast
<
framework
::
Scope
*>
(
&
scope
);
auto
ctx
=
executor
.
Prepare
(
*
program
,
block
->
ID
());
executor
.
RunPreparedContext
(
ctx
.
get
(),
scope_ptr
,
false
,
true
,
true
);
}
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
if
(
calibration_mode_
==
true
)
{
RunCalibration
(
scope
,
dev_place
);
return
;
}
RunTrt
(
scope
,
dev_place
);
}
void
RunCalibration
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
{
// Create calibrator here.
LOG
(
INFO
)
<<
"Running calibration trt int8 ..."
;
int
runtime_batch
=
1
;
if
(
!
Singleton
<
TRTCalibratorResManager
>::
Global
().
Has
(
engine_key_
))
{
TRTCalibratorRes
*
calib_res
=
Singleton
<
TRTCalibratorResManager
>::
Global
().
Create
(
engine_key_
);
std
::
unordered_map
<
std
::
string
,
size_t
>
calib_buffers
;
for
(
auto
&
x
:
input_names_
)
{
if
(
param_names_
.
count
(
x
))
continue
;
auto
&
t
=
inference
::
analysis
::
GetFromScope
<
framework
::
LoDTensor
>
(
scope
,
x
);
calib_buffers
[
x
]
=
t
.
memory_size
();
auto
t_shape
=
framework
::
vectorize
(
t
.
dims
());
runtime_batch
=
t_shape
[
0
];
}
calib_res
->
calib_
.
reset
(
new
TRTInt8Calibrator
(
calib_buffers
,
runtime_batch
,
engine_key_
,
dev_place
));
calib_res
->
thr_
.
reset
(
new
std
::
thread
([
&
]()
{
calib_res
->
engine_
.
reset
(
new
TensorRTEngine
(
max_batch_size_
,
workspace_size_
,
nullptr
,
boost
::
get
<
platform
::
CUDAPlace
>
(
dev_place
).
device
,
precision_mode_
,
calib_res
->
calib_
.
get
()));
VLOG
(
3
)
<<
"start the calib trt engine thread"
;
Prepare
(
scope
,
dev_place
,
calib_res
->
engine_
.
get
());
}));
}
TRTInt8Calibrator
*
temp_calibrator
=
Singleton
<
TRTCalibratorResManager
>::
Global
()
.
Get
(
engine_key_
)
->
calib_
.
get
();
std
::
unordered_map
<
std
::
string
,
void
*>
calib_data
;
for
(
auto
&
x
:
Inputs
(
"Xs"
))
{
if
(
param_names_
.
count
(
x
))
continue
;
auto
&
t
=
inference
::
analysis
::
GetFromScope
<
framework
::
LoDTensor
>
(
scope
,
x
);
calib_data
.
emplace
(
x
,
t
.
data
<
void
>
());
}
temp_calibrator
->
setBatch
(
calib_data
);
RunNative
(
scope
,
dev_place
);
}
void
RunTrt
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
{
int
runtime_batch
=
1
;
if
(
trt_engine_
.
get
()
==
nullptr
)
{
trt_engine_
.
reset
(
new
TensorRTEngine
(
max_batch_size_
,
workspace_size_
,
nullptr
,
boost
::
get
<
platform
::
CUDAPlace
>
(
dev_place
).
device
));
trt_engine_
.
reset
(
new
TensorRTEngine
(
max_batch_size_
,
workspace_size_
,
nullptr
,
boost
::
get
<
platform
::
CUDAPlace
>
(
dev_place
).
device
,
precision_mode_
,
calibrator_
.
get
()));
Prepare
(
scope
,
dev_place
,
trt_engine_
.
get
());
}
...
...
@@ -168,7 +247,8 @@ class TensorRTEngineOp : public framework::OperatorBase {
void
Prepare
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
,
TensorRTEngine
*
engine
)
const
{
VLOG
(
4
)
<<
"Prepare engine"
;
LOG
(
INFO
)
<<
"Prepare TRT engine (Optimize model structure, Select OP "
"kernel etc). This process may cost a lot of time."
;
framework
::
proto
::
BlockDesc
block_desc
;
block_desc
.
ParseFromString
(
Attr
<
std
::
string
>
(
"subgraph"
));
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录