Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
3d63aa0a
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3d63aa0a
编写于
3月 07, 2019
作者:
Z
Zhaolong Xing
提交者:
GitHub
3月 07, 2019
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #15729 from NHZlX/add_static_model_load_for_trt
Four points for enhancing Paddle-TRT
上级
b0e3c024
a9ed4277
变更
47
隐藏空白更改
内联
并排
Showing
47 changed file
with
1006 addition
and
535 deletion
+1006
-535
Dockerfile
Dockerfile
+3
-2
paddle/fluid/framework/ir/fuse_pass_base.h
paddle/fluid/framework/ir/fuse_pass_base.h
+5
-0
paddle/fluid/inference/analysis/argument.h
paddle/fluid/inference/analysis/argument.h
+6
-0
paddle/fluid/inference/analysis/helper.h
paddle/fluid/inference/analysis/helper.h
+31
-0
paddle/fluid/inference/analysis/ir_pass_manager.cc
paddle/fluid/inference/analysis/ir_pass_manager.cc
+3
-0
paddle/fluid/inference/analysis/ir_pass_manager.h
paddle/fluid/inference/analysis/ir_pass_manager.h
+3
-0
paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
...id/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
+193
-74
paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h
...uid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h
+9
-3
paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc
...ence/analysis/passes/ir_params_sync_among_devices_pass.cc
+11
-0
paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.h
...rence/analysis/passes/ir_params_sync_among_devices_pass.h
+1
-0
paddle/fluid/inference/api/analysis_config.cc
paddle/fluid/inference/api/analysis_config.cc
+3
-1
paddle/fluid/inference/api/analysis_predictor.cc
paddle/fluid/inference/api/analysis_predictor.cc
+32
-0
paddle/fluid/inference/api/analysis_predictor.h
paddle/fluid/inference/api/analysis_predictor.h
+9
-0
paddle/fluid/inference/api/details/zero_copy_tensor.cc
paddle/fluid/inference/api/details/zero_copy_tensor.cc
+58
-2
paddle/fluid/inference/api/details/zero_copy_tensor_dummy.cc
paddle/fluid/inference/api/details/zero_copy_tensor_dummy.cc
+1
-1
paddle/fluid/inference/api/helper.h
paddle/fluid/inference/api/helper.h
+5
-0
paddle/fluid/inference/api/paddle_analysis_config.h
paddle/fluid/inference/api/paddle_analysis_config.h
+3
-1
paddle/fluid/inference/api/paddle_api.h
paddle/fluid/inference/api/paddle_api.h
+21
-1
paddle/fluid/inference/engine.h
paddle/fluid/inference/engine.h
+0
-5
paddle/fluid/inference/tensorrt/convert/conv2d_op.cc
paddle/fluid/inference/tensorrt/convert/conv2d_op.cc
+2
-19
paddle/fluid/inference/tensorrt/convert/elementwise_op.cc
paddle/fluid/inference/tensorrt/convert/elementwise_op.cc
+1
-2
paddle/fluid/inference/tensorrt/convert/fc_op.cc
paddle/fluid/inference/tensorrt/convert/fc_op.cc
+2
-2
paddle/fluid/inference/tensorrt/convert/op_converter.h
paddle/fluid/inference/tensorrt/convert/op_converter.h
+62
-0
paddle/fluid/inference/tensorrt/convert/prelu_op.cc
paddle/fluid/inference/tensorrt/convert/prelu_op.cc
+8
-11
paddle/fluid/inference/tensorrt/convert/ut_helper.h
paddle/fluid/inference/tensorrt/convert/ut_helper.h
+51
-34
paddle/fluid/inference/tensorrt/engine.cc
paddle/fluid/inference/tensorrt/engine.cc
+12
-131
paddle/fluid/inference/tensorrt/engine.h
paddle/fluid/inference/tensorrt/engine.h
+39
-52
paddle/fluid/inference/tensorrt/helper.h
paddle/fluid/inference/tensorrt/helper.h
+29
-0
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
+2
-1
paddle/fluid/inference/tensorrt/plugin/avg_pool_op_plugin.cu
paddle/fluid/inference/tensorrt/plugin/avg_pool_op_plugin.cu
+7
-0
paddle/fluid/inference/tensorrt/plugin/avg_pool_op_plugin.h
paddle/fluid/inference/tensorrt/plugin/avg_pool_op_plugin.h
+9
-5
paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.cu
.../fluid/inference/tensorrt/plugin/elementwise_op_plugin.cu
+9
-2
paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.h
...e/fluid/inference/tensorrt/plugin/elementwise_op_plugin.h
+13
-7
paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu
paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu
+14
-1
paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h
paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h
+30
-13
paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu
paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu
+6
-0
paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h
paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h
+6
-3
paddle/fluid/inference/tensorrt/plugin/trt_plugin.h
paddle/fluid/inference/tensorrt/plugin/trt_plugin.h
+9
-1
paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.cc
paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.cc
+48
-0
paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h
paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h
+78
-0
paddle/fluid/inference/tensorrt/plugin/trt_plugin_utils.h
paddle/fluid/inference/tensorrt/plugin/trt_plugin_utils.h
+8
-1
paddle/fluid/inference/tensorrt/test_engine.cc
paddle/fluid/inference/tensorrt/test_engine.cc
+92
-42
paddle/fluid/inference/tests/api/trt_models_tester.cc
paddle/fluid/inference/tests/api/trt_models_tester.cc
+2
-1
paddle/fluid/operators/tensorrt/tensorrt_engine_op.cc
paddle/fluid/operators/tensorrt/tensorrt_engine_op.cc
+3
-0
paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
+63
-116
paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc
paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc
+2
-0
paddle/fluid/pybind/inference_api.cc
paddle/fluid/pybind/inference_api.cc
+2
-1
未找到文件。
Dockerfile
浏览文件 @
3d63aa0a
...
...
@@ -75,8 +75,9 @@ RUN curl -s -q https://glide.sh/get | sh
# and its size is only one-third of the official one.
# 2. Manually add ~IPluginFactory() in IPluginFactory class of NvInfer.h, otherwise, it couldn't work in paddle.
# See https://github.com/PaddlePaddle/Paddle/issues/10129 for details.
RUN
wget
-qO-
http://paddlepaddledeps.cdn.bcebos.com/TensorRT-4.0.0.3.Ubuntu-16.04.4.x86_64-gnu.cuda-8.0.cudnn7.0.tar.gz |
\
tar
-xz
-C
/usr/local
&&
\
RUN
wget
-q
https://paddlepaddledeps.cdn.bcebos.com/TensorRT-4.0.1.6-ubuntu14.04.x86_64-gnu.cuda.8.0.cudnn7.0.tar.gz
--no-check-certificate
&&
\
tar
-zxf
TensorRT-4.0.1.6-ubuntu14.04.x86_64-gnu.cuda.8.0.cudnn7.0.tar.gz
-C
/usr/local
&&
\
cp
-rf
/usr/local/TensorRT/include /usr
&&
\
cp
-rf
/usr/local/TensorRT/lib /usr
...
...
paddle/fluid/framework/ir/fuse_pass_base.h
浏览文件 @
3d63aa0a
...
...
@@ -14,6 +14,7 @@
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/scope.h"
...
...
@@ -24,6 +25,10 @@ namespace ir {
static
const
char
kParamScopeAttr
[]
=
"__param_scope__"
;
static
const
char
kFuseStatisAttr
[]
=
"__fuse_statis__"
;
// When we use trt or other third_party lib, the parameters are managed by
// the lib, but not the fluid. So we need to record them to avoid duplicate
// allocation.
static
const
char
kRepetitiveParamAttr
[]
=
"__repetitive_param__"
;
enum
FuseOptions
{
DO_NOT_FUSE
,
// fusing will not be done
...
...
paddle/fluid/inference/analysis/argument.h
浏览文件 @
3d63aa0a
...
...
@@ -23,8 +23,12 @@
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
...
...
@@ -133,6 +137,8 @@ struct Argument {
DECL_ARGUMENT_FIELD
(
tensorrt_min_subgraph_size
,
TensorRtMinSubgraphSize
,
int
);
DECL_ARGUMENT_FIELD
(
tensorrt_precision_mode
,
TensorRtPrecisionMode
,
AnalysisConfig
::
Precision
);
DECL_ARGUMENT_FIELD
(
tensorrt_use_static_engine
,
TensorRtUseStaticEngine
,
bool
);
// Memory optimized related.
DECL_ARGUMENT_FIELD
(
enable_memory_optim
,
EnableMemoryOptim
,
bool
);
...
...
paddle/fluid/inference/analysis/helper.h
浏览文件 @
3d63aa0a
...
...
@@ -17,10 +17,12 @@ limitations under the License. */
#include <sys/stat.h>
#include <cstdio>
#include <fstream>
#include <memory>
#include <set>
#include <string>
#include <typeindex>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/framework.pb.h"
...
...
@@ -217,6 +219,35 @@ static std::string GetTrtCalibTableData(const std::string &model_opt_cache_dir,
return
""
;
}
static
std
::
string
GetTrtEngineSerializedPath
(
const
std
::
string
&
model_root
,
const
std
::
string
&
engine_key
)
{
return
model_root
+
"/trt_serialized_"
+
engine_key
;
}
static
std
::
string
GetTrtEngineSerializedData
(
const
std
::
string
&
model_opt_cache_dir
,
const
std
::
string
&
engine_key
)
{
std
::
string
trt_serialized_path
=
GetTrtEngineSerializedPath
(
model_opt_cache_dir
,
engine_key
);
if
(
FileExists
(
trt_serialized_path
))
{
VLOG
(
3
)
<<
"Trt serialized file: "
<<
trt_serialized_path
<<
"is found here"
;
std
::
ifstream
infile
(
trt_serialized_path
,
std
::
ios
::
in
);
std
::
stringstream
buffer
;
buffer
<<
infile
.
rdbuf
();
std
::
string
trt_engine_serialized_data
(
buffer
.
str
());
return
trt_engine_serialized_data
;
}
return
""
;
}
static
void
SaveTrtEngineSerializedDataToFile
(
const
std
::
string
&
trt_serialized_path
,
const
std
::
string
&
engine_serialized_data
)
{
std
::
ofstream
outfile
(
trt_serialized_path
);
outfile
<<
engine_serialized_data
;
outfile
.
close
();
}
}
// namespace analysis
}
// namespace inference
}
// namespace paddle
...
...
paddle/fluid/inference/analysis/ir_pass_manager.cc
浏览文件 @
3d63aa0a
...
...
@@ -81,6 +81,9 @@ void IRPassManager::CreatePasses(Argument *argument,
pass
->
Set
(
"model_opt_cache_dir"
,
new
std
::
string
(
GetOrCreateModelOptCacheDir
(
model_opt_cache_dir
)));
pass
->
Set
(
"gpu_device_id"
,
new
int
(
argument
->
gpu_device_id
()));
pass
->
Set
(
"use_static_engine"
,
new
bool
(
argument
->
tensorrt_use_static_engine
()));
}
pre_pass
=
pass_name
;
...
...
paddle/fluid/inference/analysis/ir_pass_manager.h
浏览文件 @
3d63aa0a
...
...
@@ -22,7 +22,10 @@
#pragma once
#include <memory>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
...
...
paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
浏览文件 @
3d63aa0a
...
...
@@ -14,13 +14,13 @@
#include <algorithm>
#include <set>
#include <string>
#include <vector>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/analysis/ir_passes/subgraph_detector.h"
#include "paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/op_teller.h"
#include "paddle/fluid/string/pretty_log.h"
...
...
@@ -33,8 +33,15 @@ using framework::ir::Node;
std
::
vector
<
std
::
string
>
ExtractParameters
(
const
std
::
unordered_set
<
Node
*>
&
nodes
);
std
::
unique_ptr
<
framework
::
ir
::
Graph
>
analysis
::
TensorRtSubgraphPass
::
ApplyImpl
(
void
RenameAndGetOutputs
(
const
std
::
vector
<
framework
::
ir
::
Node
*>
&
subgraph_nodes
,
framework
::
BlockDesc
*
block_desc
,
const
std
::
set
<
std
::
string
>
&
input_names_with_id
,
std
::
set
<
std
::
string
>
*
output_names_with_id
,
std
::
set
<
std
::
string
>
*
output_names
,
std
::
unordered_map
<
std
::
string
,
std
::
string
>
*
output_name_map
);
std
::
unique_ptr
<
framework
::
ir
::
Graph
>
analysis
::
TensorRtSubgraphPass
::
ApplyImpl
(
std
::
unique_ptr
<
framework
::
ir
::
Graph
>
graph
)
const
{
framework
::
ir
::
FusePassBase
::
Init
(
"tensorrt_subgraph_pass"
,
graph
.
get
());
...
...
@@ -47,9 +54,16 @@ std::unique_ptr<framework::ir::Graph> analysis::TensorRtSubgraphPass::ApplyImpl(
Get
<
int
>
(
"min_subgraph_size"
)
/*min subgraph size*/
);
fuser
();
std
::
vector
<
std
::
string
>
graph_param_names
=
ExtractParameters
(
graph
->
Nodes
());
// those parameter already exist in trt, and should not have another copy in
// fluid.
std
::
vector
<
std
::
string
>
repetitive_params
;
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
IsOp
()
&&
!
Agent
(
node
).
subgraph
()
->
empty
())
{
CreateTensorRTOp
(
node
,
graph
.
get
());
CreateTensorRTOp
(
node
,
graph
.
get
(),
graph_param_names
,
&
repetitive_params
);
std
::
unordered_set
<
const
Node
*>
nodes2remove
(
Agent
(
node
).
subgraph
()
->
begin
(),
Agent
(
node
).
subgraph
()
->
end
());
...
...
@@ -64,12 +78,15 @@ std::unique_ptr<framework::ir::Graph> analysis::TensorRtSubgraphPass::ApplyImpl(
}
}
framework
::
ir
::
GraphSafeRemoveNodes
(
graph
.
get
(),
nodes2remove
);
graph
->
Set
(
framework
::
ir
::
kRepetitiveParamAttr
,
new
std
::
vector
<
std
::
string
>
(
repetitive_params
));
return
graph
;
}
std
::
string
GenerateEngineKey
(
const
std
::
set
<
std
::
string
>
&
engine_inputs
,
const
std
::
set
<
std
::
string
>
&
engine_outputs
)
{
const
std
::
set
<
std
::
string
>
&
engine_outputs
,
const
std
::
string
&
predictor_id
)
{
std
::
string
engine_hash_key
=
""
;
for
(
auto
name
:
engine_inputs
)
{
engine_hash_key
+=
name
;
...
...
@@ -77,12 +94,15 @@ std::string GenerateEngineKey(const std::set<std::string> &engine_inputs,
for
(
auto
name
:
engine_outputs
)
{
engine_hash_key
+=
name
;
}
engine_hash_key
+=
predictor_id
;
auto
engine_key
=
std
::
to_string
(
std
::
hash
<
std
::
string
>
()(
engine_hash_key
));
return
engine_key
;
}
void
TensorRtSubgraphPass
::
CreateTensorRTOp
(
framework
::
ir
::
Node
*
node
,
Graph
*
graph
)
const
{
void
TensorRtSubgraphPass
::
CreateTensorRTOp
(
framework
::
ir
::
Node
*
node
,
Graph
*
graph
,
const
std
::
vector
<
std
::
string
>
&
graph_params
,
std
::
vector
<
std
::
string
>
*
repetitive_params
)
const
{
auto
*
op_desc
=
node
->
Op
();
auto
&
subgraph
=
*
Agent
(
node
).
subgraph
();
PADDLE_ENFORCE
(
!
subgraph
.
empty
());
...
...
@@ -116,12 +136,16 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
// is unique.
std
::
set
<
std
::
string
>
input_names
;
std
::
set
<
std
::
string
>
input_names_with_id
;
std
::
vector
<
std
::
string
>
params
;
// The node->inputs containes input tensors and parameters.
for
(
auto
*
x
:
node
->
inputs
)
{
input_names
.
insert
(
x
->
Name
());
input_names_with_id
.
insert
(
x
->
Name
()
+
std
::
to_string
(
x
->
id
()));
if
(
std
::
count
(
graph_params
.
begin
(),
graph_params
.
end
(),
x
->
Name
())
>
0
)
{
params
.
push_back
(
x
->
Name
());
}
}
op_desc
->
SetInput
(
"Xs"
,
std
::
vector
<
std
::
string
>
(
input_names
.
begin
(),
input_names
.
end
()));
std
::
set
<
std
::
string
>
output_names
;
std
::
set
<
std
::
string
>
output_names_with_id
;
...
...
@@ -130,11 +154,8 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
output_names_with_id
.
insert
(
x
->
Name
()
+
std
::
to_string
(
x
->
id
()));
}
op_desc
->
SetOutput
(
"Ys"
,
std
::
vector
<
std
::
string
>
(
output_names
.
begin
(),
output_names
.
end
()));
op_desc
->
SetType
(
"tensorrt_engine"
);
std
::
unordered_map
<
std
::
string
,
std
::
string
>
output_name_map
;
auto
&
subgraph_nodes
=
*
Agent
(
node
).
subgraph
();
// The following procedure is used to rename all the intermediate
// variables and the output variables of the subgraph.
...
...
@@ -148,61 +169,8 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
// input of a OP, but also the output of a Op, there will be problems.
// So we have to rename the variable in the subgraph to make sure
// it is either an OP's input or an OP's output.
auto
&
subgraph_nodes
=
*
Agent
(
node
).
subgraph
();
for
(
size_t
index
=
0
;
index
<
block_desc
.
OpSize
();
++
index
)
{
framework
::
proto
::
OpDesc
*
op
=
block_desc
.
Op
(
index
)
->
Proto
();
auto
correspond_node
=
subgraph_nodes
[
index
];
PADDLE_ENFORCE_EQ
(
correspond_node
->
Name
(),
op
->
type
());
std
::
unordered_map
<
std
::
string
,
size_t
>
var2id
;
for
(
auto
*
in_var
:
correspond_node
->
inputs
)
{
var2id
[
in_var
->
Name
()]
=
in_var
->
id
();
}
// rename for the input variables of op inside subgraph
for
(
int
i
=
0
;
i
<
op
->
inputs_size
();
i
++
)
{
// one input
auto
*
in_var
=
op
->
mutable_inputs
(
i
);
std
::
vector
<
std
::
string
>
replaced_names
;
for
(
int
k
=
0
;
k
<
in_var
->
arguments_size
();
k
++
)
{
// all the arguments
std
::
string
arg_value
=
in_var
->
arguments
(
k
);
std
::
string
arg_value_with_id
=
arg_value
+
std
::
to_string
(
var2id
[
arg_value
]);
if
(
input_names_with_id
.
count
(
arg_value_with_id
))
{
replaced_names
.
push_back
(
arg_value
);
}
else
{
replaced_names
.
push_back
(
arg_value_with_id
);
}
}
in_var
->
clear_arguments
();
for
(
size_t
k
=
0
;
k
<
replaced_names
.
size
();
k
++
)
{
in_var
->
add_arguments
(
replaced_names
[
k
]);
}
}
var2id
.
clear
();
for
(
auto
out_var
:
correspond_node
->
outputs
)
{
var2id
[
out_var
->
Name
()]
=
out_var
->
id
();
}
// rename for the output variables of op inside subgraph
for
(
int
i
=
0
;
i
<
op
->
outputs_size
();
i
++
)
{
framework
::
proto
::
OpDesc_Var
*
out_var
=
op
->
mutable_outputs
(
i
);
std
::
vector
<
std
::
string
>
replaced_names
;
for
(
int
k
=
0
;
k
<
out_var
->
arguments_size
();
k
++
)
{
std
::
string
arg_value
=
out_var
->
arguments
(
k
);
std
::
string
arg_value_with_id
=
arg_value
+
std
::
to_string
(
var2id
[
arg_value
]);
if
(
output_names_with_id
.
count
(
arg_value_with_id
))
{
output_name_map
[
arg_value
]
=
arg_value_with_id
;
}
replaced_names
.
push_back
(
arg_value_with_id
);
}
out_var
->
clear_arguments
();
for
(
size_t
k
=
0
;
k
<
replaced_names
.
size
();
k
++
)
{
out_var
->
add_arguments
(
replaced_names
[
k
]);
}
}
}
RenameAndGetOutputs
(
subgraph_nodes
,
&
block_desc
,
input_names_with_id
,
&
output_names_with_id
,
&
output_names
,
&
output_name_map
);
// When tensorrt engine runs at the end of the operation,
// output_mapping help us copy the data from the renamed ITensor
...
...
@@ -212,6 +180,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
PADDLE_ENFORCE
(
output_name_map
.
count
(
name
)
!=
0
);
output_mapping
.
push_back
(
output_name_map
[
name
]);
}
PADDLE_ENFORCE
(
!
output_mapping
.
empty
());
auto
*
vars
=
block_desc
.
Proto
()
->
mutable_vars
();
for
(
framework
::
ir
::
Node
*
node
:
graph
->
Nodes
())
{
...
...
@@ -222,26 +191,83 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
PADDLE_ENFORCE
(
!
block_desc
.
Proto
()
->
vars
().
empty
(),
"the block has no var-desc"
);
PADDLE_ENFORCE
(
!
output_mapping
.
empty
());
// Set attrs
op_desc
->
SetType
(
"tensorrt_engine"
);
op_desc
->
SetInput
(
"Xs"
,
std
::
vector
<
std
::
string
>
(
input_names
.
begin
(),
input_names
.
end
()));
op_desc
->
SetOutput
(
"Ys"
,
std
::
vector
<
std
::
string
>
(
output_names
.
begin
(),
output_names
.
end
()));
op_desc
->
SetBlockAttr
(
"sub_block"
,
new_block
);
SetAttr
(
op_desc
->
Proto
(),
"subgraph"
,
block_desc
.
Proto
()
->
SerializeAsString
());
// Set attrs
SetAttr
(
op_desc
->
Proto
(),
"max_batch_size"
,
Get
<
int
>
(
"max_batch_size"
));
SetAttr
(
op_desc
->
Proto
(),
"workspace_size"
,
Get
<
int
>
(
"workspace_size"
));
SetAttr
(
op_desc
->
Proto
(),
"parameters"
,
ExtractParameters
(
graph
->
Nodes
()));
SetAttr
(
op_desc
->
Proto
(),
"output_name_mapping"
,
output_mapping
);
SetAttr
(
op_desc
->
Proto
(),
"parameters"
,
params
);
auto
enable_int8
=
Get
<
bool
>
(
"enable_int8"
);
auto
engine_key
=
GenerateEngineKey
(
input_names_with_id
,
output_names_with_id
);
auto
engine_key
=
GenerateEngineKey
(
input_names_with_id
,
output_names_with_id
,
std
::
to_string
(
0
)
);
// Get "" when there is no cached calibration table data.
std
::
string
calibration_data
=
GetTrtCalibTableData
(
Get
<
std
::
string
>
(
"model_opt_cache_dir"
),
engine_key
,
enable_int8
);
SetAttr
(
op_desc
->
Proto
(),
"calibration_data"
,
calibration_data
);
SetAttr
(
op_desc
->
Proto
(),
"enable_int8"
,
enable_int8
);
SetAttr
(
op_desc
->
Proto
(),
"engine_key"
,
engine_key
);
SetAttr
(
op_desc
->
Proto
(),
"engine_serialized_data"
,
std
::
string
(
""
));
std
::
unique_ptr
<
tensorrt
::
TRTInt8Calibrator
>
calibrator
;
if
(
enable_int8
&&
calibration_data
.
size
()
!=
0
)
{
calibrator
.
reset
(
new
tensorrt
::
TRTInt8Calibrator
(
calibration_data
));
}
bool
use_static_engine
=
Get
<
bool
>
(
"use_static_engine"
);
// When in int8 mode and calibration_mode, the program just produce the
// calibration table data.
bool
calibration_mode
=
(
enable_int8
&&
calibration_data
.
size
()
==
0
);
if
(
!
calibration_mode
&&
use_static_engine
)
{
std
::
copy
(
params
.
begin
(),
params
.
end
(),
std
::
back_inserter
(
*
repetitive_params
));
std
::
string
trt_engine_serialized_data
=
GetTrtEngineSerializedData
(
Get
<
std
::
string
>
(
"model_opt_cache_dir"
),
engine_key
);
if
(
trt_engine_serialized_data
.
empty
())
{
LOG
(
INFO
)
<<
"Prepare TRT engine (Optimize model structure, Select OP "
"kernel etc). This process may cost a lot of time."
;
std
::
unique_ptr
<
tensorrt
::
TensorRTEngine
>
trt_engine
(
new
tensorrt
::
TensorRTEngine
(
Get
<
int
>
(
"max_batch_size"
),
Get
<
int
>
(
"workspace_size"
),
enable_int8
,
calibrator
.
get
(),
Get
<
int
>
(
"gpu_device_id"
)));
auto
*
scope
=
param_scope
();
framework
::
BlockDesc
block_desc_temp
(
nullptr
,
block_desc
.
Proto
());
std
::
unordered_set
<
std
::
string
>
param_set
(
params
.
begin
(),
params
.
end
());
inference
::
Singleton
<
inference
::
tensorrt
::
OpConverter
>::
Global
()
.
ConvertBlockToTRTEngine
(
&
block_desc_temp
,
*
scope
,
std
::
vector
<
std
::
string
>
(
input_names
.
begin
(),
input_names
.
end
()),
param_set
,
output_mapping
,
trt_engine
.
get
());
nvinfer1
::
IHostMemory
*
serialized_engine_data
=
trt_engine
->
Serialize
();
trt_engine_serialized_data
=
std
::
string
((
const
char
*
)
serialized_engine_data
->
data
(),
serialized_engine_data
->
size
());
SaveTrtEngineSerializedDataToFile
(
GetTrtEngineSerializedPath
(
Get
<
std
::
string
>
(
"model_opt_cache_dir"
),
engine_key
),
trt_engine_serialized_data
);
}
else
{
LOG
(
INFO
)
<<
"Load TRT Optimized Info from "
<<
GetTrtEngineSerializedPath
(
Get
<
std
::
string
>
(
"model_opt_cache_dir"
),
engine_key
);
}
SetAttr
(
op_desc
->
Proto
(),
"engine_serialized_data"
,
trt_engine_serialized_data
);
}
}
std
::
vector
<
std
::
string
>
ExtractParameters
(
...
...
@@ -253,7 +279,7 @@ std::vector<std::string> ExtractParameters(
for
(
const
auto
&
node
:
nodes
)
{
if
(
!
node
->
IsOp
())
continue
;
std
::
string
op_type
=
node
->
Op
()
->
Type
();
if
(
op_type
==
"feed"
)
{
if
(
op_type
==
"feed"
||
op_type
==
"fetch"
)
{
std
::
vector
<
std
::
string
>
output_names
=
node
->
Op
()
->
OutputArgumentNames
();
std
::
copy
(
output_names
.
begin
(),
output_names
.
end
(),
std
::
back_inserter
(
feed_outputs
));
...
...
@@ -272,6 +298,99 @@ std::vector<std::string> ExtractParameters(
return
parameters
;
}
void
RenameAndGetOutputs
(
const
std
::
vector
<
framework
::
ir
::
Node
*>
&
subgraph_nodes
,
framework
::
BlockDesc
*
block_desc
,
const
std
::
set
<
std
::
string
>
&
input_names_with_id
,
std
::
set
<
std
::
string
>
*
output_names_with_id
,
std
::
set
<
std
::
string
>
*
output_names
,
std
::
unordered_map
<
std
::
string
,
std
::
string
>
*
output_name_map
)
{
//// In the normal case, the paddle-trt exists bug when runing the googlenet.
// When there are more than two convolutions of 1 * 1 with the same input, the
// paddle-tensorrt will do the merging optimization, which fuse those conv
// into one conv, and then trigger bug. So, We should use strategy to avoid
// this optimization for the time being. This bug will be fixed in the future.
std
::
unordered_map
<
std
::
string
/*name*/
,
int
/*ITensor_quote_num*/
>
same_hierarchy_conv2d_num_map
;
for
(
size_t
index
=
0
;
index
<
block_desc
->
OpSize
();
++
index
)
{
framework
::
proto
::
OpDesc
*
op
=
block_desc
->
Op
(
index
)
->
Proto
();
framework
::
OpDesc
op_desc
(
*
op
,
nullptr
);
auto
correspond_node
=
subgraph_nodes
[
index
];
PADDLE_ENFORCE_EQ
(
correspond_node
->
Name
(),
op
->
type
());
std
::
unordered_map
<
std
::
string
,
size_t
>
var2id
;
std
::
unordered_map
<
std
::
string
,
framework
::
ir
::
Node
*>
in_vars
;
for
(
auto
*
in_var
:
correspond_node
->
inputs
)
{
var2id
[
in_var
->
Name
()]
=
in_var
->
id
();
in_vars
[
in_var
->
Name
()]
=
in_var
;
}
// rename for the input variables of op inside subgraph
for
(
int
i
=
0
;
i
<
op
->
inputs_size
();
i
++
)
{
// one input
auto
*
in_var
=
op
->
mutable_inputs
(
i
);
std
::
vector
<
std
::
string
>
replaced_names
;
for
(
int
k
=
0
;
k
<
in_var
->
arguments_size
();
k
++
)
{
// all the arguments
std
::
string
arg_value
=
in_var
->
arguments
(
k
);
std
::
string
arg_value_with_id
=
arg_value
+
std
::
to_string
(
var2id
[
arg_value
]);
if
(
input_names_with_id
.
count
(
arg_value_with_id
))
{
replaced_names
.
push_back
(
arg_value
);
}
else
{
replaced_names
.
push_back
(
arg_value_with_id
);
}
}
in_var
->
clear_arguments
();
for
(
size_t
k
=
0
;
k
<
replaced_names
.
size
();
k
++
)
{
in_var
->
add_arguments
(
replaced_names
[
k
]);
}
}
var2id
.
clear
();
for
(
auto
out_var
:
correspond_node
->
outputs
)
{
var2id
[
out_var
->
Name
()]
=
out_var
->
id
();
}
if
(
op_desc
.
Type
()
==
"conv2d"
)
{
auto
input_var_name
=
op_desc
.
Input
(
"Input"
).
front
();
auto
filter_var_name
=
op_desc
.
Input
(
"Filter"
).
front
();
auto
out_var_name
=
op_desc
.
Output
(
"Output"
).
front
();
auto
filter_shape
=
in_vars
[
filter_var_name
]
->
Var
()
->
GetShape
();
const
std
::
vector
<
int
>
strides
=
boost
::
get
<
std
::
vector
<
int
>>
(
op_desc
.
GetAttr
(
"strides"
));
const
std
::
vector
<
int
>
paddings
=
boost
::
get
<
std
::
vector
<
int
>>
(
op_desc
.
GetAttr
(
"paddings"
));
if
(
same_hierarchy_conv2d_num_map
[
input_var_name
]
>
0
)
{
(
*
output_names_with_id
)
.
insert
(
out_var_name
+
std
::
to_string
(
var2id
[
out_var_name
]));
(
*
output_names
).
insert
(
out_var_name
);
}
else
if
(
filter_shape
[
2
]
==
1
&&
filter_shape
[
3
]
==
1
&&
strides
[
0
]
==
1
&&
strides
[
1
]
==
1
&&
paddings
[
0
]
==
0
&&
paddings
[
1
]
==
0
)
{
same_hierarchy_conv2d_num_map
[
input_var_name
]
+=
1
;
}
}
// rename for the output variables of op inside subgraph
for
(
int
i
=
0
;
i
<
op
->
outputs_size
();
i
++
)
{
framework
::
proto
::
OpDesc_Var
*
out_var
=
op
->
mutable_outputs
(
i
);
std
::
vector
<
std
::
string
>
replaced_names
;
for
(
int
k
=
0
;
k
<
out_var
->
arguments_size
();
k
++
)
{
std
::
string
arg_value
=
out_var
->
arguments
(
k
);
std
::
string
arg_value_with_id
=
arg_value
+
std
::
to_string
(
var2id
[
arg_value
]);
if
(
output_names_with_id
->
count
(
arg_value_with_id
))
{
(
*
output_name_map
)[
arg_value
]
=
arg_value_with_id
;
}
replaced_names
.
push_back
(
arg_value_with_id
);
}
out_var
->
clear_arguments
();
for
(
size_t
k
=
0
;
k
<
replaced_names
.
size
();
k
++
)
{
out_var
->
add_arguments
(
replaced_names
[
k
]);
}
}
}
}
}
// namespace analysis
}
// namespace inference
}
// namespace paddle
...
...
paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h
浏览文件 @
3d63aa0a
...
...
@@ -13,7 +13,12 @@
// limitations under the License.
#pragma once
#include <paddle/fluid/framework/ir/fuse_pass_base.h>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace
paddle
{
...
...
@@ -26,8 +31,9 @@ class TensorRtSubgraphPass : public framework::ir::FusePassBase {
std
::
unique_ptr
<
framework
::
ir
::
Graph
>
graph
)
const
override
;
private:
void
CreateTensorRTOp
(
framework
::
ir
::
Node
*
x
,
framework
::
ir
::
Graph
*
graph
)
const
;
void
CreateTensorRTOp
(
framework
::
ir
::
Node
*
x
,
framework
::
ir
::
Graph
*
graph
,
const
std
::
vector
<
std
::
string
>
&
graph_params
,
std
::
vector
<
std
::
string
>
*
repetitive_params
)
const
;
void
CleanIntermediateOutputs
(
framework
::
ir
::
Node
*
node
);
};
...
...
paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc
浏览文件 @
3d63aa0a
...
...
@@ -31,6 +31,13 @@ void IrParamsSyncAmongDevicesPass::RunImpl(Argument *argument) {
// The parameters are on the cpu, therefore, synchronization is not necessary.
if
(
!
argument
->
use_gpu
())
return
;
auto
&
graph
=
argument
->
main_graph
();
std
::
vector
<
std
::
string
>
repetitive_params
;
if
(
graph
.
Has
(
framework
::
ir
::
kRepetitiveParamAttr
))
repetitive_params
=
graph
.
Get
<
std
::
vector
<
std
::
string
>>
(
framework
::
ir
::
kRepetitiveParamAttr
);
LOG
(
INFO
)
<<
"Sync params from CPU to GPU"
;
PADDLE_ENFORCE
(
argument
->
gpu_device_id_valid
());
...
...
@@ -43,6 +50,10 @@ void IrParamsSyncAmongDevicesPass::RunImpl(Argument *argument) {
// Because there exists the case that new parameter variables are not added to
// the program in the analysis pass.
for
(
auto
&
var_name
:
all_vars
)
{
if
(
std
::
count
(
repetitive_params
.
begin
(),
repetitive_params
.
end
(),
var_name
))
{
continue
;
}
auto
*
var
=
scope
->
FindLocalVar
(
var_name
);
PADDLE_ENFORCE
(
var
!=
nullptr
);
if
(
var
->
IsType
<
framework
::
LoDTensor
>
()
||
...
...
paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.h
浏览文件 @
3d63aa0a
...
...
@@ -17,6 +17,7 @@
#include <string>
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/analysis/analysis_pass.h"
#include "paddle/fluid/platform/place.h"
...
...
paddle/fluid/inference/api/analysis_config.cc
浏览文件 @
3d63aa0a
...
...
@@ -103,6 +103,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
CP_MEMBER
(
tensorrt_max_batchsize_
);
CP_MEMBER
(
tensorrt_min_subgraph_size_
);
CP_MEMBER
(
tensorrt_precision_mode_
);
CP_MEMBER
(
trt_use_static_engine_
);
// MKLDNN related.
CP_MEMBER
(
use_mkldnn_
);
CP_MEMBER
(
mkldnn_enabled_op_types_
);
...
...
@@ -144,7 +145,7 @@ void AnalysisConfig::EnableMKLDNN() {
void
AnalysisConfig
::
EnableTensorRtEngine
(
int
workspace_size
,
int
max_batch_size
,
int
min_subgraph_size
,
AnalysisConfig
::
Precision
precision_mode
)
{
AnalysisConfig
::
Precision
precision_mode
,
bool
use_static
)
{
#ifdef PADDLE_WITH_CUDA
if
(
!
use_gpu
())
{
LOG
(
ERROR
)
<<
"To use TensorRT engine, please call EnableGpu() first"
;
...
...
@@ -156,6 +157,7 @@ void AnalysisConfig::EnableTensorRtEngine(
tensorrt_max_batchsize_
=
max_batch_size
;
tensorrt_min_subgraph_size_
=
min_subgraph_size
;
tensorrt_precision_mode_
=
precision_mode
;
trt_use_static_engine_
=
use_static
;
Update
();
#else
...
...
paddle/fluid/inference/api/analysis_predictor.cc
浏览文件 @
3d63aa0a
...
...
@@ -365,6 +365,7 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
argument_
.
SetTensorRtMaxBatchSize
(
config_
.
tensorrt_max_batchsize_
);
argument_
.
SetTensorRtMinSubgraphSize
(
config_
.
tensorrt_min_subgraph_size_
);
argument_
.
SetTensorRtPrecisionMode
(
config_
.
tensorrt_precision_mode_
);
argument_
.
SetTensorRtUseStaticEngine
(
config_
.
trt_use_static_engine_
);
}
if
(
config_
.
use_mkldnn_
)
{
...
...
@@ -438,12 +439,14 @@ void AnalysisPredictor::PrepareFeedFetch() {
}
feeds_
[
idx
]
=
op
;
feed_names_
[
op
->
Output
(
"Out"
)[
0
]]
=
idx
;
idx2feeds_
[
idx
]
=
op
->
Output
(
"Out"
)[
0
];
}
else
if
(
op
->
Type
()
==
"fetch"
)
{
int
idx
=
boost
::
get
<
int
>
(
op
->
GetAttr
(
"col"
));
if
(
fetches_
.
size
()
<=
static_cast
<
size_t
>
(
idx
))
{
fetches_
.
resize
(
idx
+
1
);
}
fetches_
[
idx
]
=
op
;
idx2fetches_
[
idx
]
=
op
->
Input
(
"X"
)[
0
];
}
}
}
...
...
@@ -456,6 +459,22 @@ void AnalysisPredictor::CreateFeedFetchVar(framework::Scope *scope) {
var
->
GetMutable
<
framework
::
FeedFetchList
>
();
}
std
::
vector
<
std
::
string
>
AnalysisPredictor
::
GetInputNames
()
{
std
::
vector
<
std
::
string
>
input_names
;
for
(
auto
&
item
:
idx2feeds_
)
{
input_names
.
push_back
(
item
.
second
);
}
return
input_names
;
}
std
::
vector
<
std
::
string
>
AnalysisPredictor
::
GetOutputNames
()
{
std
::
vector
<
std
::
string
>
output_names
;
for
(
auto
&
item
:
idx2fetches_
)
{
output_names
.
push_back
(
item
.
second
);
}
return
output_names
;
}
std
::
unique_ptr
<
ZeroCopyTensor
>
AnalysisPredictor
::
GetInputTensor
(
const
std
::
string
&
name
)
{
PADDLE_ENFORCE
(
executor_
->
scope
()
->
FindVar
(
name
),
"no name called %s"
,
name
);
...
...
@@ -463,6 +482,13 @@ std::unique_ptr<ZeroCopyTensor> AnalysisPredictor::GetInputTensor(
new
ZeroCopyTensor
(
static_cast
<
void
*>
(
executor_
->
scope
())));
res
->
input_or_output_
=
true
;
res
->
SetName
(
name
);
if
(
platform
::
is_cpu_place
(
place_
))
{
res
->
SetPlace
(
PaddlePlace
::
kCPU
);
}
else
{
auto
gpu_place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
place_
);
res
->
SetPlace
(
PaddlePlace
::
kGPU
,
gpu_place
.
GetDeviceId
());
}
return
res
;
}
...
...
@@ -473,6 +499,12 @@ std::unique_ptr<ZeroCopyTensor> AnalysisPredictor::GetOutputTensor(
new
ZeroCopyTensor
(
static_cast
<
void
*>
(
executor_
->
scope
())));
res
->
input_or_output_
=
false
;
res
->
SetName
(
name
);
if
(
platform
::
is_cpu_place
(
place_
))
{
res
->
SetPlace
(
PaddlePlace
::
kCPU
);
}
else
{
auto
gpu_place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
place_
);
res
->
SetPlace
(
PaddlePlace
::
kGPU
,
gpu_place
.
GetDeviceId
());
}
return
res
;
}
...
...
paddle/fluid/inference/api/analysis_predictor.h
浏览文件 @
3d63aa0a
...
...
@@ -15,12 +15,14 @@
#pragma once
#include <algorithm>
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/naive_executor.h"
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/api/api_impl.h"
#include "paddle/fluid/inference/api/details/reset_tensor_array.h"
#include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/string/printf.h"
#ifdef PADDLE_WITH_TESTING
...
...
@@ -53,6 +55,9 @@ class AnalysisPredictor : public PaddlePredictor {
std
::
vector
<
PaddleTensor
>
*
output_data
,
int
batch_size
=
-
1
)
override
;
std
::
vector
<
std
::
string
>
GetInputNames
();
std
::
vector
<
std
::
string
>
GetOutputNames
();
std
::
unique_ptr
<
ZeroCopyTensor
>
GetInputTensor
(
const
std
::
string
&
name
)
override
;
std
::
unique_ptr
<
ZeroCopyTensor
>
GetOutputTensor
(
...
...
@@ -131,7 +136,11 @@ class AnalysisPredictor : public PaddlePredictor {
std
::
shared_ptr
<
framework
::
ProgramDesc
>
inference_program_
;
std
::
vector
<
framework
::
OpDesc
*>
feeds_
;
std
::
map
<
std
::
string
,
size_t
>
feed_names_
;
// Sorted according to the idx.
std
::
map
<
size_t
,
std
::
string
>
idx2feeds_
;
std
::
vector
<
framework
::
OpDesc
*>
fetches_
;
std
::
map
<
size_t
,
std
::
string
>
idx2fetches_
;
// Memory buffer for feed inputs. The temporary LoDTensor will cause serious
// concurrency problems, wrong results and memory leak, so cache them.
std
::
vector
<
framework
::
LoDTensor
>
feed_tensors_
;
...
...
paddle/fluid/inference/api/details/zero_copy_tensor.cc
浏览文件 @
3d63aa0a
...
...
@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
...
...
@@ -73,6 +74,61 @@ T *ZeroCopyTensor::data(PaddlePlace *place, int *size) const {
return
res
;
}
template
<
typename
T
>
void
ZeroCopyTensor
::
copy_from_cpu
(
const
T
*
data
)
{
EAGER_GET_TENSOR
;
PADDLE_ENFORCE_GE
(
tensor
->
numel
(),
0
,
"You should call ZeroCopyTensor::Reshape(const std::vector<int> &shape)"
"function before copy data from cpu."
);
size_t
ele_size
=
tensor
->
numel
()
*
sizeof
(
T
);
if
(
place_
==
PaddlePlace
::
kCPU
)
{
auto
*
t_data
=
tensor
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
std
::
memcpy
(
static_cast
<
void
*>
(
t_data
),
data
,
ele_size
);
}
else
{
#ifdef PADDLE_WITH_CUDA
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
CUDAPlace
gpu_place
(
device_
);
auto
*
t_data
=
tensor
->
mutable_data
<
T
>
(
gpu_place
);
auto
*
dev_ctx
=
static_cast
<
const
platform
::
CUDADeviceContext
*>
(
pool
.
Get
(
gpu_place
));
memory
::
Copy
(
gpu_place
,
static_cast
<
void
*>
(
t_data
),
platform
::
CPUPlace
(),
data
,
ele_size
,
dev_ctx
->
stream
());
#else
PADDLE_THROW
(
"Not compile with CUDA, should not reach here."
);
#endif
}
}
template
<
typename
T
>
void
ZeroCopyTensor
::
copy_to_cpu
(
T
*
data
)
{
EAGER_GET_TENSOR
;
auto
ele_num
=
tensor
->
numel
();
auto
*
t_data
=
tensor
->
data
<
T
>
();
auto
t_place
=
tensor
->
place
();
if
(
platform
::
is_cpu_place
(
t_place
))
{
std
::
memcpy
(
static_cast
<
void
*>
(
data
),
t_data
,
ele_num
*
sizeof
(
T
));
}
else
{
#ifdef PADDLE_WITH_CUDA
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
gpu_place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
t_place
);
auto
*
dev_ctx
=
static_cast
<
const
platform
::
CUDADeviceContext
*>
(
pool
.
Get
(
gpu_place
));
memory
::
Copy
(
platform
::
CPUPlace
(),
static_cast
<
void
*>
(
data
),
gpu_place
,
t_data
,
ele_num
*
sizeof
(
T
),
dev_ctx
->
stream
());
#else
PADDLE_THROW
(
"Not compile with CUDA, should not reach here."
);
#endif
}
}
template
void
ZeroCopyTensor
::
copy_from_cpu
<
float
>(
const
float
*
data
);
template
void
ZeroCopyTensor
::
copy_from_cpu
<
int64_t
>(
const
int64_t
*
data
);
template
void
ZeroCopyTensor
::
copy_to_cpu
<
float
>(
float
*
data
);
template
void
ZeroCopyTensor
::
copy_to_cpu
<
int64_t
>(
int64_t
*
data
);
template
float
*
ZeroCopyTensor
::
data
<
float
>(
PaddlePlace
*
place
,
int
*
size
)
const
;
template
int64_t
*
ZeroCopyTensor
::
data
<
int64_t
>(
PaddlePlace
*
place
,
...
...
@@ -92,10 +148,10 @@ void *ZeroCopyTensor::FindTensor() const {
return
tensor
;
}
std
::
vector
<
int
64_t
>
ZeroCopyTensor
::
shape
()
const
{
std
::
vector
<
int
>
ZeroCopyTensor
::
shape
()
const
{
EAGER_GET_TENSOR
;
PADDLE_ENFORCE
(
tensor_
,
"not found tensor called %s in the scope"
,
name_
);
return
framework
::
vectorize
(
tensor
->
dims
());
return
framework
::
vectorize
2int
(
tensor
->
dims
());
}
void
ZeroCopyTensor
::
SetLoD
(
const
std
::
vector
<
std
::
vector
<
size_t
>>
&
x
)
{
...
...
paddle/fluid/inference/api/details/zero_copy_tensor_dummy.cc
浏览文件 @
3d63aa0a
...
...
@@ -37,7 +37,7 @@ template int64_t *ZeroCopyTensor::mutable_data(PaddlePlace place);
void
*
ZeroCopyTensor
::
FindTensor
()
const
{
return
nullptr
;
}
std
::
vector
<
int
64_t
>
ZeroCopyTensor
::
shape
()
const
{
return
{};
}
std
::
vector
<
int
>
ZeroCopyTensor
::
shape
()
const
{
return
{};
}
void
ZeroCopyTensor
::
SetLoD
(
const
std
::
vector
<
std
::
vector
<
size_t
>>
&
x
)
{}
...
...
paddle/fluid/inference/api/helper.h
浏览文件 @
3d63aa0a
...
...
@@ -50,6 +50,11 @@ class Timer {
}
};
static
int
GetUniqueId
()
{
static
int
id
=
0
;
return
id
++
;
}
static
void
split
(
const
std
::
string
&
str
,
char
sep
,
std
::
vector
<
std
::
string
>
*
pieces
)
{
pieces
->
clear
();
...
...
paddle/fluid/inference/api/paddle_analysis_config.h
浏览文件 @
3d63aa0a
...
...
@@ -135,7 +135,8 @@ struct AnalysisConfig {
*/
void
EnableTensorRtEngine
(
int
workspace_size
=
1
<<
20
,
int
max_batch_size
=
1
,
int
min_subgraph_size
=
3
,
Precision
precision
=
Precision
::
kFloat32
);
Precision
precision
=
Precision
::
kFloat32
,
bool
use_static
=
true
);
/** A boolean state telling whether the TensorRT engine is used.
*/
bool
tensorrt_engine_enabled
()
const
{
return
use_tensorrt_
;
}
...
...
@@ -233,6 +234,7 @@ struct AnalysisConfig {
// subgraph, 3 as default value.
int
tensorrt_min_subgraph_size_
{
3
};
Precision
tensorrt_precision_mode_
;
bool
trt_use_static_engine_
;
// memory reuse related.
bool
enable_memory_optim_
{
false
};
...
...
paddle/fluid/inference/api/paddle_api.h
浏览文件 @
3d63aa0a
...
...
@@ -160,11 +160,21 @@ class ZeroCopyTensor {
template
<
typename
T
>
T
*
data
(
PaddlePlace
*
place
,
int
*
size
)
const
;
std
::
vector
<
int64_t
>
shape
()
const
;
template
<
typename
T
>
void
copy_from_cpu
(
const
T
*
data
);
template
<
typename
T
>
void
copy_to_cpu
(
T
*
data
);
std
::
vector
<
int
>
shape
()
const
;
void
SetLoD
(
const
std
::
vector
<
std
::
vector
<
size_t
>>&
x
);
std
::
vector
<
std
::
vector
<
size_t
>>
lod
()
const
;
const
std
::
string
&
name
()
const
{
return
name_
;
}
void
SetPlace
(
PaddlePlace
place
,
int
device
=
-
1
)
{
place_
=
place
;
device_
=
device
;
}
protected:
explicit
ZeroCopyTensor
(
void
*
scope
)
:
scope_
{
scope
}
{}
...
...
@@ -179,6 +189,8 @@ class ZeroCopyTensor {
// The corresponding tensor pointer inside Paddle workspace is cached for
// performance.
mutable
void
*
tensor_
{
nullptr
};
PaddlePlace
place_
;
int
device_
;
};
/** A simple Inference API for Paddle.
...
...
@@ -200,6 +212,14 @@ class PaddlePredictor {
std
::
vector
<
PaddleTensor
>*
output_data
,
int
batch_size
=
-
1
)
=
0
;
/** \brief Get input names of the model
*/
virtual
std
::
vector
<
std
::
string
>
GetInputNames
()
{
return
{};
}
/** \brief Get output names of the model
*/
virtual
std
::
vector
<
std
::
string
>
GetOutputNames
()
{
return
{};
}
/** \brief Get a mutable tensor directly.
*
* NOTE Only works in AnalysisPredictor.
...
...
paddle/fluid/inference/engine.h
浏览文件 @
3d63aa0a
...
...
@@ -49,11 +49,6 @@ class EngineBase {
// Execute the engine, that will run the inference network.
virtual
void
Execute
(
int
batch_size
)
=
0
;
// Return the IO buffer that allocated in engine. One can read/write directly
// on the buffer. If the buffer's buffer is nullptr, one can also allocate
// memory and maintain it outside the engine.
virtual
Buffer
&
buffer
(
const
std
::
string
&
name
)
=
0
;
virtual
~
EngineBase
()
{}
};
// class EngineBase
...
...
paddle/fluid/inference/tensorrt/convert/conv2d_op.cc
浏览文件 @
3d63aa0a
...
...
@@ -18,21 +18,6 @@ namespace paddle {
namespace
inference
{
namespace
tensorrt
{
bool
to_skip_merging_optimize
(
TensorRTEngine
*
engine
,
const
std
::
vector
<
int
>&
filters
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
std
::
string
input_name
)
{
if
(
engine
->
itensor_quote_num
[
input_name
]
>
0
)
{
return
true
;
}
if
(
filters
[
0
]
==
1
&&
filters
[
1
]
==
1
&&
strides
[
0
]
==
1
&&
strides
[
1
]
==
1
&&
paddings
[
0
]
==
0
&&
paddings
[
1
]
==
0
)
engine
->
itensor_quote_num
[
input_name
]
+=
1
;
return
false
;
}
template
<
typename
RegistFunc
,
typename
SetDilationFunc
>
void
ConvertConv2d
(
TensorRTEngine
*
engine
,
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
,
...
...
@@ -59,7 +44,7 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op,
weight_tensor
->
Resize
(
Y_t
->
dims
());
TensorCopySync
((
*
Y_t
),
cpu_place
,
weight_tensor
.
get
());
auto
*
weight_data
=
weight_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
()
);
auto
*
weight_data
=
weight_tensor
->
mutable_data
<
float
>
(
cpu_place
);
PADDLE_ENFORCE_EQ
(
weight_tensor
->
dims
().
size
(),
4UL
);
const
int
n_output
=
weight_tensor
->
dims
()[
0
];
...
...
@@ -100,9 +85,7 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op,
layer
->
getOutput
(
0
)
->
setName
(
output_name
.
c_str
());
engine
->
SetITensor
(
output_name
,
layer
->
getOutput
(
0
));
if
(
test_mode
||
to_skip_merging_optimize
(
engine
,
{
filter_h
,
filter_w
},
strides
,
paddings
,
op_desc
.
Input
(
"Input"
).
front
()))
{
if
(
test_mode
)
{
engine
->
DeclareOutput
(
output_name
);
}
}
...
...
paddle/fluid/inference/tensorrt/convert/elementwise_op.cc
浏览文件 @
3d63aa0a
...
...
@@ -153,7 +153,6 @@ class ElementwiseTensorOpConverter : public OpConverter {
if
(
CheckDims
(
dims_x
,
dims_y
))
{
// The two input tensor should have the same dims
VLOG
(
3
)
<<
"Convert a fluid elementwise op to TensorRT IElementWiseLayer"
;
nvinfer1
::
IElementWiseLayer
*
layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
ElementWise
,
*
const_cast
<
nvinfer1
::
ITensor
*>
(
X
),
*
const_cast
<
nvinfer1
::
ITensor
*>
(
Y
),
op_pair
->
second
);
...
...
@@ -166,7 +165,7 @@ class ElementwiseTensorOpConverter : public OpConverter {
"ElementWisePluginLayer"
;
plugin
::
ElementWisePlugin
*
plugin
=
new
plugin
::
ElementWisePlugin
(
op_
pair
->
second
,
dims_x
,
dims_y
,
axis
);
new
plugin
::
ElementWisePlugin
(
op_
type_
,
dims_x
,
dims_y
,
axis
);
plugin
->
AddInput
(
X
);
plugin
->
AddInput
(
Y
);
nvinfer1
::
IPluginLayer
*
layer
=
engine_
->
AddPlugin
(
...
...
paddle/fluid/inference/tensorrt/convert/fc_op.cc
浏览文件 @
3d63aa0a
...
...
@@ -85,10 +85,10 @@ class FcOpConverter : public OpConverter {
Y_t
->
dims
()[
0
]
*
Y_t
->
dims
()[
1
]
*
sizeof
(
float
));
TensorRTEngine
::
Weight
weight
{
nvinfer1
::
DataType
::
kFLOAT
,
static_cast
<
void
*>
(
weight_data
),
Y_t
->
memory_size
()
/
sizeof
(
float
)};
static_cast
<
size_t
>
(
Y_t
->
numel
()
)};
TensorRTEngine
::
Weight
tmp_weight
(
nvinfer1
::
DataType
::
kFLOAT
,
static_cast
<
void
*>
(
tmp
->
data
<
float
>
()),
Y_t
->
memory_size
()
/
sizeof
(
float
));
static_cast
<
size_t
>
(
Y_t
->
numel
()
));
weight
.
dims
.
assign
({
Y_t
->
dims
()[
0
],
Y_t
->
dims
()[
1
]});
tmp_weight
.
dims
=
weight
.
dims
;
...
...
paddle/fluid/inference/tensorrt/convert/op_converter.h
浏览文件 @
3d63aa0a
...
...
@@ -16,9 +16,12 @@ limitations under the License. */
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/utils/singleton.h"
...
...
@@ -26,6 +29,37 @@ namespace paddle {
namespace
inference
{
namespace
tensorrt
{
using
FluidDT
=
framework
::
proto
::
VarType_Type
;
using
TRT_DT
=
nvinfer1
::
DataType
;
namespace
{
// NOLINT
TRT_DT
FluidDataType2TRT
(
FluidDT
type
)
{
switch
(
type
)
{
case
FluidDT
::
VarType_Type_FP32
:
return
TRT_DT
::
kFLOAT
;
case
FluidDT
::
VarType_Type_INT32
:
return
TRT_DT
::
kINT32
;
default:
return
TRT_DT
::
kINT32
;
}
PADDLE_THROW
(
"unkown type"
);
return
TRT_DT
::
kINT32
;
}
nvinfer1
::
Dims
Vec2TRT_Dims
(
const
std
::
vector
<
int64_t
>&
shape
)
{
PADDLE_ENFORCE_GT
(
shape
.
size
(),
1UL
,
"TensorRT' tensor input requires at least 2 dimensions"
);
PADDLE_ENFORCE_LE
(
shape
.
size
(),
4UL
,
"TensorRT' tensor input requires at most 4 dimensions"
);
PADDLE_ENFORCE
(
shape
.
size
()
==
4UL
||
shape
.
size
()
==
2UL
);
if
(
shape
.
size
()
==
4UL
)
return
nvinfer1
::
DimsCHW
(
shape
[
1
],
shape
[
2
],
shape
[
3
]);
return
nvinfer1
::
DimsCHW
(
shape
[
1
],
1
,
1
);
}
}
// namespace // NOLINT
/*
* Convert Op from Fluid to TensorRT Engine.
*/
...
...
@@ -110,6 +144,34 @@ class OpConverter {
}
}
// The scope here should be inited with the parameter vars.
void
ConvertBlockToTRTEngine
(
framework
::
BlockDesc
*
block_desc
,
const
framework
::
Scope
&
scope
,
const
std
::
vector
<
std
::
string
>&
inputs
,
const
std
::
unordered_set
<
std
::
string
>&
parameters
,
const
std
::
vector
<
std
::
string
>&
outputs
,
TensorRTEngine
*
engine
)
{
engine
->
InitNetwork
();
for
(
auto
&
input
:
inputs
)
{
if
(
parameters
.
count
(
input
))
continue
;
auto
*
var
=
block_desc
->
FindVar
(
input
);
PADDLE_ENFORCE
(
var
,
"no variable called %s"
,
input
);
PADDLE_ENFORCE_EQ
(
var
->
GetType
(),
FluidDT
::
VarType_Type_LOD_TENSOR
,
"TensorRT engine only takes LoDTensor as input"
);
auto
var_shape
=
var
->
GetShape
();
engine
->
DeclareInput
(
input
,
FluidDataType2TRT
(
var
->
Proto
()
->
type
().
lod_tensor
().
tensor
().
data_type
()),
Vec2TRT_Dims
(
var_shape
));
}
framework
::
proto
::
BlockDesc
*
block_proto
=
block_desc
->
Proto
();
ConvertBlock
(
*
block_proto
,
parameters
,
scope
,
engine
);
for
(
auto
&
output
:
outputs
)
{
engine
->
DeclareOutput
(
output
);
}
engine
->
FreezeNetwork
();
}
void
SetEngine
(
TensorRTEngine
*
engine
)
{
engine_
=
engine
;
}
virtual
~
OpConverter
()
{}
...
...
paddle/fluid/inference/tensorrt/convert/prelu_op.cc
浏览文件 @
3d63aa0a
...
...
@@ -43,23 +43,20 @@ class PReluOpConverter : public OpConverter {
PADDLE_ENFORCE_NOT_NULL
(
alpha_var
);
auto
*
alpha_tensor
=
alpha_var
->
GetMutable
<
framework
::
LoDTensor
>
();
platform
::
C
UDAPlace
place
;
std
::
unique_ptr
<
framework
::
LoDTensor
>
alpha_tensor_
device
(
platform
::
C
PUPlace
cpu_
place
;
std
::
unique_ptr
<
framework
::
LoDTensor
>
alpha_tensor_
temp
(
new
framework
::
LoDTensor
());
alpha_tensor_
device
->
Resize
(
alpha_tensor
->
dims
());
TensorCopySync
(
*
alpha_tensor
,
place
,
alpha_tensor_device
.
get
());
float
*
alpha_data
=
alpha_tensor_
device
->
mutable_data
<
float
>
(
place
);
alpha_tensor_
temp
->
Resize
(
alpha_tensor
->
dims
());
TensorCopySync
(
*
alpha_tensor
,
cpu_place
,
alpha_tensor_temp
.
get
());
float
*
alpha_data
=
alpha_tensor_
temp
->
mutable_data
<
float
>
(
cpu_
place
);
// Transform alpha to TensorRTEngine::Weight
TensorRTEngine
::
Weight
alpha_rt
(
nvinfer1
::
DataType
::
kFLOAT
,
static_cast
<
void
*>
(
alpha_data
),
alpha_tensor_device
->
numel
());
plugin
::
PReluPlugin
*
plugin
=
new
plugin
::
PReluPlugin
(
alpha_rt
,
mode
);
plugin
::
PReluPlugin
*
plugin
=
new
plugin
::
PReluPlugin
(
alpha_data
,
alpha_tensor_temp
->
numel
(),
mode
);
nvinfer1
::
IPluginLayer
*
layer
=
engine_
->
AddPlugin
(
&
input
,
input_num
,
plugin
);
// keep alpha tensor to avoid release it's memory
engine_
->
weight_map
[
op_desc
.
Input
(
"Alpha"
)[
0
]]
=
std
::
move
(
alpha_tensor_
device
);
std
::
move
(
alpha_tensor_
temp
);
std
::
string
layer_name
=
"prelu (Output: "
;
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
...
...
paddle/fluid/inference/tensorrt/convert/ut_helper.h
浏览文件 @
3d63aa0a
...
...
@@ -19,7 +19,9 @@ limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
...
...
@@ -79,7 +81,8 @@ class TRTConvertValidation {
if_add_batch_
(
if_add_batch
),
max_batch_size_
(
max_batch_size
)
{
PADDLE_ENFORCE_EQ
(
cudaStreamCreate
(
&
stream_
),
0
);
engine_
.
reset
(
new
TensorRTEngine
(
max_batch_size
,
workspace_size
,
stream_
));
engine_
.
reset
(
new
TensorRTEngine
(
max_batch_size
,
workspace_size
,
false
,
nullptr
,
0
));
engine_
->
InitNetwork
();
}
...
...
@@ -114,13 +117,12 @@ class TRTConvertValidation {
}
void
DeclVar
(
const
std
::
string
&
name
,
const
std
::
vector
<
int
>
dim_vec
)
{
platform
::
CUDAPlace
place
;
platform
::
CUDADeviceContext
ctx
(
place
);
platform
::
CUDADeviceContext
ctx
(
place_
);
auto
*
x
=
scope_
.
Var
(
name
);
auto
*
x_tensor
=
x
->
GetMutable
<
framework
::
LoDTensor
>
();
x_tensor
->
Resize
(
framework
::
make_ddim
(
dim_vec
));
RandomizeTensor
(
x_tensor
,
place
,
ctx
);
RandomizeTensor
(
x_tensor
,
place
_
,
ctx
);
}
// Declare a variable in a fluid Scope.
void
DeclVar
(
const
std
::
string
&
name
,
const
nvinfer1
::
Dims
&
dims
,
...
...
@@ -146,19 +148,6 @@ class TRTConvertValidation {
// Declare outputs.
op_desc_
.
reset
(
new
framework
::
OpDesc
(
desc
,
nullptr
));
// Set Inputs.
for
(
const
auto
&
input
:
op_desc_
->
InputArgumentNames
())
{
if
(
parameters_
.
count
(
input
))
continue
;
auto
*
var
=
scope_
.
FindVar
(
input
);
PADDLE_ENFORCE
(
var
);
auto
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
engine_
->
SetInputFromGPU
(
input
,
static_cast
<
void
*>
(
tensor
->
data
<
void
>
()),
sizeof
(
float
)
*
analysis
::
AccuDims
(
tensor
->
dims
(),
tensor
->
dims
().
size
()));
}
}
// We use the set 'neglected_output' here, because some Ops like batch norm,
...
...
@@ -168,43 +157,71 @@ class TRTConvertValidation {
std
::
unordered_set
<
std
::
string
>
neglected_output
=
{})
{
// Execute Fluid Op
PADDLE_ENFORCE_LE
(
batch_size
,
max_batch_size_
);
platform
::
CUDAPlace
place
;
platform
::
CUDADeviceContext
ctx
(
place
);
op_
->
Run
(
scope_
,
place
);
// Execute TRT.
engine_
->
Execute
(
batch_size
);
cudaStreamSynchronize
(
engine_
->
stream
());
platform
::
CUDADeviceContext
ctx
(
place_
);
op_
->
Run
(
scope_
,
place_
);
ASSERT_FALSE
(
op_desc_
->
OutputArgumentNames
().
empty
());
const
size_t
output_space_size
=
3000
;
std
::
vector
<
std
::
string
>
input_output_names
;
// Note: we need filter the parameter
for
(
const
auto
&
input
:
op_desc_
->
InputArgumentNames
())
{
if
(
parameters_
.
count
(
input
))
continue
;
input_output_names
.
push_back
(
input
);
}
// Collect the fluid outputs.
std
::
vector
<
std
::
vector
<
float
>>
fluid_outs
;
for
(
const
auto
&
output
:
op_desc_
->
OutputArgumentNames
())
{
if
(
neglected_output
.
count
(
output
))
continue
;
input_output_names
.
push_back
(
output
);
std
::
vector
<
float
>
fluid_out
;
std
::
vector
<
float
>
trt_out
(
output_space_size
);
engine_
->
GetOutputInCPU
(
output
,
&
trt_out
[
0
],
output_space_size
);
cudaStreamSynchronize
(
engine_
->
stream
());
auto
*
var
=
scope_
.
FindVar
(
output
);
auto
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
framework
::
TensorToVector
(
*
tensor
,
ctx
,
&
fluid_out
);
fluid_outs
.
push_back
(
fluid_out
);
}
// Bind input and output for TRT.
const
int
num_bindings
=
input_output_names
.
size
();
std
::
vector
<
void
*>
buffers
(
num_bindings
);
for
(
const
std
::
string
&
name
:
input_output_names
)
{
auto
*
var
=
scope_
.
FindVar
(
name
);
auto
*
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
const
int
bind_index
=
engine_
->
engine
()
->
getBindingIndex
(
name
.
c_str
());
buffers
[
bind_index
]
=
static_cast
<
void
*>
(
tensor
->
mutable_data
<
float
>
(
place_
));
}
// Execute TRT.
engine_
->
Execute
(
batch_size
,
&
buffers
,
stream_
);
size_t
fluid_out_size
=
fluid_out
.
size
();
ASSERT_FALSE
(
op_desc_
->
OutputArgumentNames
().
empty
());
int
index
=
0
;
for
(
const
auto
&
output
:
op_desc_
->
OutputArgumentNames
())
{
if
(
neglected_output
.
count
(
output
))
continue
;
std
::
vector
<
float
>
trt_out
;
auto
*
var
=
scope_
.
FindVar
(
output
);
auto
*
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
framework
::
TensorToVector
(
*
tensor
,
ctx
,
&
trt_out
);
size_t
fluid_out_size
=
fluid_outs
[
index
].
size
();
if
(
if_add_batch_
==
true
)
{
fluid_out_size
=
batch_size
*
(
framework
::
product
(
tensor
->
dims
())
/
max_batch_size_
);
}
// Compare two output
ASSERT_FALSE
(
fluid_out
.
empty
());
for
(
size_t
i
=
0
;
i
<
fluid_out_size
;
i
++
)
{
// Loose the threshold for CI in different machine model.
EXPECT_LT
(
std
::
abs
(
fluid_out
[
i
]
-
trt_out
[
i
]),
2e-5
);
EXPECT_LT
(
std
::
abs
(
fluid_out
s
[
index
]
[
i
]
-
trt_out
[
i
]),
2e-5
);
}
index
+=
1
;
}
}
framework
::
Scope
&
scope
()
{
return
scope_
;
}
private:
platform
::
CUDAPlace
place_
;
std
::
unique_ptr
<
TensorRTEngine
>
engine_
;
cudaStream_t
stream_
;
std
::
unique_ptr
<
framework
::
OperatorBase
>
op_
;
...
...
paddle/fluid/inference/tensorrt/engine.cc
浏览文件 @
3d63aa0a
...
...
@@ -32,36 +32,18 @@ void TensorRTEngine::Build(const DescType &paddle_model) {
PADDLE_ENFORCE
(
false
,
"not implemented"
);
}
void
TensorRTEngine
::
Execute
(
int
batch_size
)
{
void
TensorRTEngine
::
Execute
(
int
batch_size
,
std
::
vector
<
void
*>
*
buffers
,
cudaStream_t
stream
)
{
freshDeviceId
();
batch_size_
=
batch_size
;
std
::
vector
<
void
*>
buffers
;
for
(
auto
&
buf
:
buffers_
)
{
PADDLE_ENFORCE_NOT_NULL
(
buf
.
buffer
,
"buffer should be allocated"
);
PADDLE_ENFORCE_GT
(
buf
.
max_size
,
0
);
PADDLE_ENFORCE
(
buf
.
device
==
DeviceType
::
GPU
);
buffers
.
push_back
(
buf
.
buffer
);
}
infer_context_
->
enqueue
(
batch_size
,
buffers
.
data
(),
stream_
,
nullptr
);
cudaStreamSynchronize
(
stream_
);
infer_context_
->
enqueue
(
batch_size
,
buffers
->
data
(),
stream
,
nullptr
);
cudaStreamSynchronize
(
stream
);
SetRuntimeBatch
(
batch_size
);
}
TensorRTEngine
::~
TensorRTEngine
()
{
cudaStreamSynchronize
(
stream_
);
// clean buffer
for
(
auto
&
buf
:
buffers_
)
{
if
(
buf
.
device
==
DeviceType
::
GPU
&&
buf
.
buffer
!=
nullptr
)
{
PADDLE_ENFORCE_EQ
(
0
,
cudaFree
(
buf
.
buffer
));
buf
.
buffer
=
nullptr
;
buf
.
max_size
=
0
;
}
}
}
void
TensorRTEngine
::
FreezeNetwork
()
{
VLOG
(
3
)
<<
"TRT to freeze network"
;
freshDeviceId
();
VLOG
(
3
)
<<
"TRT to freeze network"
;
PADDLE_ENFORCE
(
infer_builder_
!=
nullptr
,
"Call InitNetwork first to initialize network."
);
PADDLE_ENFORCE
(
infer_network_
!=
nullptr
,
...
...
@@ -81,30 +63,6 @@ void TensorRTEngine::FreezeNetwork() {
PADDLE_ENFORCE
(
infer_engine_
!=
nullptr
,
"build cuda engine failed!"
);
infer_context_
.
reset
(
infer_engine_
->
createExecutionContext
());
// allocate GPU buffers.
buffers_
.
resize
(
buffer_sizes_
.
size
());
for
(
auto
&
item
:
buffer_sizes_
)
{
// The output buffers are not set in the network building phrase, need to
// infer from the TesorRT network.
if
(
item
.
second
==
0
)
{
auto
slot_offset
=
infer_engine_
->
getBindingIndex
(
item
.
first
.
c_str
());
auto
dims
=
infer_engine_
->
getBindingDimensions
(
slot_offset
);
item
.
second
=
kDataTypeSize
[
static_cast
<
int
>
(
infer_engine_
->
getBindingDataType
(
slot_offset
))]
*
analysis
::
AccuDims
(
dims
.
d
,
dims
.
nbDims
)
*
max_batch_
;
PADDLE_ENFORCE_GT
(
item
.
second
,
0
);
}
auto
&
buf
=
buffer
(
item
.
first
);
buf
.
max_size
=
item
.
second
*
max_batch_
;
CHECK
(
buf
.
buffer
==
nullptr
);
// buffer should be allocated only once.
PADDLE_ENFORCE_EQ
(
0
,
cudaMalloc
(
&
buf
.
buffer
,
item
.
second
*
max_batch_
));
buf
.
size
=
0
;
PADDLE_ENFORCE_LE
(
buf
.
max_size
,
1
<<
30
);
// 10G
buf
.
device
=
DeviceType
::
GPU
;
}
}
nvinfer1
::
ITensor
*
TensorRTEngine
::
DeclareInput
(
const
std
::
string
&
name
,
...
...
@@ -158,83 +116,6 @@ void TensorRTEngine::DeclareOutput(const std::string &name) {
buffer_sizes_
[
name
]
=
0
;
}
void
*
TensorRTEngine
::
GetOutputInGPU
(
const
std
::
string
&
name
)
{
return
buffer
(
name
).
buffer
;
}
void
TensorRTEngine
::
GetOutputInGPU
(
const
std
::
string
&
name
,
void
*
dst
,
size_t
max_size
)
{
// determine data size
auto
*
output
=
TensorRTEngine
::
GetITensor
(
name
);
nvinfer1
::
Dims
dims
=
output
->
getDimensions
();
auto
dim_size
=
analysis
::
AccuDims
(
dims
.
d
,
dims
.
nbDims
);
size_t
dst_size
=
dim_size
*
runtime_batch_
*
kDataTypeSize
[
static_cast
<
int
>
(
output
->
getType
())];
auto
it
=
buffer_sizes_
.
find
(
name
);
PADDLE_ENFORCE
(
it
!=
buffer_sizes_
.
end
());
PADDLE_ENFORCE_GT
(
it
->
second
,
0
);
PADDLE_ENFORCE_LE
(
dst_size
,
it
->
second
);
PADDLE_ENFORCE_GE
(
max_size
,
dst_size
);
auto
&
buf
=
buffer
(
name
);
PADDLE_ENFORCE_NOT_NULL
(
buf
.
buffer
,
"buffer should be allocated before"
);
PADDLE_ENFORCE_EQ
(
cudaMemcpyAsync
(
dst
,
buf
.
buffer
,
dst_size
,
cudaMemcpyDeviceToDevice
,
stream_
),
0
);
}
void
TensorRTEngine
::
GetOutputInCPU
(
const
std
::
string
&
name
,
void
*
dst
,
size_t
max_size
)
{
// determine data size
auto
*
output
=
TensorRTEngine
::
GetITensor
(
name
);
nvinfer1
::
Dims
dims
=
output
->
getDimensions
();
auto
dim_size
=
analysis
::
AccuDims
(
dims
.
d
,
dims
.
nbDims
);
size_t
dst_size
=
dim_size
*
runtime_batch_
*
kDataTypeSize
[
static_cast
<
int
>
(
output
->
getType
())];
auto
it
=
buffer_sizes_
.
find
(
name
);
PADDLE_ENFORCE
(
it
!=
buffer_sizes_
.
end
());
PADDLE_ENFORCE_GT
(
it
->
second
,
0
);
PADDLE_ENFORCE_LE
(
dst_size
,
it
->
second
);
PADDLE_ENFORCE_GE
(
max_size
,
dst_size
);
auto
&
buf
=
buffer
(
name
);
PADDLE_ENFORCE_NOT_NULL
(
buf
.
buffer
,
"buffer should be allocated before"
);
PADDLE_ENFORCE_EQ
(
0
,
cudaMemcpyAsync
(
dst
,
buf
.
buffer
,
dst_size
,
cudaMemcpyDeviceToHost
,
stream_
));
}
Buffer
&
TensorRTEngine
::
buffer
(
const
std
::
string
&
name
)
{
PADDLE_ENFORCE
(
infer_engine_
!=
nullptr
,
"call FreezeNetwork first."
);
auto
it
=
buffer_sizes_
.
find
(
name
);
PADDLE_ENFORCE
(
it
!=
buffer_sizes_
.
end
(),
"tried to access buffer named %s"
,
name
);
auto
slot_offset
=
infer_engine_
->
getBindingIndex
(
name
.
c_str
());
return
buffers_
[
slot_offset
];
}
void
TensorRTEngine
::
SetInputFromCPU
(
const
std
::
string
&
name
,
const
void
*
data
,
size_t
size
)
{
auto
&
buf
=
buffer
(
name
);
PADDLE_ENFORCE_NOT_NULL
(
buf
.
buffer
);
PADDLE_ENFORCE_NOT_NULL
(
data
);
PADDLE_ENFORCE_LE
(
size
,
buf
.
max_size
,
"buffer is too small"
);
PADDLE_ENFORCE
(
buf
.
device
==
DeviceType
::
GPU
);
buf
.
size
=
size
;
PADDLE_ENFORCE_EQ
(
0
,
cudaMemcpyAsync
(
buf
.
buffer
,
data
,
size
,
cudaMemcpyHostToDevice
,
stream_
));
}
void
TensorRTEngine
::
SetInputFromGPU
(
const
std
::
string
&
name
,
const
void
*
data
,
size_t
size
)
{
auto
&
buf
=
buffer
(
name
);
buf
.
size
=
size
;
PADDLE_ENFORCE_NOT_NULL
(
buf
.
buffer
);
PADDLE_ENFORCE_LE
(
size
,
buf
.
max_size
,
"buffer is too small"
);
PADDLE_ENFORCE
(
buf
.
device
==
DeviceType
::
GPU
);
PADDLE_ENFORCE_EQ
(
0
,
cudaMemcpyAsync
(
buf
.
buffer
,
data
,
size
,
cudaMemcpyDeviceToDevice
,
stream_
));
}
void
TensorRTEngine
::
SetITensor
(
const
std
::
string
&
name
,
nvinfer1
::
ITensor
*
tensor
)
{
PADDLE_ENFORCE
(
tensor
!=
nullptr
);
...
...
@@ -254,13 +135,6 @@ void TensorRTEngine::SetRuntimeBatch(size_t batch_size) {
int
TensorRTEngine
::
GetRuntimeBatch
()
{
return
runtime_batch_
;
}
void
TensorRTEngine
::
freshDeviceId
()
{
int
count
;
cudaGetDeviceCount
(
&
count
);
PADDLE_ENFORCE_LT
(
device_
,
count
);
cudaSetDevice
(
device_
);
}
nvinfer1
::
IPluginLayer
*
TensorRTEngine
::
AddPlugin
(
nvinfer1
::
ITensor
*
const
*
inputs
,
int
num_inputs
,
plugin
::
PluginTensorRT
*
plugin
)
{
...
...
@@ -268,6 +142,13 @@ nvinfer1::IPluginLayer *TensorRTEngine::AddPlugin(
return
infer_network_
.
get
()
->
addPluginExt
(
inputs
,
num_inputs
,
*
plugin
);
}
void
TensorRTEngine
::
freshDeviceId
()
{
int
count
;
cudaGetDeviceCount
(
&
count
);
PADDLE_ENFORCE_LT
(
device_id_
,
count
);
cudaSetDevice
(
device_id_
);
}
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/engine.h
浏览文件 @
3d63aa0a
...
...
@@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/fluid/inference/engine.h"
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
#include "paddle/fluid/inference/tensorrt/trt_int8_calibrator.h"
#include "paddle/fluid/inference/utils/singleton.h"
...
...
@@ -37,7 +38,9 @@ class TRTInt8Calibrator;
* There are two alternative ways to use it, one is to build from a paddle
* protobuf model, another way is to manully construct the network.
*/
class
TensorRTEngine
:
public
EngineBase
{
class
TensorRTEngine
{
using
DescType
=
::
paddle
::
framework
::
proto
::
BlockDesc
;
public:
// Weight is model parameter.
class
Weight
{
...
...
@@ -56,28 +59,28 @@ class TensorRTEngine : public EngineBase {
nvinfer1
::
Weights
w_
;
};
TensorRTEngine
(
int
max_batch
,
int
max_workspace
,
cudaStream_t
stream
,
int
device
=
0
,
bool
enable_int8
=
false
,
TRTInt8Calibrator
*
calibrator
=
nullptr
,
TensorRTEngine
(
int
max_batch
,
int
max_workspace
,
bool
enable_int8
=
false
,
TRTInt8Calibrator
*
calibrator
=
nullptr
,
int
device_id
=
0
,
nvinfer1
::
ILogger
&
logger
=
NaiveLogger
::
Global
())
:
max_batch_
(
max_batch
),
max_workspace_
(
max_workspace
),
stream_
(
stream
),
device_
(
device
),
enable_int8_
(
enable_int8
),
calibrator_
(
calibrator
),
device_id_
(
device_id
),
logger_
(
logger
)
{}
virtual
~
TensorRTEngine
();
~
TensorRTEngine
()
{}
// TODO(Superjomn) implement it later when graph segmentation is supported.
void
Build
(
const
DescType
&
paddle_model
)
override
;
void
Build
(
const
DescType
&
paddle_model
);
void
Execute
(
int
batch_size
)
override
;
void
Execute
(
int
batch_size
,
std
::
vector
<
void
*>*
buffers
,
cudaStream_t
stream
);
// Initialize the inference network, so that TensorRT layers can add to this
// network.
void
InitNetwork
()
{
freshDeviceId
();
infer_builder_
.
reset
(
createInferBuilder
(
&
logger_
));
infer_network_
.
reset
(
infer_builder_
->
createNetwork
());
}
...
...
@@ -98,37 +101,34 @@ class TensorRTEngine : public EngineBase {
// Check if the ITensor has been declared
bool
HasDeclared
(
const
std
::
string
&
name
);
// GPU memory address for an ITensor with specific name. One can operate on
// these memory directly for acceleration, for example, output the converted
// data directly to the buffer to save data copy overhead.
// NOTE this should be used after calling `FreezeNetwork`.
Buffer
&
buffer
(
const
std
::
string
&
name
)
override
;
cudaStream_t
stream
()
{
return
stream_
;
}
// Fill an input from CPU memory with name and size.
void
SetInputFromCPU
(
const
std
::
string
&
name
,
const
void
*
data
,
size_t
size
);
// TODO(Superjomn) is this method necessary given that buffer(xxx) can be
// accessed directly. Fill an input from GPU memory with name and size.
void
SetInputFromGPU
(
const
std
::
string
&
name
,
const
void
*
data
,
size_t
size
);
// Get an output called name, the output of tensorrt is in GPU, so this method
// Return the output's GPU memory address without copy.
void
*
GetOutputInGPU
(
const
std
::
string
&
name
);
// Copy data into dst inside the GPU device.
void
GetOutputInGPU
(
const
std
::
string
&
name
,
void
*
dst
,
size_t
max_size
);
// LOW EFFICENCY! Get output to CPU, this will trigger a memory copy from GPU
// to CPU.
void
GetOutputInCPU
(
const
std
::
string
&
name
,
void
*
dst
,
size_t
max_size
);
// Fill an ITensor into map itensor_map_.
void
SetITensor
(
const
std
::
string
&
name
,
nvinfer1
::
ITensor
*
tensor
);
// Get an ITensor called name.
nvinfer1
::
ITensor
*
GetITensor
(
const
std
::
string
&
name
);
nvinfer1
::
ICudaEngine
*
engine
()
{
return
infer_engine_
.
get
();
}
nvinfer1
::
INetworkDefinition
*
network
()
{
return
infer_network_
.
get
();
}
nvinfer1
::
IHostMemory
*
Serialize
()
{
PADDLE_ENFORCE
(
infer_engine_
!=
nullptr
,
"You should build engine first and then serialize"
);
ihost_memory_
.
reset
(
infer_engine_
->
serialize
());
return
ihost_memory_
.
get
();
}
void
Deserialize
(
const
std
::
string
&
engine_serialized_data
)
{
freshDeviceId
();
infer_ptr
<
nvinfer1
::
IRuntime
>
runtime
(
createInferRuntime
(
&
logger_
));
infer_engine_
.
reset
(
runtime
->
deserializeCudaEngine
(
engine_serialized_data
.
c_str
(),
engine_serialized_data
.
size
(),
&
inference
::
Singleton
<
plugin
::
PluginFactoryTensorRT
>::
Global
()));
PADDLE_ENFORCE
(
infer_engine_
!=
nullptr
,
"build cuda engine failed when deserialize engine info.!"
);
infer_context_
.
reset
(
infer_engine_
->
createExecutionContext
());
}
void
SetRuntimeBatch
(
size_t
batch_size
);
int
GetRuntimeBatch
();
int
GetDevice
()
{
return
device
_
;
}
int
GetDevice
Id
()
{
return
device_id
_
;
}
nvinfer1
::
IPluginLayer
*
AddPlugin
(
nvinfer1
::
ITensor
*
const
*
inputs
,
int
num_inputs
,
plugin
::
PluginTensorRT
*
);
...
...
@@ -140,17 +140,12 @@ class TensorRTEngine : public EngineBase {
std
::
unordered_map
<
std
::
string
/*name*/
,
std
::
unique_ptr
<
framework
::
Tensor
>>
weight_map
;
// TODO(NHZLX)
// In the normal case, the paddle-trt exists bug when runing the googlenet.
// When there are more than two convolutions of 1 * 1 with the same input, the
// paddle-tensorrt will do the merging optimization, which fuse those conv
// into one conv, and then trigger bug. So, We should use strategy to avoid
// this
// optimization for the time being. This bug will be fixed in the future.
std
::
unordered_map
<
std
::
string
/*name*/
,
int
/*ITensor_quote_num*/
>
itensor_quote_num
;
private:
// Each ICudaEngine object is bound to a specific GPU when it is instantiated,
// ensure that the thread is associated with the correct device by calling
// freshDeviceId().
void
freshDeviceId
();
// the max batch size
int
max_batch_
;
// the runtime batch size
...
...
@@ -158,18 +153,14 @@ class TensorRTEngine : public EngineBase {
// the max memory size the engine uses
int
max_workspace_
;
cudaStream_t
stream_
;
// The specific GPU id that the TensorRTEngine bounded to.
int
device_
;
bool
enable_int8_
;
TRTInt8Calibrator
*
calibrator_
;
// batch size of the current data, will be updated each Executation.
int
batch_size_
{
-
1
};
int
device_id_
;
nvinfer1
::
ILogger
&
logger_
;
std
::
vector
<
Buffer
>
buffers_
;
// max data size for the buffers.
std
::
unordered_map
<
std
::
string
/*name*/
,
size_t
/*max size*/
>
buffer_sizes_
;
std
::
unordered_map
<
std
::
string
/*name*/
,
nvinfer1
::
ITensor
*
/*ITensor*/
>
...
...
@@ -192,15 +183,11 @@ class TensorRTEngine : public EngineBase {
infer_ptr
<
nvinfer1
::
INetworkDefinition
>
infer_network_
;
infer_ptr
<
nvinfer1
::
ICudaEngine
>
infer_engine_
;
infer_ptr
<
nvinfer1
::
IExecutionContext
>
infer_context_
;
// Each ICudaEngine object is bound to a specific GPU when it is instantiated,
// ensure that the thread is associated with the correct device by calling
// freshDeviceId().
void
freshDeviceId
();
infer_ptr
<
nvinfer1
::
IHostMemory
>
ihost_memory_
;
};
// class TensorRTEngine
// Add an layer__ into engine__ with args ARGS.
// For example:
// TRT_ENGINE_ADD_LAYER(xxx, FullyConnected, input, dim, weights, bias)
//
// Reference
// https://docs.nvidia.com/deeplearning/sdk/tensorrt-developer-guide/index.html#charRNN_define_network
...
...
paddle/fluid/inference/tensorrt/helper.h
浏览文件 @
3d63aa0a
...
...
@@ -17,6 +17,9 @@
#include <NvInfer.h>
#include <cuda.h>
#include <glog/logging.h>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/platform/dynload/tensorrt.h"
#include "paddle/fluid/platform/enforce.h"
...
...
@@ -74,6 +77,32 @@ class NaiveLogger : public nvinfer1::ILogger {
~
NaiveLogger
()
override
{}
};
class
NaiveProfiler
:
public
nvinfer1
::
IProfiler
{
public:
typedef
std
::
pair
<
std
::
string
,
float
>
Record
;
std
::
vector
<
Record
>
mProfile
;
virtual
void
reportLayerTime
(
const
char
*
layerName
,
float
ms
)
{
auto
record
=
std
::
find_if
(
mProfile
.
begin
(),
mProfile
.
end
(),
[
&
](
const
Record
&
r
)
{
return
r
.
first
==
layerName
;
});
if
(
record
==
mProfile
.
end
())
mProfile
.
push_back
(
std
::
make_pair
(
layerName
,
ms
));
else
record
->
second
+=
ms
;
}
void
printLayerTimes
()
{
float
totalTime
=
0
;
for
(
size_t
i
=
0
;
i
<
mProfile
.
size
();
i
++
)
{
printf
(
"%-40.40s %4.3fms
\n
"
,
mProfile
[
i
].
first
.
c_str
(),
mProfile
[
i
].
second
);
totalTime
+=
mProfile
[
i
].
second
;
}
printf
(
"Time over all layers: %4.3f
\n
"
,
totalTime
);
}
};
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
浏览文件 @
3d63aa0a
nv_library
(
tensorrt_plugin
SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu prelu_op_plugin.cu
SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu
prelu_op_plugin.cu trt_plugin_factory.cc
avg_pool_op_plugin.cu
DEPS enforce tensorrt_engine prelu
)
paddle/fluid/inference/tensorrt/plugin/avg_pool_op_plugin.cu
浏览文件 @
3d63aa0a
...
...
@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/inference/tensorrt/plugin/avg_pool_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
#include "paddle/fluid/operators/math/pooling.h"
namespace
paddle
{
...
...
@@ -20,6 +21,12 @@ namespace inference {
namespace
tensorrt
{
namespace
plugin
{
AvgPoolPlugin
*
CreateAvgPoolPluginDeserialize
(
const
void
*
buffer
,
size_t
length
)
{
return
new
AvgPoolPlugin
(
buffer
,
length
);
}
REGISTER_TRT_PLUGIN
(
"avg_pool_plugin"
,
CreateAvgPoolPluginDeserialize
);
nvinfer1
::
Dims
AvgPoolPlugin
::
getOutputDimensions
(
int
index
,
const
nvinfer1
::
Dims
*
inputDims
,
int
nbInputs
)
{
assert
(
nbInputs
==
1
);
...
...
paddle/fluid/inference/tensorrt/plugin/avg_pool_op_plugin.h
浏览文件 @
3d63aa0a
...
...
@@ -33,24 +33,27 @@ class AvgPoolPlugin : public PluginTensorRT {
protected:
size_t
getSerializationSize
()
override
{
return
SerializedSize
(
ceil_mode_
)
+
SerializedSize
(
ksize_
)
+
SerializedSize
(
strides_
)
+
SerializedSize
(
paddings_
)
+
SerializedSize
(
input_shape_
)
+
getBaseSerializationSize
();
return
SerializedSize
(
getPluginType
())
+
SerializedSize
(
ceil_mode_
)
+
SerializedSize
(
ksize_
)
+
SerializedSize
(
strides_
)
+
SerializedSize
(
paddings_
)
+
SerializedSize
(
input_shape_
)
+
SerializedSize
(
output_shape_
)
+
getBaseSerializationSize
();
}
// TRT will call this func when we need to serialize the configuration of
// tensorrt.
// It should not be called by users.
void
serialize
(
void
*
buffer
)
override
{
SerializeValue
(
&
buffer
,
getPluginType
());
serializeBase
(
buffer
);
SerializeValue
(
&
buffer
,
ceil_mode_
);
SerializeValue
(
&
buffer
,
ksize_
);
SerializeValue
(
&
buffer
,
strides_
);
SerializeValue
(
&
buffer
,
paddings_
);
SerializeValue
(
&
buffer
,
input_shape_
);
SerializeValue
(
&
buffer
,
output_shape_
);
}
public:
AvgPoolPlugin
()
{}
AvgPoolPlugin
(
bool
ceil_mode
,
std
::
vector
<
int
>
ksize
,
std
::
vector
<
int
>
strides
,
std
::
vector
<
int
>
paddings
,
std
::
vector
<
int
>
input_shape
)
...
...
@@ -89,6 +92,7 @@ class AvgPoolPlugin : public PluginTensorRT {
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
strides_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
paddings_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
input_shape_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
output_shape_
);
}
AvgPoolPlugin
*
clone
()
const
override
{
...
...
@@ -96,7 +100,7 @@ class AvgPoolPlugin : public PluginTensorRT {
input_shape_
);
}
const
char
*
getPluginType
()
const
override
{
return
"avg_pool"
;
}
const
char
*
getPluginType
()
const
override
{
return
"avg_pool
_plugin
"
;
}
int
getNbOutputs
()
const
override
{
return
1
;
}
nvinfer1
::
Dims
getOutputDimensions
(
int
index
,
const
nvinfer1
::
Dims
*
inputs
,
int
nbInputDims
)
override
;
...
...
paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.cu
浏览文件 @
3d63aa0a
...
...
@@ -14,12 +14,19 @@ limitations under the License. */
#include <glog/logging.h>
#include "paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
ElementWisePlugin
*
CreateElementWisePluginDeserialize
(
const
void
*
buffer
,
size_t
length
)
{
return
new
ElementWisePlugin
(
buffer
,
length
);
}
REGISTER_TRT_PLUGIN
(
"elementwise_plugin"
,
CreateElementWisePluginDeserialize
);
namespace
details
{
template
<
typename
T
>
...
...
@@ -119,10 +126,10 @@ int ElementWisePlugin::enqueue(int batch_size, const void* const* inputs,
const
float
*
y
=
reinterpret_cast
<
const
float
*>
(
inputs
[
1
]);
float
*
out
=
reinterpret_cast
<
float
*>
(
outputs
[
0
]);
if
(
type_
==
nvinfer1
::
ElementWiseOperation
::
kSUM
)
{
if
(
type_
==
"add"
)
{
details
::
ElementWise
(
details
::
Add
<
float
>
(),
x
,
y
,
out
,
batch_size
,
prev_size_
,
midd_size_
,
post_size_
,
stream
);
}
else
if
(
type_
==
nvinfer1
::
ElementWiseOperation
::
kPROD
)
{
}
else
if
(
type_
==
"mul"
)
{
details
::
ElementWise
(
details
::
Mul
<
float
>
(),
x
,
y
,
out
,
batch_size
,
prev_size_
,
midd_size_
,
post_size_
,
stream
);
}
else
{
...
...
paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.h
浏览文件 @
3d63aa0a
...
...
@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
...
...
@@ -24,9 +25,8 @@ namespace plugin {
class
ElementWisePlugin
:
public
PluginTensorRT
{
public:
ElementWisePlugin
(
nvinfer1
::
ElementWiseOperation
type
,
nvinfer1
::
Dims
const
&
dims_x
,
nvinfer1
::
Dims
const
&
dims_y
,
int
axis
)
ElementWisePlugin
(
std
::
string
type
,
nvinfer1
::
Dims
const
&
dims_x
,
nvinfer1
::
Dims
const
&
dims_y
,
int
axis
)
:
type_
(
type
),
dims_x_
(
dims_x
),
dims_y_
(
dims_y
),
...
...
@@ -37,6 +37,9 @@ class ElementWisePlugin : public PluginTensorRT {
ElementWisePlugin
(
void
const
*
serial_data
,
size_t
serial_length
)
{
deserializeBase
(
serial_data
,
serial_length
);
const
char
*
elementwise_type
;
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
elementwise_type
);
type_
=
std
::
string
(
elementwise_type
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
axis_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
dims_x_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
dims_y_
);
...
...
@@ -47,7 +50,7 @@ class ElementWisePlugin : public PluginTensorRT {
return
nullptr
;
}
const
char
*
getPluginType
()
const
override
{
return
"elementwise"
;
}
const
char
*
getPluginType
()
const
override
{
return
"elementwise
_plugin
"
;
}
nvinfer1
::
Dims
getOutputDimensions
(
int
index
,
const
nvinfer1
::
Dims
*
input_dims
,
...
...
@@ -61,18 +64,21 @@ class ElementWisePlugin : public PluginTensorRT {
protected:
size_t
getSerializationSize
()
override
{
return
SerializedSize
(
axis_
)
+
SerializedSize
(
dims_x_
)
+
SerializedSize
(
dims_y_
)
+
getBaseSerializationSize
();
return
SerializedSize
(
getPluginType
())
+
SerializedSize
(
axis_
)
+
SerializedSize
(
dims_x_
)
+
SerializedSize
(
dims_y_
)
+
getBaseSerializationSize
();
}
void
serialize
(
void
*
buffer
)
override
{
SerializeValue
(
&
buffer
,
getPluginType
());
serializeBase
(
buffer
);
SerializeValue
(
&
buffer
,
type_
.
c_str
());
SerializeValue
(
&
buffer
,
axis_
);
SerializeValue
(
&
buffer
,
dims_x_
);
SerializeValue
(
&
buffer
,
dims_y_
);
}
nvinfer1
::
ElementWiseOperation
type_
;
std
::
string
type_
;
nvinfer1
::
Dims
dims_x_
;
nvinfer1
::
Dims
dims_y_
;
int
axis_
;
...
...
paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu
浏览文件 @
3d63aa0a
...
...
@@ -17,6 +17,7 @@
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
#include "paddle/fluid/operators/math/prelu.h"
namespace
paddle
{
...
...
@@ -24,6 +25,17 @@ namespace inference {
namespace
tensorrt
{
namespace
plugin
{
PReluPlugin
*
CreatePreluPluginDeserialize
(
const
void
*
buffer
,
size_t
length
)
{
return
new
PReluPlugin
(
buffer
,
length
);
}
REGISTER_TRT_PLUGIN
(
"prelu_plugin"
,
CreatePreluPluginDeserialize
);
int
PReluPlugin
::
initialize
()
{
cudaMalloc
(
&
p_gpu_weight_
,
sizeof
(
float
)
*
weight_
.
size
());
cudaMemcpy
(
p_gpu_weight_
,
weight_
.
data
(),
weight_
.
size
()
*
sizeof
(
float
),
cudaMemcpyHostToDevice
);
}
nvinfer1
::
Dims
PReluPlugin
::
getOutputDimensions
(
int
index
,
const
nvinfer1
::
Dims
*
inputDims
,
int
nbInputs
)
{
...
...
@@ -39,7 +51,8 @@ int PReluPlugin::enqueue(int batch_size, const void *const *inputs,
// input dims is CHW.
const
auto
&
input_dims
=
this
->
getInputDims
(
0
);
const
float
*
input
=
reinterpret_cast
<
const
float
*>
(
inputs
[
0
]);
const
float
*
alpha
=
reinterpret_cast
<
const
float
*>
(
alpha_
.
get
().
values
);
// const float *alpha = reinterpret_cast<const float *>(alpha_.get().values);
const
float
*
alpha
=
p_gpu_weight_
;
float
*
output
=
reinterpret_cast
<
float
**>
(
outputs
)[
0
];
std
::
vector
<
int
>
input_shape
;
...
...
paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h
浏览文件 @
3d63aa0a
...
...
@@ -14,7 +14,12 @@
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
...
...
@@ -24,39 +29,51 @@ namespace tensorrt {
namespace
plugin
{
class
PReluPlugin
:
public
PluginTensorRT
{
TensorRTEngine
::
Weight
alpha_
;
std
::
vector
<
float
>
weight_
;
float
*
p_gpu_weight_
;
std
::
string
mode_
;
protected:
size_t
getSerializationSize
()
override
{
// return getBaseSerializationSize(alpha_) + SerializedSize(mode_);
return
0
;
return
getBaseSerializationSize
()
+
SerializedSize
(
mode_
.
c_str
())
+
SerializedSize
(
weight_
)
+
SerializedSize
(
getPluginType
())
;
}
// TRT will call this func when we need to serialize the configuration of
// tensorrt.
// It should not be called by users.
void
serialize
(
void
*
buffer
)
override
{
// serializeBase(buffer);
// SerializeValue(&buffer, alpha_);
// SerializeValue(&buffer, mode_);
SerializeValue
(
&
buffer
,
getPluginType
());
serializeBase
(
buffer
);
SerializeValue
(
&
buffer
,
weight_
);
SerializeValue
(
&
buffer
,
mode_
.
c_str
());
}
public:
PReluPlugin
(
TensorRTEngine
::
Weight
const
&
alpha
,
std
::
string
const
&
mode
)
:
alpha_
(
alpha
),
mode_
(
mode
)
{}
PReluPlugin
(
const
float
*
weight
,
const
int
weight_num
,
std
::
string
const
&
mode
)
:
mode_
(
mode
)
{
weight_
.
resize
(
weight_num
);
std
::
copy
(
weight
,
weight
+
weight_num
,
weight_
.
data
());
}
// It was used for tensorrt deserialization.
// It should not be called by users.
PReluPlugin
(
void
const
*
serialData
,
size_t
serialLength
)
{
// deserializeBase(serialData, serialLength);
// DeserializeValue(&serialData, &serialLength, &alpha_);
// DeserializeValue(&serialData, &serialLength, &mode_);
deserializeBase
(
serialData
,
serialLength
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
weight_
);
const
char
*
prelu_mode
;
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
prelu_mode
);
mode_
=
std
::
string
(
prelu_mode
);
}
~
PReluPlugin
()
{
cudaFree
(
p_gpu_weight_
);
}
int
initialize
()
override
;
PReluPlugin
*
clone
()
const
override
{
return
new
PReluPlugin
(
alpha_
,
mode_
);
}
PReluPlugin
*
clone
()
const
override
{
return
new
PReluPlugin
(
weight_
.
data
(),
weight_
.
size
(),
mode_
);
}
const
char
*
getPluginType
()
const
override
{
return
"prelu"
;
}
const
char
*
getPluginType
()
const
override
{
return
"prelu
_plugin
"
;
}
int
getNbOutputs
()
const
override
{
return
1
;
}
nvinfer1
::
Dims
getOutputDimensions
(
int
index
,
const
nvinfer1
::
Dims
*
inputs
,
int
nbInputDims
)
override
;
...
...
paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu
浏览文件 @
3d63aa0a
...
...
@@ -15,12 +15,18 @@
#include <cuda_fp16.h>
#include <algorithm>
#include "paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
SplitPlugin
*
CreateSplitPluginDeserialize
(
const
void
*
buffer
,
size_t
length
)
{
return
new
SplitPlugin
(
buffer
,
length
);
}
REGISTER_TRT_PLUGIN
(
"split_plugin"
,
CreateSplitPluginDeserialize
);
// copied from operators::math::SplitFunctor
template
<
typename
T
>
__global__
void
SplitKernel
(
const
T
*
input_data
,
const
int
in_row
,
...
...
paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h
浏览文件 @
3d63aa0a
...
...
@@ -15,6 +15,7 @@
#pragma once
#include <thrust/device_vector.h>
#include <utility>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
...
...
@@ -25,6 +26,7 @@ namespace plugin {
class
SplitPlugin
:
public
PluginTensorRT
{
public:
SplitPlugin
()
{}
SplitPlugin
(
int
axis
,
std
::
vector
<
int
>
const
&
output_lengths
)
:
axis_
(
axis
),
same_shape_
(
true
),
output_length_
(
output_lengths
)
{}
...
...
@@ -38,7 +40,7 @@ class SplitPlugin : public PluginTensorRT {
return
new
SplitPlugin
(
axis_
,
output_length_
);
}
const
char
*
getPluginType
()
const
override
{
return
"split"
;
}
const
char
*
getPluginType
()
const
override
{
return
"split
_plugin
"
;
}
int
getNbOutputs
()
const
override
{
return
output_length_
.
size
();
}
nvinfer1
::
Dims
getOutputDimensions
(
int
index
,
const
nvinfer1
::
Dims
*
input_dims
,
...
...
@@ -50,11 +52,12 @@ class SplitPlugin : public PluginTensorRT {
protected:
size_t
getSerializationSize
()
override
{
return
SerializedSize
(
axis_
)
+
SerializedSize
(
output_length
_
)
+
getBaseSerializationSize
();
return
SerializedSize
(
getPluginType
())
+
SerializedSize
(
axis
_
)
+
SerializedSize
(
output_length_
)
+
getBaseSerializationSize
();
}
void
serialize
(
void
*
buffer
)
override
{
SerializeValue
(
&
buffer
,
getPluginType
());
serializeBase
(
buffer
);
SerializeValue
(
&
buffer
,
axis_
);
SerializeValue
(
&
buffer
,
output_length_
);
...
...
paddle/fluid/inference/tensorrt/plugin/trt_plugin.h
浏览文件 @
3d63aa0a
...
...
@@ -17,9 +17,10 @@
#include <NvInfer.h>
#include <cstring>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/
serialize
.h"
#include "paddle/fluid/inference/tensorrt/plugin/
trt_plugin_utils
.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/profiler.h"
...
...
@@ -30,6 +31,13 @@ namespace inference {
namespace
tensorrt
{
namespace
plugin
{
class
PluginTensorRT
;
typedef
std
::
function
<
PluginTensorRT
*
(
const
void
*
,
size_t
)
>
PluginDeserializeFunc
;
typedef
std
::
function
<
PluginTensorRT
*
(
void
)
>
PluginConstructFunc
;
class
PluginTensorRT
:
public
nvinfer1
::
IPluginExt
{
public:
PluginTensorRT
()
{}
...
...
paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.cc
0 → 100644
浏览文件 @
3d63aa0a
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
PluginTensorRT
*
PluginFactoryTensorRT
::
createPlugin
(
const
char
*
layer_name
,
const
void
*
serial_data
,
size_t
serial_length
)
{
const
char
*
plugin_type
;
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
plugin_type
);
PADDLE_ENFORCE
(
Has
(
plugin_type
),
"trt plugin type %s does not exists, check it."
,
plugin_type
);
auto
plugin
=
plugin_registry_
[
plugin_type
](
serial_data
,
serial_length
);
owned_plugins_
.
emplace_back
(
plugin
);
return
plugin
;
}
bool
PluginFactoryTensorRT
::
RegisterPlugin
(
const
std
::
string
&
op_name
,
PluginDeserializeFunc
deserialize_func
)
{
if
(
Has
(
op_name
))
return
false
;
auto
ret
=
plugin_registry_
.
emplace
(
op_name
,
deserialize_func
);
return
ret
.
second
;
}
void
PluginFactoryTensorRT
::
DestroyPlugins
()
{
owned_plugins_
.
clear
();
}
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h
0 → 100644
浏览文件 @
3d63aa0a
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <NvInfer.h>
#include <cstring>
#include <list>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_utils.h"
#include "paddle/fluid/inference/utils/singleton.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
class
PluginFactoryTensorRT
:
public
nvinfer1
::
IPluginFactory
,
public
DeleteHelper
{
public:
// Deserialization method
PluginTensorRT
*
createPlugin
(
const
char
*
layer_name
,
const
void
*
serial_data
,
size_t
serial_length
)
override
;
bool
RegisterPlugin
(
const
std
::
string
&
op_name
,
PluginDeserializeFunc
deserialize_func
);
bool
Has
(
const
std
::
string
&
op_name
)
{
return
plugin_registry_
.
find
(
op_name
)
!=
plugin_registry_
.
end
();
}
void
DestroyPlugins
();
protected:
std
::
unordered_map
<
std
::
string
,
PluginDeserializeFunc
>
plugin_registry_
;
std
::
list
<
std
::
unique_ptr
<
PluginTensorRT
>>
owned_plugins_
;
};
class
TrtPluginRegistrar
{
public:
TrtPluginRegistrar
(
const
std
::
string
&
name
,
PluginDeserializeFunc
deserialize_func
)
{
inference
::
Singleton
<
PluginFactoryTensorRT
>::
Global
().
RegisterPlugin
(
name
,
deserialize_func
);
}
};
#define REGISTER_TRT_PLUGIN(name, deserialize_func) \
REGISTER_TRT_PLUGIN_UNIQ(__COUNTER__, name, deserialize_func)
#define REGISTER_TRT_PLUGIN_UNIQ(ctr, name, deserialize_func) \
static paddle::inference::tensorrt::plugin::TrtPluginRegistrar \
trt_plugin_registrar##ctr __attribute__((unused)) = \
paddle::inference::tensorrt::plugin::TrtPluginRegistrar( \
name, deserialize_func)
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/plugin/
serialize
.h
→
paddle/fluid/inference/tensorrt/plugin/
trt_plugin_utils
.h
浏览文件 @
3d63aa0a
...
...
@@ -13,8 +13,8 @@
// limitations under the License.
#pragma once
#include <cstring>
#include <string>
#include <type_traits>
#include <vector>
#include "paddle/fluid/platform/enforce.h"
...
...
@@ -24,6 +24,13 @@ namespace inference {
namespace
tensorrt
{
namespace
plugin
{
// Some trt base classes lack of the destructor.
// We use a assisted class to fix this.
struct
DeleteHelper
{
protected:
virtual
~
DeleteHelper
()
{}
};
template
<
typename
T
>
inline
void
SerializeValue
(
void
**
buffer
,
T
const
&
value
);
...
...
paddle/fluid/inference/tensorrt/test_engine.cc
浏览文件 @
3d63aa0a
...
...
@@ -17,6 +17,8 @@ limitations under the License. */
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/platform/enforce.h"
...
...
@@ -27,19 +29,34 @@ namespace tensorrt {
class
TensorRTEngineTest
:
public
::
testing
::
Test
{
protected:
void
SetUp
()
override
{
ASSERT_EQ
(
0
,
cudaStreamCreate
(
&
stream_
));
engine_
=
new
TensorRTEngine
(
10
,
1
<<
10
,
stream_
);
ctx_
=
new
platform
::
CUDADeviceContext
(
platform
::
CUDAPlace
(
0
));
engine_
=
new
TensorRTEngine
(
10
,
1
<<
10
);
engine_
->
InitNetwork
();
}
void
TearDown
()
override
{
delete
engine_
;
cudaStreamDestroy
(
stream_
);
if
(
engine_
)
{
delete
engine_
;
engine_
=
nullptr
;
}
}
void
PrepareInputOutput
(
const
std
::
vector
<
float
>
&
input
,
std
::
vector
<
int
>
output_shape
)
{
TensorFromVector
(
input
,
*
ctx_
,
&
input_
);
output_
.
Resize
(
framework
::
make_ddim
(
output_shape
));
}
void
GetOutput
(
std
::
vector
<
float
>
*
output
)
{
TensorToVector
(
output_
,
*
ctx_
,
output
);
}
protected:
TensorRTEngine
*
engine_
;
cudaStream_t
stream_
;
framework
::
Tensor
input_
;
framework
::
Tensor
output_
;
TensorRTEngine
*
engine_
;
platform
::
CUDADeviceContext
*
ctx_
;
};
TEST_F
(
TensorRTEngineTest
,
add_layer
)
{
...
...
@@ -48,12 +65,14 @@ TEST_F(TensorRTEngineTest, add_layer) {
float
raw_weight
[
size
]
=
{
2.
};
// Weight in CPU memory.
float
raw_bias
[
size
]
=
{
3.
};
std
::
vector
<
void
*>
buffers
(
2
);
// TRT binded inputs
LOG
(
INFO
)
<<
"create weights"
;
TensorRTEngine
::
Weight
weight
(
nvinfer1
::
DataType
::
kFLOAT
,
raw_weight
,
size
);
TensorRTEngine
::
Weight
bias
(
nvinfer1
::
DataType
::
kFLOAT
,
raw_bias
,
size
);
auto
*
x
=
engine_
->
DeclareInput
(
"x"
,
nvinfer1
::
DataType
::
kFLOAT
,
auto
*
x
=
engine_
->
DeclareInput
(
"x"
,
nvinfer1
::
DataType
::
kFLOAT
,
nvinfer1
::
DimsCHW
{
1
,
1
,
1
});
auto
*
fc_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
FullyConnected
,
*
x
,
size
,
auto
*
fc_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
FullyConnected
,
*
x
,
size
,
weight
.
get
(),
bias
.
get
());
PADDLE_ENFORCE
(
fc_layer
!=
nullptr
);
...
...
@@ -63,18 +82,24 @@ TEST_F(TensorRTEngineTest, add_layer) {
ASSERT_EQ
(
engine_
->
engine
()
->
getNbBindings
(),
2
);
// fill in real data
float
x_v
=
1234
;
engine_
->
SetInputFromCPU
(
"x"
,
reinterpret_cast
<
void
*>
(
&
x_v
),
1
*
sizeof
(
float
));
std
::
vector
<
float
>
x_v
=
{
1234
};
std
::
vector
<
float
>
y_cpu
;
PrepareInputOutput
(
x_v
,
{
1
});
auto
*
x_v_gpu_data
=
input_
.
mutable_data
<
float
>
(
ctx_
->
GetPlace
());
auto
*
y_gpu_data
=
output_
.
mutable_data
<
float
>
(
ctx_
->
GetPlace
());
buffers
[
0
]
=
reinterpret_cast
<
void
*>
(
x_v_gpu_data
);
buffers
[
1
]
=
reinterpret_cast
<
void
*>
(
y_gpu_data
);
LOG
(
INFO
)
<<
"to execute"
;
engine_
->
Execute
(
1
);
engine_
->
Execute
(
1
,
&
buffers
,
ctx_
->
stream
()
);
LOG
(
INFO
)
<<
"to get output"
;
float
y_cpu
;
engine_
->
GetOutputInCPU
(
"y"
,
&
y_cpu
,
1
*
sizeof
(
float
));
GetOutput
(
&
y_cpu
);
LOG
(
INFO
)
<<
"to checkout output"
;
ASSERT_EQ
(
y_cpu
,
x_v
*
2
+
3
);
ASSERT_EQ
(
y_cpu
[
0
],
x_v
[
0
]
*
2
+
3
);
}
TEST_F
(
TensorRTEngineTest
,
add_layer_multi_dim
)
{
...
...
@@ -83,12 +108,13 @@ TEST_F(TensorRTEngineTest, add_layer_multi_dim) {
// instead of row-major, which is [[1.0, 1.1], [3.3, 4.4]]
float
raw_weight
[
4
]
=
{
1.0
,
1.1
,
3.3
,
4.4
};
float
raw_bias
[
2
]
=
{
1.3
,
2.4
};
std
::
vector
<
void
*>
buffers
(
2
);
// TRT binded inputs
TensorRTEngine
::
Weight
weight
(
nvinfer1
::
DataType
::
kFLOAT
,
raw_weight
,
4
);
TensorRTEngine
::
Weight
bias
(
nvinfer1
::
DataType
::
kFLOAT
,
raw_bias
,
2
);
auto
*
x
=
engine_
->
DeclareInput
(
"x"
,
nvinfer1
::
DataType
::
kFLOAT
,
auto
*
x
=
engine_
->
DeclareInput
(
"x"
,
nvinfer1
::
DataType
::
kFLOAT
,
nvinfer1
::
DimsCHW
{
1
,
2
,
1
});
auto
*
fc_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
FullyConnected
,
*
x
,
2
,
auto
*
fc_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
FullyConnected
,
*
x
,
2
,
weight
.
get
(),
bias
.
get
());
PADDLE_ENFORCE
(
fc_layer
!=
nullptr
);
...
...
@@ -96,19 +122,27 @@ TEST_F(TensorRTEngineTest, add_layer_multi_dim) {
engine_
->
FreezeNetwork
();
ASSERT_EQ
(
engine_
->
engine
()
->
getNbBindings
(),
2
);
float
x_v
[
2
]
=
{
1.0
,
2.0
};
engine_
->
SetInputFromCPU
(
"x"
,
reinterpret_cast
<
void
*>
(
&
x_v
),
2
*
sizeof
(
float
));
engine_
->
Execute
(
1
);
// fill in real data
std
::
vector
<
float
>
x_v
=
{
1.0
,
2.0
};
std
::
vector
<
float
>
y_cpu
;
PrepareInputOutput
(
x_v
,
{
2
});
auto
*
x_v_gpu_data
=
input_
.
mutable_data
<
float
>
(
ctx_
->
GetPlace
());
auto
*
y_gpu_data
=
output_
.
mutable_data
<
float
>
(
ctx_
->
GetPlace
());
buffers
[
0
]
=
reinterpret_cast
<
void
*>
(
x_v_gpu_data
);
buffers
[
1
]
=
reinterpret_cast
<
void
*>
(
y_gpu_data
);
engine_
->
Execute
(
1
,
&
buffers
,
ctx_
->
stream
());
LOG
(
INFO
)
<<
"to get output"
;
float
y_cpu
[
2
]
=
{
-
1.
,
-
1.
}
;
GetOutput
(
&
y_cpu
)
;
auto
dims
=
engine_
->
GetITensor
(
"y"
)
->
getDimensions
();
ASSERT_EQ
(
dims
.
nbDims
,
3
);
ASSERT_EQ
(
dims
.
d
[
0
],
2
);
ASSERT_EQ
(
dims
.
d
[
1
],
1
);
engine_
->
GetOutputInCPU
(
"y"
,
&
y_cpu
[
0
],
2
*
sizeof
(
float
));
ASSERT_EQ
(
y_cpu
[
0
],
4.5
);
ASSERT_EQ
(
y_cpu
[
1
],
14.5
);
}
...
...
@@ -117,12 +151,13 @@ TEST_F(TensorRTEngineTest, test_conv2d) {
// Weight in CPU memory.
float
raw_weight
[
9
]
=
{
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
};
float
raw_bias
[
1
]
=
{
0
};
std
::
vector
<
void
*>
buffers
(
2
);
// TRT binded inputs
TensorRTEngine
::
Weight
weight
(
nvinfer1
::
DataType
::
kFLOAT
,
raw_weight
,
9
);
TensorRTEngine
::
Weight
bias
(
nvinfer1
::
DataType
::
kFLOAT
,
raw_bias
,
1
);
auto
*
x
=
engine_
->
DeclareInput
(
"x"
,
nvinfer1
::
DataType
::
kFLOAT
,
auto
*
x
=
engine_
->
DeclareInput
(
"x"
,
nvinfer1
::
DataType
::
kFLOAT
,
nvinfer1
::
Dims3
{
1
,
3
,
3
});
auto
*
conv_layer
=
auto
*
conv_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Convolution
,
*
x
,
1
,
nvinfer1
::
DimsHW
{
3
,
3
},
weight
.
get
(),
bias
.
get
());
PADDLE_ENFORCE
(
conv_layer
!=
nullptr
);
...
...
@@ -133,28 +168,36 @@ TEST_F(TensorRTEngineTest, test_conv2d) {
engine_
->
FreezeNetwork
();
ASSERT_EQ
(
engine_
->
engine
()
->
getNbBindings
(),
2
);
float
x_v
[
18
]
=
{
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
};
engine_
->
SetInputFromCPU
(
"x"
,
reinterpret_cast
<
void
*>
(
&
x_v
),
18
*
sizeof
(
float
));
engine_
->
Execute
(
2
);
// fill in real data
std
::
vector
<
float
>
x_v
=
{
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
};
std
::
vector
<
float
>
y_cpu
;
PrepareInputOutput
(
x_v
,
{
18
});
auto
*
x_v_gpu_data
=
input_
.
mutable_data
<
float
>
(
ctx_
->
GetPlace
());
auto
*
y_gpu_data
=
output_
.
mutable_data
<
float
>
(
ctx_
->
GetPlace
());
buffers
[
0
]
=
reinterpret_cast
<
void
*>
(
x_v_gpu_data
);
buffers
[
1
]
=
reinterpret_cast
<
void
*>
(
y_gpu_data
);
engine_
->
Execute
(
2
,
&
buffers
,
ctx_
->
stream
());
LOG
(
INFO
)
<<
"to get output"
;
float
*
y_cpu
=
new
float
[
18
]
;
engine_
->
GetOutputInCPU
(
"y"
,
&
y_cpu
[
0
],
18
*
sizeof
(
float
));
GetOutput
(
&
y_cpu
)
;
ASSERT_EQ
(
y_cpu
[
0
],
4.0
);
ASSERT_EQ
(
y_cpu
[
1
],
6.0
);
}
TEST_F
(
TensorRTEngineTest
,
test_pool2d
)
{
// Weight in CPU memory.
auto
*
x
=
engine_
->
DeclareInput
(
"x"
,
nvinfer1
::
DataType
::
kFLOAT
,
auto
*
x
=
engine_
->
DeclareInput
(
"x"
,
nvinfer1
::
DataType
::
kFLOAT
,
nvinfer1
::
Dims3
{
1
,
2
,
2
});
std
::
vector
<
void
*>
buffers
(
2
);
// TRT binded inputs
nvinfer1
::
PoolingType
pool_t
=
nvinfer1
::
PoolingType
::
kAVERAGE
;
auto
*
pool_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Pooling
,
*
const_cast
<
nvinfer1
::
ITensor
*>
(
x
),
pool_t
,
nvinfer1
::
DimsHW
{
2
,
2
});
auto
*
pool_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Pooling
,
*
x
,
pool_t
,
nvinfer1
::
DimsHW
{
2
,
2
});
PADDLE_ENFORCE
(
pool_layer
!=
nullptr
);
pool_layer
->
setStride
(
nvinfer1
::
DimsHW
{
1
,
1
});
...
...
@@ -164,14 +207,21 @@ TEST_F(TensorRTEngineTest, test_pool2d) {
engine_
->
FreezeNetwork
();
ASSERT_EQ
(
engine_
->
engine
()
->
getNbBindings
(),
2
);
float
x_v
[
8
]
=
{
1.0
,
2.0
,
5.0
,
0.0
,
2.0
,
3.0
,
5.0
,
10.0
};
engine_
->
SetInputFromCPU
(
"x"
,
reinterpret_cast
<
void
*>
(
&
x_v
),
8
*
sizeof
(
float
));
engine_
->
Execute
(
2
);
// fill in real data
std
::
vector
<
float
>
x_v
=
{
1.0
,
2.0
,
5.0
,
0.0
,
2.0
,
3.0
,
5.0
,
10.0
};
std
::
vector
<
float
>
y_cpu
;
PrepareInputOutput
(
x_v
,
{
2
});
auto
*
x_v_gpu_data
=
input_
.
mutable_data
<
float
>
(
ctx_
->
GetPlace
());
auto
*
y_gpu_data
=
output_
.
mutable_data
<
float
>
(
ctx_
->
GetPlace
());
buffers
[
0
]
=
reinterpret_cast
<
void
*>
(
x_v_gpu_data
);
buffers
[
1
]
=
reinterpret_cast
<
void
*>
(
y_gpu_data
);
engine_
->
Execute
(
2
,
&
buffers
,
ctx_
->
stream
());
LOG
(
INFO
)
<<
"to get output"
;
float
*
y_cpu
=
new
float
[
2
];
engine_
->
GetOutputInCPU
(
"y"
,
&
y_cpu
[
0
],
2
*
sizeof
(
float
));
GetOutput
(
&
y_cpu
);
ASSERT_EQ
(
y_cpu
[
0
],
2.0
);
ASSERT_EQ
(
y_cpu
[
1
],
5.0
);
...
...
paddle/fluid/inference/tests/api/trt_models_tester.cc
浏览文件 @
3d63aa0a
...
...
@@ -54,7 +54,8 @@ void SetConfig<AnalysisConfig>(AnalysisConfig* config, std::string model_dir,
if
(
use_gpu
)
{
config
->
EnableUseGpu
(
100
,
0
);
if
(
use_tensorrt
)
{
config
->
EnableTensorRtEngine
(
1
<<
10
,
batch_size
);
config
->
EnableTensorRtEngine
(
1
<<
10
,
batch_size
,
3
,
AnalysisConfig
::
Precision
::
kFloat32
,
false
);
config
->
pass_builder
()
->
DeletePass
(
"conv_bn_fuse_pass"
);
config
->
pass_builder
()
->
DeletePass
(
"fc_fuse_pass"
);
config
->
pass_builder
()
->
TurnOnDebug
();
...
...
paddle/fluid/operators/tensorrt/tensorrt_engine_op.cc
浏览文件 @
3d63aa0a
...
...
@@ -30,6 +30,9 @@ class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput
(
"Ys"
,
"A list of outputs"
).
AsDuplicable
();
AddAttr
<
std
::
string
>
(
"subgraph"
,
"the subgraph."
);
AddAttr
<
std
::
string
>
(
"calibration_data"
,
"the calibration data for int8"
);
AddAttr
<
std
::
string
>
(
"engine_serialized_data"
,
"the serialized data contains the all info of the ICUDAEngine"
);
AddAttr
<
std
::
string
>
(
"engine_key"
,
"The engine_key here is used to distinguish different TRT Engines"
);
...
...
paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
浏览文件 @
3d63aa0a
...
...
@@ -16,8 +16,10 @@
#ifdef PADDLE_WITH_CUDA
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/executor.h"
...
...
@@ -31,37 +33,6 @@ namespace paddle {
namespace
operators
{
using
FluidDT
=
framework
::
proto
::
VarType_Type
;
using
TRT_DT
=
nvinfer1
::
DataType
;
namespace
{
// NOLINT
TRT_DT
FluidDataType2TRT
(
FluidDT
type
)
{
switch
(
type
)
{
case
FluidDT
::
VarType_Type_FP32
:
return
TRT_DT
::
kFLOAT
;
case
FluidDT
::
VarType_Type_INT32
:
return
TRT_DT
::
kINT32
;
default:
return
TRT_DT
::
kINT32
;
}
PADDLE_THROW
(
"unkown type"
);
return
TRT_DT
::
kINT32
;
}
nvinfer1
::
Dims
Vec2TRT_Dims
(
const
std
::
vector
<
int64_t
>
&
shape
)
{
PADDLE_ENFORCE_GT
(
shape
.
size
(),
1UL
,
"TensorRT' tensor input requires at least 2 dimensions"
);
PADDLE_ENFORCE_LE
(
shape
.
size
(),
4UL
,
"TensorRT' tensor input requires at most 4 dimensions"
);
PADDLE_ENFORCE
(
shape
.
size
()
==
4UL
||
shape
.
size
()
==
2UL
);
if
(
shape
.
size
()
==
4UL
)
return
nvinfer1
::
DimsCHW
(
shape
[
1
],
shape
[
2
],
shape
[
3
]);
return
nvinfer1
::
DimsCHW
(
shape
[
1
],
1
,
1
);
}
}
// namespace // NOLINT
using
inference
::
Singleton
;
using
inference
::
tensorrt
::
TensorRTEngine
;
using
inference
::
tensorrt
::
TRTInt8Calibrator
;
...
...
@@ -79,6 +50,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
bool
enable_int8_
;
std
::
string
calibration_data_
;
std
::
string
engine_key_
;
std
::
string
engine_serialized_data_
;
bool
calibration_mode_
;
public:
...
...
@@ -93,6 +65,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
enable_int8_
=
Attr
<
bool
>
(
"enable_int8"
);
calibration_data_
=
Attr
<
std
::
string
>
(
"calibration_data"
);
engine_key_
=
Attr
<
std
::
string
>
(
"engine_key"
);
engine_serialized_data_
=
Attr
<
std
::
string
>
(
"engine_serialized_data"
);
auto
params
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"parameters"
);
for
(
const
auto
&
param
:
params
)
{
...
...
@@ -125,7 +98,8 @@ class TensorRTEngineOp : public framework::OperatorBase {
RunCalibration
(
scope
,
dev_place
);
return
;
}
RunTrt
(
scope
,
dev_place
);
auto
*
trt_engine
=
GetEngine
(
scope
,
dev_place
);
RunTrt
(
scope
,
dev_place
,
trt_engine
);
}
void
RunCalibration
(
const
framework
::
Scope
&
scope
,
...
...
@@ -136,10 +110,6 @@ class TensorRTEngineOp : public framework::OperatorBase {
LOG_FIRST_N
(
INFO
,
1
)
<<
"The TRT engine: "
<<
engine_key_
<<
" is running calibration trt int8... "
;
int
runtime_batch
=
1
;
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
dev_ctx
).
stream
();
if
(
!
Singleton
<
TRTCalibratorEngineManager
>::
Global
().
Has
(
engine_key_
))
{
TRTCalibratorEngine
*
calib_res
=
Singleton
<
TRTCalibratorEngineManager
>::
Global
().
Create
(
engine_key_
);
...
...
@@ -156,11 +126,11 @@ class TensorRTEngineOp : public framework::OperatorBase {
calib_buffers
,
runtime_batch
,
engine_key_
,
dev_place
));
calib_res
->
thr_
.
reset
(
new
std
::
thread
([
&
]()
{
calib_res
->
engine_
.
reset
(
new
TensorRTEngine
(
max_batch_size_
,
workspace_size_
,
stream
,
boost
::
get
<
platform
::
CUDAPlace
>
(
dev_place
).
device
,
enable_int8_
,
calib_res
->
calib_
.
get
()
));
max_batch_size_
,
workspace_size_
,
enable_int8_
,
calib_res
->
calib_
.
get
()
,
boost
::
get
<
platform
::
CUDAPlace
>
(
dev_place
).
device
));
VLOG
(
3
)
<<
"start the calib trt engine thread"
;
Prepare
(
scope
,
dev_plac
e
,
calib_res
->
engine_
.
get
());
Prepare
TRTEngine
(
scop
e
,
calib_res
->
engine_
.
get
());
}));
}
...
...
@@ -180,28 +150,29 @@ class TensorRTEngineOp : public framework::OperatorBase {
RunNativeImpl
(
scope
,
dev_place
);
}
void
RunTrt
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_plac
e
)
const
{
void
RunTrt
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
,
TensorRTEngine
*
engin
e
)
const
{
int
runtime_batch
=
1
;
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
dev_ctx
).
stream
();
if
(
trt_engine_
.
get
()
==
nullptr
)
{
trt_engine_
.
reset
(
new
TensorRTEngine
(
max_batch_size_
,
workspace_size_
,
stream
,
boost
::
get
<
platform
::
CUDAPlace
>
(
dev_place
).
device
,
enable_int8_
,
calibrator_
.
get
()));
Prepare
(
scope
,
dev_place
,
trt_engine_
.
get
());
}
auto
*
engine
=
trt_engine_
.
get
();
PADDLE_ENFORCE
(
!
input_names_
.
empty
(),
"should pass more than one inputs"
);
std
::
vector
<
std
::
string
>
output_maps
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"output_name_mapping"
);
// Convert input tensor from fluid to engine.
int
num_inputs
=
0
;
for
(
const
auto
&
x
:
Inputs
(
"Xs"
))
{
if
(
param_names_
.
count
(
x
))
continue
;
num_inputs
+=
1
;
}
const
int
num_bindings
=
num_inputs
+
Outputs
(
"Ys"
).
size
();
std
::
vector
<
void
*>
buffers
(
num_bindings
);
// Bind input tensor to TRT.
for
(
const
auto
&
x
:
Inputs
(
"Xs"
))
{
if
(
param_names_
.
count
(
x
))
continue
;
// convert input and copy to TRT engine's buffer
...
...
@@ -209,28 +180,20 @@ class TensorRTEngineOp : public framework::OperatorBase {
inference
::
analysis
::
GetFromScope
<
framework
::
LoDTensor
>
(
scope
,
x
);
auto
t_shape
=
framework
::
vectorize
(
t
.
dims
());
runtime_batch
=
t_shape
[
0
];
if
(
platform
::
is_cpu_place
(
t
.
place
()))
{
engine
->
SetInputFromCPU
(
x
,
static_cast
<
const
void
*>
(
t
.
data
<
void
>
()),
t
.
memory_size
());
}
else
{
engine
->
SetInputFromGPU
(
x
,
static_cast
<
const
void
*>
(
t
.
data
<
void
>
()),
t
.
memory_size
());
}
}
cudaStreamSynchronize
(
stream
);
PADDLE_ENFORCE_LE
(
runtime_batch
,
max_batch_size_
);
// Execute the engine.
engine
->
Execute
(
runtime_batch
);
const
int
bind_index
=
engine
->
engine
()
->
getBindingIndex
(
x
.
c_str
());
PADDLE_ENFORCE
(
bind_index
<
num_bindings
,
"The bind index should be less than num_bindings"
);
buffers
[
bind_index
]
=
static_cast
<
void
*>
(
t
.
data
<
float
>
());
}
//
Convert output tensor from engine to fluid
//
Bind output tensor to TRT.
int
output_index
=
0
;
VLOG
(
4
)
<<
"TensorRT Engine Op Outputs:"
;
for
(
const
auto
&
y
:
Outputs
(
"Ys"
))
{
VLOG
(
4
)
<<
y
;
// convert output and copy to fluid.
nvinfer1
::
ITensor
*
trt_t
=
engine
->
GetITensor
(
output_maps
[
output_index
]);
auto
dims
=
trt_t
->
getDimensions
();
const
int
bind_index
=
engine
->
engine
()
->
getBindingIndex
(
output_maps
[
output_index
].
c_str
());
auto
dims
=
engine
->
engine
()
->
getBindingDimensions
(
bind_index
);
// Use the output ITensor's dims to reshape the Fluid Tensor.
// The ITensor doesn't contain the batch size dim.
std
::
vector
<
int
>
ddim
;
...
...
@@ -238,71 +201,55 @@ class TensorRTEngineOp : public framework::OperatorBase {
for
(
int
i
=
0
;
i
<
dims
.
nbDims
;
i
++
)
{
ddim
.
push_back
(
dims
.
d
[
i
]);
}
auto
*
fluid_v
=
scope
.
FindVar
(
y
);
PADDLE_ENFORCE_NOT_NULL
(
fluid_v
,
"no output variable called %s"
,
y
);
auto
*
fluid_t
=
fluid_v
->
GetMutable
<
framework
::
LoDTensor
>
();
fluid_t
->
Resize
(
framework
::
make_ddim
(
ddim
));
// TODO(Superjomn) change this float to dtype size.
auto
size
=
inference
::
analysis
::
AccuDims
(
dims
.
d
,
dims
.
nbDims
)
*
runtime_batch
;
engine
->
GetOutputInGPU
(
output_maps
[
output_index
],
fluid_t
->
mutable_data
<
float
>
(
platform
::
CUDAPlace
(
boost
::
get
<
platform
::
CUDAPlace
>
(
dev_place
).
device
)),
size
*
sizeof
(
float
));
PADDLE_ENFORCE
(
bind_index
<
num_bindings
,
"The bind index should be less than num_bindings"
);
buffers
[
bind_index
]
=
static_cast
<
void
*>
(
fluid_t
->
mutable_data
<
float
>
(
boost
::
get
<
platform
::
CUDAPlace
>
(
dev_place
)));
output_index
+=
1
;
}
PADDLE_ENFORCE_LE
(
runtime_batch
,
max_batch_size_
);
// Execute the engine.
engine
->
Execute
(
runtime_batch
,
&
buffers
,
stream
);
cudaStreamSynchronize
(
stream
);
}
void
Prepare
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
,
TensorRTEngine
*
engine
)
const
{
TensorRTEngine
*
GetEngine
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
{
if
(
!
trt_engine_
)
{
trt_engine_
.
reset
(
new
inference
::
tensorrt
::
TensorRTEngine
(
max_batch_size_
,
workspace_size_
,
enable_int8_
,
calibrator_
.
get
(),
boost
::
get
<
platform
::
CUDAPlace
>
(
dev_place
).
device
));
if
(
!
engine_serialized_data_
.
empty
())
{
trt_engine_
->
Deserialize
(
engine_serialized_data_
);
}
else
{
PrepareTRTEngine
(
scope
,
trt_engine_
.
get
());
}
}
return
trt_engine_
.
get
();
}
void
PrepareTRTEngine
(
const
framework
::
Scope
&
scope
,
TensorRTEngine
*
engine
)
const
{
LOG
(
INFO
)
<<
"Prepare TRT engine (Optimize model structure, Select OP "
"kernel etc). This process may cost a lot of time."
;
framework
::
proto
::
BlockDesc
block_desc
;
block_desc
.
ParseFromString
(
Attr
<
std
::
string
>
(
"subgraph"
));
framework
::
proto
::
BlockDesc
block_proto
;
block_proto
.
ParseFromString
(
Attr
<
std
::
string
>
(
"subgraph"
));
framework
::
BlockDesc
block_desc
(
nullptr
,
&
block_proto
);
std
::
vector
<
std
::
string
>
output_maps
=
std
::
vector
<
std
::
string
>
inputs
=
Inputs
(
"Xs"
);
std
::
vector
<
std
::
string
>
outputs
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"output_name_mapping"
);
engine
->
InitNetwork
();
framework
::
BlockDesc
block
(
nullptr
/*programdesc*/
,
&
block_desc
);
VLOG
(
4
)
<<
"parsed var size "
<<
block
.
AllVars
().
size
();
// Add inputs
VLOG
(
4
)
<<
"declare inputs"
;
for
(
auto
&
input
:
Inputs
(
"Xs"
))
{
if
(
param_names_
.
count
(
input
))
continue
;
VLOG
(
4
)
<<
"declare input "
<<
input
;
auto
&
t
=
inference
::
analysis
::
GetFromScope
<
framework
::
LoDTensor
>
(
scope
,
input
);
auto
t_shape
=
framework
::
vectorize
(
t
.
dims
());
auto
*
var
=
block
.
FindVar
(
input
);
// TensorRT engine need to create parameters. The parameter's description
// should be set in
PADDLE_ENFORCE
(
var
,
"no variable called %s"
,
input
);
PADDLE_ENFORCE_EQ
(
var
->
GetType
(),
FluidDT
::
VarType_Type_LOD_TENSOR
,
"TensorRT engine only takes LoDTensor as input"
);
engine
->
DeclareInput
(
input
,
FluidDataType2TRT
(
var
->
Proto
()
->
type
().
lod_tensor
().
tensor
().
data_type
()),
Vec2TRT_Dims
(
t_shape
));
}
inference
::
Singleton
<
inference
::
tensorrt
::
OpConverter
>::
Global
()
.
ConvertBlock
(
block_desc
,
param_names_
,
scope
,
engine
);
// Add outputs
for
(
auto
&
output
:
output_maps
)
{
engine
->
DeclareOutput
(
output
);
}
engine
->
FreezeNetwork
();
.
ConvertBlockToTRTEngine
(
&
block_desc
,
scope
,
inputs
,
param_names_
,
outputs
,
engine
);
}
};
...
...
paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc
浏览文件 @
3d63aa0a
...
...
@@ -107,6 +107,7 @@ TEST(TensorRTEngineOp, manual) {
engine_op_desc
.
SetAttr
(
"output_name_mapping"
,
std
::
vector
<
std
::
string
>
({
"z0"
}));
engine_op_desc
.
SetAttr
(
"subgraph"
,
std
::
string
(
block_
->
SerializeAsString
()));
engine_op_desc
.
SetAttr
(
"engine_serialized_data"
,
std
::
string
(
""
));
LOG
(
INFO
)
<<
"create engine op"
;
auto
engine_op
=
framework
::
OpRegistry
::
CreateOp
(
engine_op_desc
);
...
...
@@ -202,6 +203,7 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) {
engine_op_desc
.
SetAttr
(
"output_name_mapping"
,
std
::
vector
<
std
::
string
>
({
"z3"
}));
engine_op_desc
.
SetAttr
(
"subgraph"
,
std
::
string
(
block_
->
SerializeAsString
()));
engine_op_desc
.
SetAttr
(
"engine_serialized_data"
,
std
::
string
(
""
));
auto
engine_op
=
framework
::
OpRegistry
::
CreateOp
(
engine_op_desc
);
...
...
paddle/fluid/pybind/inference_api.cc
浏览文件 @
3d63aa0a
...
...
@@ -221,7 +221,8 @@ void BindAnalysisConfig(py::module *m) {
.
def
(
"enable_tensorrt_engine"
,
&
AnalysisConfig
::
EnableTensorRtEngine
,
py
::
arg
(
"workspace_size"
)
=
1
<<
20
,
py
::
arg
(
"max_batch_size"
)
=
1
,
py
::
arg
(
"min_subgraph_size"
)
=
3
,
py
::
arg
(
"precision_mode"
)
=
AnalysisConfig
::
Precision
::
kFloat32
)
py
::
arg
(
"precision_mode"
)
=
AnalysisConfig
::
Precision
::
kFloat32
,
py
::
arg
(
"use_static"
)
=
true
)
.
def
(
"tensorrt_engine_enabled"
,
&
AnalysisConfig
::
tensorrt_engine_enabled
)
.
def
(
"switch_ir_debug"
,
&
AnalysisConfig
::
SwitchIrDebug
,
py
::
arg
(
"x"
)
=
true
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录