Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
88c24baa
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
88c24baa
编写于
2月 14, 2019
作者:
N
nhzlx
提交者:
ceci3
3月 08, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add static model load for trt
1. bind trt input and output to fluid tensors
上级
e1c28e7f
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
313 addition
and
344 deletion
+313
-344
paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
...id/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
+113
-62
paddle/fluid/inference/engine.h
paddle/fluid/inference/engine.h
+0
-5
paddle/fluid/inference/tensorrt/convert/conv2d_op.cc
paddle/fluid/inference/tensorrt/convert/conv2d_op.cc
+1
-18
paddle/fluid/inference/tensorrt/convert/ut_helper.h
paddle/fluid/inference/tensorrt/convert/ut_helper.h
+43
-26
paddle/fluid/inference/tensorrt/engine.cc
paddle/fluid/inference/tensorrt/engine.cc
+7
-110
paddle/fluid/inference/tensorrt/engine.h
paddle/fluid/inference/tensorrt/engine.h
+2
-39
paddle/fluid/inference/tensorrt/test_engine.cc
paddle/fluid/inference/tensorrt/test_engine.cc
+89
-43
paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
+58
-41
未找到文件。
paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
浏览文件 @
88c24baa
...
...
@@ -33,6 +33,14 @@ using framework::ir::Node;
std
::
vector
<
std
::
string
>
ExtractParameters
(
const
std
::
unordered_set
<
Node
*>
&
nodes
);
void
RenameAndGetOutputs
(
const
std
::
vector
<
framework
::
ir
::
Node
*>
&
subgraph_nodes
,
framework
::
BlockDesc
*
block_desc
,
const
std
::
set
<
std
::
string
>
&
input_names_with_id
,
std
::
set
<
std
::
string
>
*
output_names_with_id
,
std
::
set
<
std
::
string
>
*
output_names
,
std
::
unordered_map
<
std
::
string
,
std
::
string
>
*
output_name_map
);
std
::
unique_ptr
<
framework
::
ir
::
Graph
>
analysis
::
TensorRtSubgraphPass
::
ApplyImpl
(
std
::
unique_ptr
<
framework
::
ir
::
Graph
>
graph
)
const
{
...
...
@@ -120,9 +128,6 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
input_names
.
insert
(
x
->
Name
());
input_names_with_id
.
insert
(
x
->
Name
()
+
std
::
to_string
(
x
->
id
()));
}
op_desc
->
SetInput
(
"Xs"
,
std
::
vector
<
std
::
string
>
(
input_names
.
begin
(),
input_names
.
end
()));
std
::
set
<
std
::
string
>
output_names
;
std
::
set
<
std
::
string
>
output_names_with_id
;
for
(
auto
*
x
:
node
->
outputs
)
{
...
...
@@ -130,11 +135,8 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
output_names_with_id
.
insert
(
x
->
Name
()
+
std
::
to_string
(
x
->
id
()));
}
op_desc
->
SetOutput
(
"Ys"
,
std
::
vector
<
std
::
string
>
(
output_names
.
begin
(),
output_names
.
end
()));
op_desc
->
SetType
(
"tensorrt_engine"
);
std
::
unordered_map
<
std
::
string
,
std
::
string
>
output_name_map
;
auto
&
subgraph_nodes
=
*
Agent
(
node
).
subgraph
();
// The following procedure is used to rename all the intermediate
// variables and the output variables of the subgraph.
...
...
@@ -148,61 +150,8 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
// input of a OP, but also the output of a Op, there will be problems.
// So we have to rename the variable in the subgraph to make sure
// it is either an OP's input or an OP's output.
auto
&
subgraph_nodes
=
*
Agent
(
node
).
subgraph
();
for
(
size_t
index
=
0
;
index
<
block_desc
.
OpSize
();
++
index
)
{
framework
::
proto
::
OpDesc
*
op
=
block_desc
.
Op
(
index
)
->
Proto
();
auto
correspond_node
=
subgraph_nodes
[
index
];
PADDLE_ENFORCE_EQ
(
correspond_node
->
Name
(),
op
->
type
());
std
::
unordered_map
<
std
::
string
,
size_t
>
var2id
;
for
(
auto
*
in_var
:
correspond_node
->
inputs
)
{
var2id
[
in_var
->
Name
()]
=
in_var
->
id
();
}
// rename for the input variables of op inside subgraph
for
(
int
i
=
0
;
i
<
op
->
inputs_size
();
i
++
)
{
// one input
auto
*
in_var
=
op
->
mutable_inputs
(
i
);
std
::
vector
<
std
::
string
>
replaced_names
;
for
(
int
k
=
0
;
k
<
in_var
->
arguments_size
();
k
++
)
{
// all the arguments
std
::
string
arg_value
=
in_var
->
arguments
(
k
);
std
::
string
arg_value_with_id
=
arg_value
+
std
::
to_string
(
var2id
[
arg_value
]);
if
(
input_names_with_id
.
count
(
arg_value_with_id
))
{
replaced_names
.
push_back
(
arg_value
);
}
else
{
replaced_names
.
push_back
(
arg_value_with_id
);
}
}
in_var
->
clear_arguments
();
for
(
size_t
k
=
0
;
k
<
replaced_names
.
size
();
k
++
)
{
in_var
->
add_arguments
(
replaced_names
[
k
]);
}
}
var2id
.
clear
();
for
(
auto
out_var
:
correspond_node
->
outputs
)
{
var2id
[
out_var
->
Name
()]
=
out_var
->
id
();
}
// rename for the output variables of op inside subgraph
for
(
int
i
=
0
;
i
<
op
->
outputs_size
();
i
++
)
{
framework
::
proto
::
OpDesc_Var
*
out_var
=
op
->
mutable_outputs
(
i
);
std
::
vector
<
std
::
string
>
replaced_names
;
for
(
int
k
=
0
;
k
<
out_var
->
arguments_size
();
k
++
)
{
std
::
string
arg_value
=
out_var
->
arguments
(
k
);
std
::
string
arg_value_with_id
=
arg_value
+
std
::
to_string
(
var2id
[
arg_value
]);
if
(
output_names_with_id
.
count
(
arg_value_with_id
))
{
output_name_map
[
arg_value
]
=
arg_value_with_id
;
}
replaced_names
.
push_back
(
arg_value_with_id
);
}
out_var
->
clear_arguments
();
for
(
size_t
k
=
0
;
k
<
replaced_names
.
size
();
k
++
)
{
out_var
->
add_arguments
(
replaced_names
[
k
]);
}
}
}
RenameAndGetOutputs
(
subgraph_nodes
,
&
block_desc
,
input_names_with_id
,
&
output_names_with_id
,
&
output_names
,
&
output_name_map
);
// When tensorrt engine runs at the end of the operation,
// output_mapping help us copy the data from the renamed ITensor
...
...
@@ -222,6 +171,14 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
PADDLE_ENFORCE
(
!
block_desc
.
Proto
()
->
vars
().
empty
(),
"the block has no var-desc"
);
op_desc
->
SetInput
(
"Xs"
,
std
::
vector
<
std
::
string
>
(
input_names
.
begin
(),
input_names
.
end
()));
op_desc
->
SetOutput
(
"Ys"
,
std
::
vector
<
std
::
string
>
(
output_names
.
begin
(),
output_names
.
end
()));
op_desc
->
SetType
(
"tensorrt_engine"
);
PADDLE_ENFORCE
(
!
output_mapping
.
empty
());
op_desc
->
SetBlockAttr
(
"sub_block"
,
new_block
);
SetAttr
(
op_desc
->
Proto
(),
"subgraph"
,
...
...
@@ -236,6 +193,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
auto
engine_key
=
GenerateEngineKey
(
input_names_with_id
,
output_names_with_id
);
// Get "" when there is no cached calibration table data.
std
::
string
calibration_data
=
GetTrtCalibTableData
(
Get
<
std
::
string
>
(
"model_opt_cache_dir"
),
engine_key
,
enable_int8
);
SetAttr
(
op_desc
->
Proto
(),
"calibration_data"
,
calibration_data
);
...
...
@@ -272,6 +230,99 @@ std::vector<std::string> ExtractParameters(
return
parameters
;
}
void
RenameAndGetOutputs
(
const
std
::
vector
<
framework
::
ir
::
Node
*>
&
subgraph_nodes
,
framework
::
BlockDesc
*
block_desc
,
const
std
::
set
<
std
::
string
>
&
input_names_with_id
,
std
::
set
<
std
::
string
>
*
output_names_with_id
,
std
::
set
<
std
::
string
>
*
output_names
,
std
::
unordered_map
<
std
::
string
,
std
::
string
>
*
output_name_map
)
{
//// In the normal case, the paddle-trt exists bug when runing the googlenet.
// When there are more than two convolutions of 1 * 1 with the same input, the
// paddle-tensorrt will do the merging optimization, which fuse those conv
// into one conv, and then trigger bug. So, We should use strategy to avoid
// this optimization for the time being. This bug will be fixed in the future.
std
::
unordered_map
<
std
::
string
/*name*/
,
int
/*ITensor_quote_num*/
>
same_hierarchy_conv2d_num_map
;
for
(
size_t
index
=
0
;
index
<
block_desc
->
OpSize
();
++
index
)
{
framework
::
proto
::
OpDesc
*
op
=
block_desc
->
Op
(
index
)
->
Proto
();
framework
::
OpDesc
op_desc
(
*
op
,
nullptr
);
auto
correspond_node
=
subgraph_nodes
[
index
];
PADDLE_ENFORCE_EQ
(
correspond_node
->
Name
(),
op
->
type
());
std
::
unordered_map
<
std
::
string
,
size_t
>
var2id
;
std
::
unordered_map
<
std
::
string
,
framework
::
ir
::
Node
*>
in_vars
;
for
(
auto
*
in_var
:
correspond_node
->
inputs
)
{
var2id
[
in_var
->
Name
()]
=
in_var
->
id
();
in_vars
[
in_var
->
Name
()]
=
in_var
;
}
// rename for the input variables of op inside subgraph
for
(
int
i
=
0
;
i
<
op
->
inputs_size
();
i
++
)
{
// one input
auto
*
in_var
=
op
->
mutable_inputs
(
i
);
std
::
vector
<
std
::
string
>
replaced_names
;
for
(
int
k
=
0
;
k
<
in_var
->
arguments_size
();
k
++
)
{
// all the arguments
std
::
string
arg_value
=
in_var
->
arguments
(
k
);
std
::
string
arg_value_with_id
=
arg_value
+
std
::
to_string
(
var2id
[
arg_value
]);
if
(
input_names_with_id
.
count
(
arg_value_with_id
))
{
replaced_names
.
push_back
(
arg_value
);
}
else
{
replaced_names
.
push_back
(
arg_value_with_id
);
}
}
in_var
->
clear_arguments
();
for
(
size_t
k
=
0
;
k
<
replaced_names
.
size
();
k
++
)
{
in_var
->
add_arguments
(
replaced_names
[
k
]);
}
}
var2id
.
clear
();
for
(
auto
out_var
:
correspond_node
->
outputs
)
{
var2id
[
out_var
->
Name
()]
=
out_var
->
id
();
}
if
(
op_desc
.
Type
()
==
"conv2d"
)
{
auto
input_var_name
=
op_desc
.
Input
(
"Input"
).
front
();
auto
filter_var_name
=
op_desc
.
Input
(
"Filter"
).
front
();
auto
out_var_name
=
op_desc
.
Output
(
"Output"
).
front
();
auto
filter_shape
=
in_vars
[
filter_var_name
]
->
Var
()
->
GetShape
();
const
std
::
vector
<
int
>
strides
=
boost
::
get
<
std
::
vector
<
int
>>
(
op_desc
.
GetAttr
(
"strides"
));
const
std
::
vector
<
int
>
paddings
=
boost
::
get
<
std
::
vector
<
int
>>
(
op_desc
.
GetAttr
(
"paddings"
));
if
(
same_hierarchy_conv2d_num_map
[
input_var_name
]
>
0
)
{
(
*
output_names_with_id
)
.
insert
(
out_var_name
+
std
::
to_string
(
var2id
[
out_var_name
]));
(
*
output_names
).
insert
(
out_var_name
);
}
else
if
(
filter_shape
[
2
]
==
1
&&
filter_shape
[
3
]
==
1
&&
strides
[
0
]
==
1
&&
strides
[
1
]
==
1
&&
paddings
[
0
]
==
0
&&
paddings
[
1
]
==
0
)
{
same_hierarchy_conv2d_num_map
[
input_var_name
]
+=
1
;
}
}
// rename for the output variables of op inside subgraph
for
(
int
i
=
0
;
i
<
op
->
outputs_size
();
i
++
)
{
framework
::
proto
::
OpDesc_Var
*
out_var
=
op
->
mutable_outputs
(
i
);
std
::
vector
<
std
::
string
>
replaced_names
;
for
(
int
k
=
0
;
k
<
out_var
->
arguments_size
();
k
++
)
{
std
::
string
arg_value
=
out_var
->
arguments
(
k
);
std
::
string
arg_value_with_id
=
arg_value
+
std
::
to_string
(
var2id
[
arg_value
]);
if
(
output_names_with_id
->
count
(
arg_value_with_id
))
{
(
*
output_name_map
)[
arg_value
]
=
arg_value_with_id
;
}
replaced_names
.
push_back
(
arg_value_with_id
);
}
out_var
->
clear_arguments
();
for
(
size_t
k
=
0
;
k
<
replaced_names
.
size
();
k
++
)
{
out_var
->
add_arguments
(
replaced_names
[
k
]);
}
}
}
}
}
// namespace analysis
}
// namespace inference
}
// namespace paddle
...
...
paddle/fluid/inference/engine.h
浏览文件 @
88c24baa
...
...
@@ -49,11 +49,6 @@ class EngineBase {
// Execute the engine, that will run the inference network.
virtual
void
Execute
(
int
batch_size
)
=
0
;
// Return the IO buffer that allocated in engine. One can read/write directly
// on the buffer. If the buffer's buffer is nullptr, one can also allocate
// memory and maintain it outside the engine.
virtual
Buffer
&
buffer
(
const
std
::
string
&
name
)
=
0
;
virtual
~
EngineBase
()
{}
};
// class EngineBase
...
...
paddle/fluid/inference/tensorrt/convert/conv2d_op.cc
浏览文件 @
88c24baa
...
...
@@ -18,21 +18,6 @@ namespace paddle {
namespace
inference
{
namespace
tensorrt
{
bool
to_skip_merging_optimize
(
TensorRTEngine
*
engine
,
const
std
::
vector
<
int
>&
filters
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
std
::
string
input_name
)
{
if
(
engine
->
itensor_quote_num
[
input_name
]
>
0
)
{
return
true
;
}
if
(
filters
[
0
]
==
1
&&
filters
[
1
]
==
1
&&
strides
[
0
]
==
1
&&
strides
[
1
]
==
1
&&
paddings
[
0
]
==
0
&&
paddings
[
1
]
==
0
)
engine
->
itensor_quote_num
[
input_name
]
+=
1
;
return
false
;
}
template
<
typename
RegistFunc
,
typename
SetDilationFunc
>
void
ConvertConv2d
(
TensorRTEngine
*
engine
,
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
,
...
...
@@ -100,9 +85,7 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op,
layer
->
getOutput
(
0
)
->
setName
(
output_name
.
c_str
());
engine
->
SetITensor
(
output_name
,
layer
->
getOutput
(
0
));
if
(
test_mode
||
to_skip_merging_optimize
(
engine
,
{
filter_h
,
filter_w
},
strides
,
paddings
,
op_desc
.
Input
(
"Input"
).
front
()))
{
if
(
test_mode
)
{
engine
->
DeclareOutput
(
output_name
);
}
}
...
...
paddle/fluid/inference/tensorrt/convert/ut_helper.h
浏览文件 @
88c24baa
...
...
@@ -146,19 +146,6 @@ class TRTConvertValidation {
// Declare outputs.
op_desc_
.
reset
(
new
framework
::
OpDesc
(
desc
,
nullptr
));
// Set Inputs.
for
(
const
auto
&
input
:
op_desc_
->
InputArgumentNames
())
{
if
(
parameters_
.
count
(
input
))
continue
;
auto
*
var
=
scope_
.
FindVar
(
input
);
PADDLE_ENFORCE
(
var
);
auto
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
engine_
->
SetInputFromGPU
(
input
,
static_cast
<
void
*>
(
tensor
->
data
<
void
>
()),
sizeof
(
float
)
*
analysis
::
AccuDims
(
tensor
->
dims
(),
tensor
->
dims
().
size
()));
}
}
// We use the set 'neglected_output' here, because some Ops like batch norm,
...
...
@@ -171,34 +158,64 @@ class TRTConvertValidation {
platform
::
CUDAPlace
place
;
platform
::
CUDADeviceContext
ctx
(
place
);
op_
->
Run
(
scope_
,
place
);
std
::
vector
<
std
::
string
>
input_output_names
;
// Note: we need filter the parameter
for
(
const
auto
&
input
:
op_desc_
->
InputArgumentNames
())
{
if
(
parameters_
.
count
(
input
))
continue
;
input_output_names
.
push_back
(
input
);
}
// Collect the fluid outputs.
std
::
vector
<
std
::
vector
<
float
>>
fluid_outs
;
for
(
const
auto
&
output
:
op_desc_
->
OutputArgumentNames
())
{
if
(
neglected_output
.
count
(
output
))
continue
;
input_output_names
.
push_back
(
output
);
std
::
vector
<
float
>
fluid_out
;
auto
*
var
=
scope_
.
FindVar
(
output
);
auto
*
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
framework
::
TensorToVector
(
*
tensor
,
ctx
,
&
fluid_out
);
fluid_outs
.
push_back
(
fluid_out
);
}
// Bind input and output for TRT.
const
int
num_bindings
=
input_output_names
.
size
();
std
::
vector
<
void
*>
buffers
(
num_bindings
);
for
(
const
std
::
string
&
name
:
input_output_names
)
{
auto
*
var
=
scope_
.
FindVar
(
name
);
auto
*
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
const
int
bind_index
=
engine_
->
engine
()
->
getBindingIndex
(
name
.
c_str
());
buffers
[
bind_index
]
=
static_cast
<
void
*>
(
tensor
->
mutable_data
<
float
>
(
place
));
}
// Execute TRT.
engine_
->
Execute
(
batch_size
);
engine_
->
Execute
(
batch_size
,
buffers
);
cudaStreamSynchronize
(
engine_
->
stream
());
ASSERT_FALSE
(
op_desc_
->
OutputArgumentNames
().
empty
());
const
size_t
output_space_size
=
300
0
;
int
index
=
0
;
for
(
const
auto
&
output
:
op_desc_
->
OutputArgumentNames
())
{
if
(
neglected_output
.
count
(
output
))
continue
;
std
::
vector
<
float
>
fluid_out
;
std
::
vector
<
float
>
trt_out
(
output_space_size
);
engine_
->
GetOutputInCPU
(
output
,
&
trt_out
[
0
],
output_space_size
);
cudaStreamSynchronize
(
engine_
->
stream
());
std
::
vector
<
float
>
trt_out
;
auto
*
var
=
scope_
.
FindVar
(
output
);
auto
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
framework
::
TensorToVector
(
*
tensor
,
ctx
,
&
fluid
_out
);
auto
*
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
framework
::
TensorToVector
(
*
tensor
,
ctx
,
&
trt
_out
);
size_t
fluid_out_size
=
fluid_out
.
size
();
size_t
fluid_out_size
=
fluid_out
s
[
index
]
.
size
();
if
(
if_add_batch_
==
true
)
{
fluid_out_size
=
batch_size
*
(
framework
::
product
(
tensor
->
dims
())
/
max_batch_size_
);
}
// Compare two output
ASSERT_FALSE
(
fluid_out
.
empty
());
for
(
size_t
i
=
0
;
i
<
fluid_out_size
;
i
++
)
{
// Loose the threshold for CI in different machine model.
EXPECT_LT
(
std
::
abs
(
fluid_out
[
i
]
-
trt_out
[
i
]),
2e-5
);
EXPECT_LT
(
std
::
abs
(
fluid_out
s
[
index
]
[
i
]
-
trt_out
[
i
]),
2e-5
);
}
index
+=
1
;
}
}
...
...
paddle/fluid/inference/tensorrt/engine.cc
浏览文件 @
88c24baa
...
...
@@ -32,8 +32,14 @@ void TensorRTEngine::Build(const DescType &paddle_model) {
PADDLE_ENFORCE
(
false
,
"not implemented"
);
}
void
TensorRTEngine
::
Execute
(
int
batch_size
,
std
::
vector
<
void
*>
&
buffers
)
{
batch_size_
=
batch_size
;
infer_context_
->
enqueue
(
batch_size
,
buffers
.
data
(),
stream_
,
nullptr
);
cudaStreamSynchronize
(
stream_
);
SetRuntimeBatch
(
batch_size
);
}
void
TensorRTEngine
::
Execute
(
int
batch_size
)
{
freshDeviceId
();
batch_size_
=
batch_size
;
std
::
vector
<
void
*>
buffers
;
for
(
auto
&
buf
:
buffers_
)
{
...
...
@@ -61,7 +67,6 @@ TensorRTEngine::~TensorRTEngine() {
void
TensorRTEngine
::
FreezeNetwork
()
{
VLOG
(
3
)
<<
"TRT to freeze network"
;
freshDeviceId
();
PADDLE_ENFORCE
(
infer_builder_
!=
nullptr
,
"Call InitNetwork first to initialize network."
);
PADDLE_ENFORCE
(
infer_network_
!=
nullptr
,
...
...
@@ -81,30 +86,6 @@ void TensorRTEngine::FreezeNetwork() {
PADDLE_ENFORCE
(
infer_engine_
!=
nullptr
,
"build cuda engine failed!"
);
infer_context_
.
reset
(
infer_engine_
->
createExecutionContext
());
// allocate GPU buffers.
buffers_
.
resize
(
buffer_sizes_
.
size
());
for
(
auto
&
item
:
buffer_sizes_
)
{
// The output buffers are not set in the network building phrase, need to
// infer from the TesorRT network.
if
(
item
.
second
==
0
)
{
auto
slot_offset
=
infer_engine_
->
getBindingIndex
(
item
.
first
.
c_str
());
auto
dims
=
infer_engine_
->
getBindingDimensions
(
slot_offset
);
item
.
second
=
kDataTypeSize
[
static_cast
<
int
>
(
infer_engine_
->
getBindingDataType
(
slot_offset
))]
*
analysis
::
AccuDims
(
dims
.
d
,
dims
.
nbDims
)
*
max_batch_
;
PADDLE_ENFORCE_GT
(
item
.
second
,
0
);
}
auto
&
buf
=
buffer
(
item
.
first
);
buf
.
max_size
=
item
.
second
*
max_batch_
;
CHECK
(
buf
.
buffer
==
nullptr
);
// buffer should be allocated only once.
PADDLE_ENFORCE_EQ
(
0
,
cudaMalloc
(
&
buf
.
buffer
,
item
.
second
*
max_batch_
));
buf
.
size
=
0
;
PADDLE_ENFORCE_LE
(
buf
.
max_size
,
1
<<
30
);
// 10G
buf
.
device
=
DeviceType
::
GPU
;
}
}
nvinfer1
::
ITensor
*
TensorRTEngine
::
DeclareInput
(
const
std
::
string
&
name
,
...
...
@@ -158,83 +139,6 @@ void TensorRTEngine::DeclareOutput(const std::string &name) {
buffer_sizes_
[
name
]
=
0
;
}
void
*
TensorRTEngine
::
GetOutputInGPU
(
const
std
::
string
&
name
)
{
return
buffer
(
name
).
buffer
;
}
void
TensorRTEngine
::
GetOutputInGPU
(
const
std
::
string
&
name
,
void
*
dst
,
size_t
max_size
)
{
// determine data size
auto
*
output
=
TensorRTEngine
::
GetITensor
(
name
);
nvinfer1
::
Dims
dims
=
output
->
getDimensions
();
auto
dim_size
=
analysis
::
AccuDims
(
dims
.
d
,
dims
.
nbDims
);
size_t
dst_size
=
dim_size
*
runtime_batch_
*
kDataTypeSize
[
static_cast
<
int
>
(
output
->
getType
())];
auto
it
=
buffer_sizes_
.
find
(
name
);
PADDLE_ENFORCE
(
it
!=
buffer_sizes_
.
end
());
PADDLE_ENFORCE_GT
(
it
->
second
,
0
);
PADDLE_ENFORCE_LE
(
dst_size
,
it
->
second
);
PADDLE_ENFORCE_GE
(
max_size
,
dst_size
);
auto
&
buf
=
buffer
(
name
);
PADDLE_ENFORCE_NOT_NULL
(
buf
.
buffer
,
"buffer should be allocated before"
);
PADDLE_ENFORCE_EQ
(
cudaMemcpyAsync
(
dst
,
buf
.
buffer
,
dst_size
,
cudaMemcpyDeviceToDevice
,
stream_
),
0
);
}
void
TensorRTEngine
::
GetOutputInCPU
(
const
std
::
string
&
name
,
void
*
dst
,
size_t
max_size
)
{
// determine data size
auto
*
output
=
TensorRTEngine
::
GetITensor
(
name
);
nvinfer1
::
Dims
dims
=
output
->
getDimensions
();
auto
dim_size
=
analysis
::
AccuDims
(
dims
.
d
,
dims
.
nbDims
);
size_t
dst_size
=
dim_size
*
runtime_batch_
*
kDataTypeSize
[
static_cast
<
int
>
(
output
->
getType
())];
auto
it
=
buffer_sizes_
.
find
(
name
);
PADDLE_ENFORCE
(
it
!=
buffer_sizes_
.
end
());
PADDLE_ENFORCE_GT
(
it
->
second
,
0
);
PADDLE_ENFORCE_LE
(
dst_size
,
it
->
second
);
PADDLE_ENFORCE_GE
(
max_size
,
dst_size
);
auto
&
buf
=
buffer
(
name
);
PADDLE_ENFORCE_NOT_NULL
(
buf
.
buffer
,
"buffer should be allocated before"
);
PADDLE_ENFORCE_EQ
(
0
,
cudaMemcpyAsync
(
dst
,
buf
.
buffer
,
dst_size
,
cudaMemcpyDeviceToHost
,
stream_
));
}
Buffer
&
TensorRTEngine
::
buffer
(
const
std
::
string
&
name
)
{
PADDLE_ENFORCE
(
infer_engine_
!=
nullptr
,
"call FreezeNetwork first."
);
auto
it
=
buffer_sizes_
.
find
(
name
);
PADDLE_ENFORCE
(
it
!=
buffer_sizes_
.
end
(),
"tried to access buffer named %s"
,
name
);
auto
slot_offset
=
infer_engine_
->
getBindingIndex
(
name
.
c_str
());
return
buffers_
[
slot_offset
];
}
void
TensorRTEngine
::
SetInputFromCPU
(
const
std
::
string
&
name
,
const
void
*
data
,
size_t
size
)
{
auto
&
buf
=
buffer
(
name
);
PADDLE_ENFORCE_NOT_NULL
(
buf
.
buffer
);
PADDLE_ENFORCE_NOT_NULL
(
data
);
PADDLE_ENFORCE_LE
(
size
,
buf
.
max_size
,
"buffer is too small"
);
PADDLE_ENFORCE
(
buf
.
device
==
DeviceType
::
GPU
);
buf
.
size
=
size
;
PADDLE_ENFORCE_EQ
(
0
,
cudaMemcpyAsync
(
buf
.
buffer
,
data
,
size
,
cudaMemcpyHostToDevice
,
stream_
));
}
void
TensorRTEngine
::
SetInputFromGPU
(
const
std
::
string
&
name
,
const
void
*
data
,
size_t
size
)
{
auto
&
buf
=
buffer
(
name
);
buf
.
size
=
size
;
PADDLE_ENFORCE_NOT_NULL
(
buf
.
buffer
);
PADDLE_ENFORCE_LE
(
size
,
buf
.
max_size
,
"buffer is too small"
);
PADDLE_ENFORCE
(
buf
.
device
==
DeviceType
::
GPU
);
PADDLE_ENFORCE_EQ
(
0
,
cudaMemcpyAsync
(
buf
.
buffer
,
data
,
size
,
cudaMemcpyDeviceToDevice
,
stream_
));
}
void
TensorRTEngine
::
SetITensor
(
const
std
::
string
&
name
,
nvinfer1
::
ITensor
*
tensor
)
{
PADDLE_ENFORCE
(
tensor
!=
nullptr
);
...
...
@@ -254,13 +158,6 @@ void TensorRTEngine::SetRuntimeBatch(size_t batch_size) {
int
TensorRTEngine
::
GetRuntimeBatch
()
{
return
runtime_batch_
;
}
void
TensorRTEngine
::
freshDeviceId
()
{
int
count
;
cudaGetDeviceCount
(
&
count
);
PADDLE_ENFORCE_LT
(
device_
,
count
);
cudaSetDevice
(
device_
);
}
nvinfer1
::
IPluginLayer
*
TensorRTEngine
::
AddPlugin
(
nvinfer1
::
ITensor
*
const
*
inputs
,
int
num_inputs
,
plugin
::
PluginTensorRT
*
plugin
)
{
...
...
paddle/fluid/inference/tensorrt/engine.h
浏览文件 @
88c24baa
...
...
@@ -57,13 +57,12 @@ class TensorRTEngine : public EngineBase {
};
TensorRTEngine
(
int
max_batch
,
int
max_workspace
,
cudaStream_t
stream
,
int
device
=
0
,
bool
enable_int8
=
false
,
bool
enable_int8
=
false
,
TRTInt8Calibrator
*
calibrator
=
nullptr
,
nvinfer1
::
ILogger
&
logger
=
NaiveLogger
::
Global
())
:
max_batch_
(
max_batch
),
max_workspace_
(
max_workspace
),
stream_
(
stream
),
device_
(
device
),
enable_int8_
(
enable_int8
),
calibrator_
(
calibrator
),
logger_
(
logger
)
{}
...
...
@@ -74,6 +73,7 @@ class TensorRTEngine : public EngineBase {
void
Build
(
const
DescType
&
paddle_model
)
override
;
void
Execute
(
int
batch_size
)
override
;
void
Execute
(
int
batch_size
,
std
::
vector
<
void
*>&
buffers
);
// Initialize the inference network, so that TensorRT layers can add to this
// network.
...
...
@@ -98,28 +98,8 @@ class TensorRTEngine : public EngineBase {
// Check if the ITensor has been declared
bool
HasDeclared
(
const
std
::
string
&
name
);
// GPU memory address for an ITensor with specific name. One can operate on
// these memory directly for acceleration, for example, output the converted
// data directly to the buffer to save data copy overhead.
// NOTE this should be used after calling `FreezeNetwork`.
Buffer
&
buffer
(
const
std
::
string
&
name
)
override
;
cudaStream_t
stream
()
{
return
stream_
;
}
// Fill an input from CPU memory with name and size.
void
SetInputFromCPU
(
const
std
::
string
&
name
,
const
void
*
data
,
size_t
size
);
// TODO(Superjomn) is this method necessary given that buffer(xxx) can be
// accessed directly. Fill an input from GPU memory with name and size.
void
SetInputFromGPU
(
const
std
::
string
&
name
,
const
void
*
data
,
size_t
size
);
// Get an output called name, the output of tensorrt is in GPU, so this method
// Return the output's GPU memory address without copy.
void
*
GetOutputInGPU
(
const
std
::
string
&
name
);
// Copy data into dst inside the GPU device.
void
GetOutputInGPU
(
const
std
::
string
&
name
,
void
*
dst
,
size_t
max_size
);
// LOW EFFICENCY! Get output to CPU, this will trigger a memory copy from GPU
// to CPU.
void
GetOutputInCPU
(
const
std
::
string
&
name
,
void
*
dst
,
size_t
max_size
);
// Fill an ITensor into map itensor_map_.
void
SetITensor
(
const
std
::
string
&
name
,
nvinfer1
::
ITensor
*
tensor
);
// Get an ITensor called name.
nvinfer1
::
ITensor
*
GetITensor
(
const
std
::
string
&
name
);
...
...
@@ -128,7 +108,6 @@ class TensorRTEngine : public EngineBase {
nvinfer1
::
INetworkDefinition
*
network
()
{
return
infer_network_
.
get
();
}
void
SetRuntimeBatch
(
size_t
batch_size
);
int
GetRuntimeBatch
();
int
GetDevice
()
{
return
device_
;
}
nvinfer1
::
IPluginLayer
*
AddPlugin
(
nvinfer1
::
ITensor
*
const
*
inputs
,
int
num_inputs
,
plugin
::
PluginTensorRT
*
);
...
...
@@ -140,16 +119,6 @@ class TensorRTEngine : public EngineBase {
std
::
unordered_map
<
std
::
string
/*name*/
,
std
::
unique_ptr
<
framework
::
Tensor
>>
weight_map
;
// TODO(NHZLX)
// In the normal case, the paddle-trt exists bug when runing the googlenet.
// When there are more than two convolutions of 1 * 1 with the same input, the
// paddle-tensorrt will do the merging optimization, which fuse those conv
// into one conv, and then trigger bug. So, We should use strategy to avoid
// this
// optimization for the time being. This bug will be fixed in the future.
std
::
unordered_map
<
std
::
string
/*name*/
,
int
/*ITensor_quote_num*/
>
itensor_quote_num
;
private:
// the max batch size
int
max_batch_
;
...
...
@@ -159,8 +128,6 @@ class TensorRTEngine : public EngineBase {
int
max_workspace_
;
cudaStream_t
stream_
;
// The specific GPU id that the TensorRTEngine bounded to.
int
device_
;
bool
enable_int8_
;
TRTInt8Calibrator
*
calibrator_
;
...
...
@@ -192,10 +159,6 @@ class TensorRTEngine : public EngineBase {
infer_ptr
<
nvinfer1
::
INetworkDefinition
>
infer_network_
;
infer_ptr
<
nvinfer1
::
ICudaEngine
>
infer_engine_
;
infer_ptr
<
nvinfer1
::
IExecutionContext
>
infer_context_
;
// Each ICudaEngine object is bound to a specific GPU when it is instantiated,
// ensure that the thread is associated with the correct device by calling
// freshDeviceId().
void
freshDeviceId
();
};
// class TensorRTEngine
// Add an layer__ into engine__ with args ARGS.
...
...
paddle/fluid/inference/tensorrt/test_engine.cc
浏览文件 @
88c24baa
...
...
@@ -17,6 +17,8 @@ limitations under the License. */
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/platform/enforce.h"
...
...
@@ -27,19 +29,29 @@ namespace tensorrt {
class
TensorRTEngineTest
:
public
::
testing
::
Test
{
protected:
void
SetUp
()
override
{
ASSERT_EQ
(
0
,
cudaStreamCreate
(
&
stream_
));
engine_
=
new
TensorRTEngine
(
10
,
1
<<
10
,
stream_
);
ctx_
=
new
platform
::
CUDADeviceContext
(
platform
::
CUDAPlace
(
0
));
engine_
=
new
TensorRTEngine
(
10
,
1
<<
10
,
ctx_
->
stream
());
engine_
->
InitNetwork
();
}
void
TearDown
()
override
{
delete
engine_
;
cudaStreamDestroy
(
stream_
);
void
TearDown
()
override
{
delete
engine_
;
}
void
PrepareInputOutput
(
const
std
::
vector
<
float
>
&
input
,
std
::
vector
<
int
>
output_shape
)
{
TensorFromVector
(
input
,
*
ctx_
,
&
input_
);
output_
.
Resize
(
framework
::
make_ddim
(
output_shape
));
}
void
GetOutput
(
std
::
vector
<
float
>
*
output
)
{
TensorToVector
(
output_
,
*
ctx_
,
output
);
}
protected:
TensorRTEngine
*
engine_
;
cudaStream_t
stream_
;
framework
::
Tensor
input_
;
framework
::
Tensor
output_
;
TensorRTEngine
*
engine_
;
platform
::
CUDADeviceContext
*
ctx_
;
};
TEST_F
(
TensorRTEngineTest
,
add_layer
)
{
...
...
@@ -48,12 +60,14 @@ TEST_F(TensorRTEngineTest, add_layer) {
float
raw_weight
[
size
]
=
{
2.
};
// Weight in CPU memory.
float
raw_bias
[
size
]
=
{
3.
};
std
::
vector
<
void
*>
buffers
(
2
);
// TRT binded inputs
LOG
(
INFO
)
<<
"create weights"
;
TensorRTEngine
::
Weight
weight
(
nvinfer1
::
DataType
::
kFLOAT
,
raw_weight
,
size
);
TensorRTEngine
::
Weight
bias
(
nvinfer1
::
DataType
::
kFLOAT
,
raw_bias
,
size
);
auto
*
x
=
engine_
->
DeclareInput
(
"x"
,
nvinfer1
::
DataType
::
kFLOAT
,
auto
*
x
=
engine_
->
DeclareInput
(
"x"
,
nvinfer1
::
DataType
::
kFLOAT
,
nvinfer1
::
DimsCHW
{
1
,
1
,
1
});
auto
*
fc_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
FullyConnected
,
*
x
,
size
,
auto
*
fc_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
FullyConnected
,
*
x
,
size
,
weight
.
get
(),
bias
.
get
());
PADDLE_ENFORCE
(
fc_layer
!=
nullptr
);
...
...
@@ -63,18 +77,24 @@ TEST_F(TensorRTEngineTest, add_layer) {
ASSERT_EQ
(
engine_
->
engine
()
->
getNbBindings
(),
2
);
// fill in real data
float
x_v
=
1234
;
engine_
->
SetInputFromCPU
(
"x"
,
reinterpret_cast
<
void
*>
(
&
x_v
),
1
*
sizeof
(
float
));
std
::
vector
<
float
>
x_v
=
{
1234
};
std
::
vector
<
float
>
y_cpu
;
PrepareInputOutput
(
x_v
,
{
1
});
auto
*
x_v_gpu_data
=
input_
.
mutable_data
<
float
>
(
ctx_
->
GetPlace
());
auto
*
y_gpu_data
=
output_
.
mutable_data
<
float
>
(
ctx_
->
GetPlace
());
buffers
[
0
]
=
reinterpret_cast
<
void
*>
(
x_v_gpu_data
);
buffers
[
1
]
=
reinterpret_cast
<
void
*>
(
y_gpu_data
);
LOG
(
INFO
)
<<
"to execute"
;
engine_
->
Execute
(
1
);
engine_
->
Execute
(
1
,
buffers
);
LOG
(
INFO
)
<<
"to get output"
;
float
y_cpu
;
engine_
->
GetOutputInCPU
(
"y"
,
&
y_cpu
,
1
*
sizeof
(
float
));
GetOutput
(
&
y_cpu
);
LOG
(
INFO
)
<<
"to checkout output"
;
ASSERT_EQ
(
y_cpu
,
x_v
*
2
+
3
);
ASSERT_EQ
(
y_cpu
[
0
],
x_v
[
0
]
*
2
+
3
);
}
TEST_F
(
TensorRTEngineTest
,
add_layer_multi_dim
)
{
...
...
@@ -83,12 +103,13 @@ TEST_F(TensorRTEngineTest, add_layer_multi_dim) {
// instead of row-major, which is [[1.0, 1.1], [3.3, 4.4]]
float
raw_weight
[
4
]
=
{
1.0
,
1.1
,
3.3
,
4.4
};
float
raw_bias
[
2
]
=
{
1.3
,
2.4
};
std
::
vector
<
void
*>
buffers
(
2
);
// TRT binded inputs
TensorRTEngine
::
Weight
weight
(
nvinfer1
::
DataType
::
kFLOAT
,
raw_weight
,
4
);
TensorRTEngine
::
Weight
bias
(
nvinfer1
::
DataType
::
kFLOAT
,
raw_bias
,
2
);
auto
*
x
=
engine_
->
DeclareInput
(
"x"
,
nvinfer1
::
DataType
::
kFLOAT
,
auto
*
x
=
engine_
->
DeclareInput
(
"x"
,
nvinfer1
::
DataType
::
kFLOAT
,
nvinfer1
::
DimsCHW
{
1
,
2
,
1
});
auto
*
fc_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
FullyConnected
,
*
x
,
2
,
auto
*
fc_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
FullyConnected
,
*
x
,
2
,
weight
.
get
(),
bias
.
get
());
PADDLE_ENFORCE
(
fc_layer
!=
nullptr
);
...
...
@@ -96,19 +117,27 @@ TEST_F(TensorRTEngineTest, add_layer_multi_dim) {
engine_
->
FreezeNetwork
();
ASSERT_EQ
(
engine_
->
engine
()
->
getNbBindings
(),
2
);
float
x_v
[
2
]
=
{
1.0
,
2.0
};
engine_
->
SetInputFromCPU
(
"x"
,
reinterpret_cast
<
void
*>
(
&
x_v
),
2
*
sizeof
(
float
));
engine_
->
Execute
(
1
);
// fill in real data
std
::
vector
<
float
>
x_v
=
{
1.0
,
2.0
};
std
::
vector
<
float
>
y_cpu
;
PrepareInputOutput
(
x_v
,
{
2
});
auto
*
x_v_gpu_data
=
input_
.
mutable_data
<
float
>
(
ctx_
->
GetPlace
());
auto
*
y_gpu_data
=
output_
.
mutable_data
<
float
>
(
ctx_
->
GetPlace
());
buffers
[
0
]
=
reinterpret_cast
<
void
*>
(
x_v_gpu_data
);
buffers
[
1
]
=
reinterpret_cast
<
void
*>
(
y_gpu_data
);
engine_
->
Execute
(
1
,
buffers
);
LOG
(
INFO
)
<<
"to get output"
;
float
y_cpu
[
2
]
=
{
-
1.
,
-
1.
}
;
GetOutput
(
&
y_cpu
)
;
auto
dims
=
engine_
->
GetITensor
(
"y"
)
->
getDimensions
();
ASSERT_EQ
(
dims
.
nbDims
,
3
);
ASSERT_EQ
(
dims
.
d
[
0
],
2
);
ASSERT_EQ
(
dims
.
d
[
1
],
1
);
engine_
->
GetOutputInCPU
(
"y"
,
&
y_cpu
[
0
],
2
*
sizeof
(
float
));
ASSERT_EQ
(
y_cpu
[
0
],
4.5
);
ASSERT_EQ
(
y_cpu
[
1
],
14.5
);
}
...
...
@@ -117,12 +146,13 @@ TEST_F(TensorRTEngineTest, test_conv2d) {
// Weight in CPU memory.
float
raw_weight
[
9
]
=
{
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
};
float
raw_bias
[
1
]
=
{
0
};
std
::
vector
<
void
*>
buffers
(
2
);
// TRT binded inputs
TensorRTEngine
::
Weight
weight
(
nvinfer1
::
DataType
::
kFLOAT
,
raw_weight
,
9
);
TensorRTEngine
::
Weight
bias
(
nvinfer1
::
DataType
::
kFLOAT
,
raw_bias
,
1
);
auto
*
x
=
engine_
->
DeclareInput
(
"x"
,
nvinfer1
::
DataType
::
kFLOAT
,
auto
*
x
=
engine_
->
DeclareInput
(
"x"
,
nvinfer1
::
DataType
::
kFLOAT
,
nvinfer1
::
Dims3
{
1
,
3
,
3
});
auto
*
conv_layer
=
auto
*
conv_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Convolution
,
*
x
,
1
,
nvinfer1
::
DimsHW
{
3
,
3
},
weight
.
get
(),
bias
.
get
());
PADDLE_ENFORCE
(
conv_layer
!=
nullptr
);
...
...
@@ -133,28 +163,37 @@ TEST_F(TensorRTEngineTest, test_conv2d) {
engine_
->
FreezeNetwork
();
ASSERT_EQ
(
engine_
->
engine
()
->
getNbBindings
(),
2
);
float
x_v
[
18
]
=
{
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
};
engine_
->
SetInputFromCPU
(
"x"
,
reinterpret_cast
<
void
*>
(
&
x_v
),
18
*
sizeof
(
float
));
engine_
->
Execute
(
2
);
// fill in real data
std
::
vector
<
float
>
x_v
=
{
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
};
std
::
vector
<
float
>
y_cpu
;
PrepareInputOutput
(
x_v
,
{
18
});
auto
*
x_v_gpu_data
=
input_
.
mutable_data
<
float
>
(
ctx_
->
GetPlace
());
auto
*
y_gpu_data
=
output_
.
mutable_data
<
float
>
(
ctx_
->
GetPlace
());
buffers
[
0
]
=
reinterpret_cast
<
void
*>
(
x_v_gpu_data
);
buffers
[
1
]
=
reinterpret_cast
<
void
*>
(
y_gpu_data
);
engine_
->
Execute
(
2
,
buffers
);
LOG
(
INFO
)
<<
"to get output"
;
float
*
y_cpu
=
new
float
[
18
]
;
engine_
->
GetOutputInCPU
(
"y"
,
&
y_cpu
[
0
],
18
*
sizeof
(
float
));
GetOutput
(
&
y_cpu
)
;
ASSERT_EQ
(
y_cpu
[
0
],
4.0
);
ASSERT_EQ
(
y_cpu
[
1
],
6.0
);
}
TEST_F
(
TensorRTEngineTest
,
test_pool2d
)
{
// Weight in CPU memory.
auto
*
x
=
engine_
->
DeclareInput
(
"x"
,
nvinfer1
::
DataType
::
kFLOAT
,
auto
*
x
=
engine_
->
DeclareInput
(
"x"
,
nvinfer1
::
DataType
::
kFLOAT
,
nvinfer1
::
Dims3
{
1
,
2
,
2
});
std
::
vector
<
void
*>
buffers
(
2
);
// TRT binded inputs
nvinfer1
::
PoolingType
pool_t
=
nvinfer1
::
PoolingType
::
kAVERAGE
;
auto
*
pool_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Pooling
,
*
const_cast
<
nvinfer1
::
ITensor
*>
(
x
),
pool_t
,
nvinfer1
::
DimsHW
{
2
,
2
});
auto
*
pool_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Pooling
,
*
const_cast
<
nvinfer1
::
ITensor
*>
(
x
),
pool_t
,
nvinfer1
::
DimsHW
{
2
,
2
});
PADDLE_ENFORCE
(
pool_layer
!=
nullptr
);
pool_layer
->
setStride
(
nvinfer1
::
DimsHW
{
1
,
1
});
...
...
@@ -164,14 +203,21 @@ TEST_F(TensorRTEngineTest, test_pool2d) {
engine_
->
FreezeNetwork
();
ASSERT_EQ
(
engine_
->
engine
()
->
getNbBindings
(),
2
);
float
x_v
[
8
]
=
{
1.0
,
2.0
,
5.0
,
0.0
,
2.0
,
3.0
,
5.0
,
10.0
};
engine_
->
SetInputFromCPU
(
"x"
,
reinterpret_cast
<
void
*>
(
&
x_v
),
8
*
sizeof
(
float
));
engine_
->
Execute
(
2
);
// fill in real data
std
::
vector
<
float
>
x_v
=
{
1.0
,
2.0
,
5.0
,
0.0
,
2.0
,
3.0
,
5.0
,
10.0
};
std
::
vector
<
float
>
y_cpu
;
PrepareInputOutput
(
x_v
,
{
2
});
auto
*
x_v_gpu_data
=
input_
.
mutable_data
<
float
>
(
ctx_
->
GetPlace
());
auto
*
y_gpu_data
=
output_
.
mutable_data
<
float
>
(
ctx_
->
GetPlace
());
buffers
[
0
]
=
reinterpret_cast
<
void
*>
(
x_v_gpu_data
);
buffers
[
1
]
=
reinterpret_cast
<
void
*>
(
y_gpu_data
);
engine_
->
Execute
(
2
,
buffers
);
LOG
(
INFO
)
<<
"to get output"
;
float
*
y_cpu
=
new
float
[
2
];
engine_
->
GetOutputInCPU
(
"y"
,
&
y_cpu
[
0
],
2
*
sizeof
(
float
));
GetOutput
(
&
y_cpu
);
ASSERT_EQ
(
y_cpu
[
0
],
2.0
);
ASSERT_EQ
(
y_cpu
[
1
],
5.0
);
...
...
paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
浏览文件 @
88c24baa
...
...
@@ -106,6 +106,11 @@ class TensorRTEngineOp : public framework::OperatorBase {
if
(
enable_int8_
&&
calibration_data_
.
size
())
{
calibrator_
.
reset
(
new
TRTInt8Calibrator
(
calibration_data_
));
}
// we will create an engine here.
if
(
!
calibration_mode_
)
{
// trt_engine_.reset();
}
}
protected:
...
...
@@ -125,7 +130,8 @@ class TensorRTEngineOp : public framework::OperatorBase {
RunCalibration
(
scope
,
dev_place
);
return
;
}
RunTrt
(
scope
,
dev_place
);
auto
trt_engine
=
GetEngine
(
scope
,
dev_place
);
RunTrt
(
scope
,
dev_place
,
trt_engine
);
}
void
RunCalibration
(
const
framework
::
Scope
&
scope
,
...
...
@@ -155,10 +161,9 @@ class TensorRTEngineOp : public framework::OperatorBase {
calib_res
->
calib_
.
reset
(
new
TRTInt8Calibrator
(
calib_buffers
,
runtime_batch
,
engine_key_
,
dev_place
));
calib_res
->
thr_
.
reset
(
new
std
::
thread
([
&
]()
{
calib_res
->
engine_
.
reset
(
new
TensorRTEngine
(
max_batch_size_
,
workspace_size_
,
stream
,
boost
::
get
<
platform
::
CUDAPlace
>
(
dev_place
).
device
,
enable_int8_
,
calib_res
->
calib_
.
get
()));
calib_res
->
engine_
.
reset
(
new
TensorRTEngine
(
max_batch_size_
,
workspace_size_
,
stream
,
enable_int8_
,
calib_res
->
calib_
.
get
()));
VLOG
(
3
)
<<
"start the calib trt engine thread"
;
Prepare
(
scope
,
dev_place
,
calib_res
->
engine_
.
get
());
}));
...
...
@@ -180,28 +185,30 @@ class TensorRTEngineOp : public framework::OperatorBase {
RunNativeImpl
(
scope
,
dev_place
);
}
void
RunTrt
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_plac
e
)
const
{
void
RunTrt
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
,
TensorRTEngine
*
engin
e
)
const
{
int
runtime_batch
=
1
;
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
dev_ctx
).
stream
();
if
(
trt_engine_
.
get
()
==
nullptr
)
{
trt_engine_
.
reset
(
new
TensorRTEngine
(
max_batch_size_
,
workspace_size_
,
stream
,
boost
::
get
<
platform
::
CUDAPlace
>
(
dev_place
).
device
,
enable_int8_
,
calibrator_
.
get
()));
Prepare
(
scope
,
dev_place
,
trt_engine_
.
get
());
}
auto
*
engine
=
trt_engine_
.
get
();
//
auto *engine = trt_engine_.get();
PADDLE_ENFORCE
(
!
input_names_
.
empty
(),
"should pass more than one inputs"
);
std
::
vector
<
std
::
string
>
output_maps
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"output_name_mapping"
);
// Convert input tensor from fluid to engine.
int
num_inputs
=
0
;
for
(
const
auto
&
x
:
Inputs
(
"Xs"
))
{
if
(
param_names_
.
count
(
x
))
continue
;
num_inputs
+=
1
;
}
const
int
num_bindings
=
num_inputs
+
Outputs
(
"Ys"
).
size
();
std
::
vector
<
void
*>
buffers
(
num_bindings
);
// Bind input tensor to TRT.
for
(
const
auto
&
x
:
Inputs
(
"Xs"
))
{
if
(
param_names_
.
count
(
x
))
continue
;
// convert input and copy to TRT engine's buffer
...
...
@@ -209,26 +216,17 @@ class TensorRTEngineOp : public framework::OperatorBase {
inference
::
analysis
::
GetFromScope
<
framework
::
LoDTensor
>
(
scope
,
x
);
auto
t_shape
=
framework
::
vectorize
(
t
.
dims
());
runtime_batch
=
t_shape
[
0
];
if
(
platform
::
is_cpu_place
(
t
.
place
()))
{
engine
->
SetInputFromCPU
(
x
,
static_cast
<
const
void
*>
(
t
.
data
<
void
>
()),
t
.
memory_size
());
}
else
{
engine
->
SetInputFromGPU
(
x
,
static_cast
<
const
void
*>
(
t
.
data
<
void
>
()),
t
.
memory_size
());
}
}
cudaStreamSynchronize
(
stream
);
PADDLE_ENFORCE_LE
(
runtime_batch
,
max_batch_size_
);
// Execute the engine.
engine
->
Execute
(
runtime_batch
);
const
int
bind_index
=
engine
->
engine
()
->
getBindingIndex
(
x
.
c_str
());
PADDLE_ENFORCE
(
bind_index
<
num_bindings
,
"The bind index should be less than num_bindings"
);
buffers
[
bind_index
]
=
static_cast
<
void
*>
(
t
.
data
<
float
>
());
}
//
Convert output tensor from engine to fluid
//
Bind output tensor to TRT.
int
output_index
=
0
;
VLOG
(
4
)
<<
"TensorRT Engine Op Outputs:"
;
for
(
const
auto
&
y
:
Outputs
(
"Ys"
))
{
VLOG
(
4
)
<<
y
;
// convert output and copy to fluid.
nvinfer1
::
ITensor
*
trt_t
=
engine
->
GetITensor
(
output_maps
[
output_index
]);
auto
dims
=
trt_t
->
getDimensions
();
// Use the output ITensor's dims to reshape the Fluid Tensor.
...
...
@@ -238,27 +236,46 @@ class TensorRTEngineOp : public framework::OperatorBase {
for
(
int
i
=
0
;
i
<
dims
.
nbDims
;
i
++
)
{
ddim
.
push_back
(
dims
.
d
[
i
]);
}
auto
*
fluid_v
=
scope
.
FindVar
(
y
);
PADDLE_ENFORCE_NOT_NULL
(
fluid_v
,
"no output variable called %s"
,
y
);
auto
*
fluid_t
=
fluid_v
->
GetMutable
<
framework
::
LoDTensor
>
();
fluid_t
->
Resize
(
framework
::
make_ddim
(
ddim
));
// TODO(Superjomn) change this float to dtype size.
auto
size
=
inference
::
analysis
::
AccuDims
(
dims
.
d
,
dims
.
nbDims
)
*
runtime_batch
;
engine
->
GetOutputInGPU
(
output_maps
[
output_index
],
fluid_t
->
mutable_data
<
float
>
(
platform
::
CUDAPlace
(
boost
::
get
<
platform
::
CUDAPlace
>
(
dev_place
).
device
)),
size
*
sizeof
(
float
));
const
int
bind_index
=
engine
->
engine
()
->
getBindingIndex
(
output_maps
[
output_index
].
c_str
());
PADDLE_ENFORCE
(
bind_index
<
num_bindings
,
"The bind index should be less than num_bindings"
);
buffers
[
bind_index
]
=
static_cast
<
void
*>
(
fluid_t
->
mutable_data
<
float
>
(
boost
::
get
<
platform
::
CUDAPlace
>
(
dev_place
)));
output_index
+=
1
;
}
PADDLE_ENFORCE_LE
(
runtime_batch
,
max_batch_size_
);
// Execute the engine.
engine
->
Execute
(
runtime_batch
,
buffers
);
cudaStreamSynchronize
(
stream
);
}
TensorRTEngine
*
GetEngine
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
{
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
dev_ctx
).
stream
();
if
(
trt_engine_
.
get
()
==
nullptr
)
{
trt_engine_
.
reset
(
new
TensorRTEngine
(
max_batch_size_
,
workspace_size_
,
stream
,
enable_int8_
,
calibrator_
.
get
()));
if
(
true
)
{
Prepare
(
scope
,
dev_place
,
trt_engine_
.
get
());
}
else
{
// create static engine
}
}
return
trt_engine_
.
get
();
}
void
Prepare
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
,
TensorRTEngine
*
engine
)
const
{
LOG
(
INFO
)
<<
"Prepare TRT engine (Optimize model structure, Select OP "
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录