Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
ec9eb220
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ec9eb220
编写于
8月 30, 2018
作者:
X
Xin Pan
提交者:
GitHub
8月 30, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #13039 from NHZlX/release_trt_submit
Tensorrt support : mobilenet resnet50
上级
1ca2cde1
0e40429e
变更
29
隐藏空白更改
内联
并排
Showing
29 changed file
with
642 addition
and
192 deletion
+642
-192
paddle/fluid/inference/analysis/analyzer.cc
paddle/fluid/inference/analysis/analyzer.cc
+3
-2
paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc
...fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc
+0
-6
paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h
.../fluid/inference/analysis/data_flow_graph_to_fluid_pass.h
+0
-3
paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc
...fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc
+1
-1
paddle/fluid/inference/analysis/subgraph_splitter.cc
paddle/fluid/inference/analysis/subgraph_splitter.cc
+1
-0
paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc
paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc
+14
-1
paddle/fluid/inference/api/api_tensorrt_subgraph_engine_tester.cc
...luid/inference/api/api_tensorrt_subgraph_engine_tester.cc
+1
-0
paddle/fluid/inference/api/paddle_inference_api.h
paddle/fluid/inference/api/paddle_inference_api.h
+8
-0
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
+6
-3
paddle/fluid/inference/tensorrt/convert/batch_norm_op.cc
paddle/fluid/inference/tensorrt/convert/batch_norm_op.cc
+136
-0
paddle/fluid/inference/tensorrt/convert/concat_op.cc
paddle/fluid/inference/tensorrt/convert/concat_op.cc
+57
-0
paddle/fluid/inference/tensorrt/convert/conv2d_op.cc
paddle/fluid/inference/tensorrt/convert/conv2d_op.cc
+16
-6
paddle/fluid/inference/tensorrt/convert/elementwise_op.cc
paddle/fluid/inference/tensorrt/convert/elementwise_op.cc
+14
-6
paddle/fluid/inference/tensorrt/convert/fc_op.cc
paddle/fluid/inference/tensorrt/convert/fc_op.cc
+15
-12
paddle/fluid/inference/tensorrt/convert/op_converter.h
paddle/fluid/inference/tensorrt/convert/op_converter.h
+8
-0
paddle/fluid/inference/tensorrt/convert/pool2d_op.cc
paddle/fluid/inference/tensorrt/convert/pool2d_op.cc
+8
-1
paddle/fluid/inference/tensorrt/convert/test_batch_norm_op.cc
...le/fluid/inference/tensorrt/convert/test_batch_norm_op.cc
+71
-0
paddle/fluid/inference/tensorrt/convert/test_concat_op.cc
paddle/fluid/inference/tensorrt/convert/test_concat_op.cc
+49
-0
paddle/fluid/inference/tensorrt/convert/test_op_converter.cc
paddle/fluid/inference/tensorrt/convert/test_op_converter.cc
+1
-0
paddle/fluid/inference/tensorrt/convert/test_pool2d_op.cc
paddle/fluid/inference/tensorrt/convert/test_pool2d_op.cc
+10
-2
paddle/fluid/inference/tensorrt/convert/ut_helper.h
paddle/fluid/inference/tensorrt/convert/ut_helper.h
+28
-8
paddle/fluid/inference/tensorrt/engine.cc
paddle/fluid/inference/tensorrt/engine.cc
+9
-0
paddle/fluid/inference/tensorrt/engine.h
paddle/fluid/inference/tensorrt/engine.h
+23
-5
paddle/fluid/inference/tensorrt/test_engine.cc
paddle/fluid/inference/tensorrt/test_engine.cc
+1
-1
paddle/fluid/operators/CMakeLists.txt
paddle/fluid/operators/CMakeLists.txt
+3
-1
paddle/fluid/operators/tensorrt_engine_op.cc
paddle/fluid/operators/tensorrt_engine_op.cc
+2
-107
paddle/fluid/operators/tensorrt_engine_op.cu.cc
paddle/fluid/operators/tensorrt_engine_op.cu.cc
+24
-0
paddle/fluid/operators/tensorrt_engine_op.h
paddle/fluid/operators/tensorrt_engine_op.h
+111
-8
paddle/fluid/operators/tensorrt_engine_op_test.cc
paddle/fluid/operators/tensorrt_engine_op_test.cc
+22
-19
未找到文件。
paddle/fluid/inference/analysis/analyzer.cc
浏览文件 @
ec9eb220
...
@@ -24,7 +24,7 @@
...
@@ -24,7 +24,7 @@
namespace
paddle
{
namespace
paddle
{
DEFINE_bool
(
inference_analysis_enable_tensorrt_subgraph_engine
,
tru
e
,
DEFINE_bool
(
inference_analysis_enable_tensorrt_subgraph_engine
,
fals
e
,
"Enable subgraph to TensorRT engine for acceleration"
);
"Enable subgraph to TensorRT engine for acceleration"
);
DEFINE_string
(
inference_analysis_graphviz_log_root
,
"./"
,
DEFINE_string
(
inference_analysis_graphviz_log_root
,
"./"
,
...
@@ -44,7 +44,8 @@ class DfgPassManagerImpl final : public DfgPassManager {
...
@@ -44,7 +44,8 @@ class DfgPassManagerImpl final : public DfgPassManager {
if
(
FLAGS_inference_analysis_enable_tensorrt_subgraph_engine
)
{
if
(
FLAGS_inference_analysis_enable_tensorrt_subgraph_engine
)
{
auto
trt_teller
=
[
&
](
const
Node
*
node
)
{
auto
trt_teller
=
[
&
](
const
Node
*
node
)
{
std
::
unordered_set
<
std
::
string
>
teller_set
(
std
::
unordered_set
<
std
::
string
>
teller_set
(
{
"elementwise_add"
,
"mul"
,
"conv2d"
,
"pool2d"
,
"relu"
,
"softmax"
});
{
"elementwise_add"
,
"mul"
,
"conv2d"
,
"pool2d"
,
"relu"
,
"softmax"
,
"depthwise_conv2d"
,
"batch_norm"
,
"concat"
});
if
(
!
node
->
IsFunction
())
return
false
;
if
(
!
node
->
IsFunction
())
return
false
;
const
auto
*
func
=
static_cast
<
const
Function
*>
(
node
);
const
auto
*
func
=
static_cast
<
const
Function
*>
(
node
);
...
...
paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc
浏览文件 @
ec9eb220
...
@@ -23,9 +23,6 @@
...
@@ -23,9 +23,6 @@
namespace
paddle
{
namespace
paddle
{
namespace
inference
{
namespace
inference
{
DEFINE_int32
(
tensorrt_max_batchsize
,
3
,
"TensorRT maximum batch size"
);
DEFINE_int32
(
tensorrt_workspace_size
,
2048
,
"TensorRT workspace size"
);
namespace
analysis
{
namespace
analysis
{
using
framework
::
proto
::
ProgramDesc
;
using
framework
::
proto
::
ProgramDesc
;
...
@@ -52,7 +49,6 @@ bool DataFlowGraphToFluidPass::Initialize(Argument *argument) {
...
@@ -52,7 +49,6 @@ bool DataFlowGraphToFluidPass::Initialize(Argument *argument) {
bool
DataFlowGraphToFluidPass
::
Finalize
()
{
return
true
;
}
bool
DataFlowGraphToFluidPass
::
Finalize
()
{
return
true
;
}
void
DataFlowGraphToFluidPass
::
Run
(
DataFlowGraph
*
graph
)
{
void
DataFlowGraphToFluidPass
::
Run
(
DataFlowGraph
*
graph
)
{
FilterRedundantOutputOfSubGraph
(
graph
);
LOG
(
INFO
)
<<
"graph.inputs "
<<
graph
->
inputs
.
size
();
LOG
(
INFO
)
<<
"graph.inputs "
<<
graph
->
inputs
.
size
();
for
(
auto
&
node
:
GraphTraits
<
DataFlowGraph
>
(
graph
).
nodes_in_TS
())
{
for
(
auto
&
node
:
GraphTraits
<
DataFlowGraph
>
(
graph
).
nodes_in_TS
())
{
if
(
node
.
deleted
())
continue
;
if
(
node
.
deleted
())
continue
;
...
@@ -191,8 +187,6 @@ void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph,
...
@@ -191,8 +187,6 @@ void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph,
// Set attrs
// Set attrs
SetAttr
(
desc
.
Proto
(),
"subgraph"
,
block
->
SerializeAsString
());
SetAttr
(
desc
.
Proto
(),
"subgraph"
,
block
->
SerializeAsString
());
SetAttr
(
desc
.
Proto
(),
"engine_uniq_key"
,
"trt-"
+
std
::
to_string
(
counter
++
));
SetAttr
(
desc
.
Proto
(),
"engine_uniq_key"
,
"trt-"
+
std
::
to_string
(
counter
++
));
SetAttr
(
desc
.
Proto
(),
"max_batch"
,
FLAGS_tensorrt_max_batchsize
);
SetAttr
(
desc
.
Proto
(),
"max_workspace"
,
FLAGS_tensorrt_workspace_size
);
SetAttr
(
desc
.
Proto
(),
"parameters"
,
ExtractParameters
(
graph
.
nodes
.
nodes
()));
SetAttr
(
desc
.
Proto
(),
"parameters"
,
ExtractParameters
(
graph
.
nodes
.
nodes
()));
SetAttr
(
desc
.
Proto
(),
"output_name_mapping"
,
output_mapping
);
SetAttr
(
desc
.
Proto
(),
"output_name_mapping"
,
output_mapping
);
node
->
SetPbMsg
(
desc
.
Proto
()
->
SerializeAsString
());
node
->
SetPbMsg
(
desc
.
Proto
()
->
SerializeAsString
());
...
...
paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h
浏览文件 @
ec9eb220
...
@@ -27,9 +27,6 @@
...
@@ -27,9 +27,6 @@
namespace
paddle
{
namespace
paddle
{
namespace
inference
{
namespace
inference
{
DECLARE_int32
(
tensorrt_max_batchsize
);
DECLARE_int32
(
tensorrt_workspace_size
);
namespace
analysis
{
namespace
analysis
{
class
DataFlowGraphToFluidPass
final
:
public
DataFlowGraphPass
{
class
DataFlowGraphToFluidPass
final
:
public
DataFlowGraphPass
{
public:
public:
...
...
paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc
浏览文件 @
ec9eb220
...
@@ -92,6 +92,7 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) {
...
@@ -92,6 +92,7 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) {
auto
*
in
=
graph
->
nodes
.
GetMutable
(
var2id
.
at
(
in_var
.
arguments
(
k
)));
auto
*
in
=
graph
->
nodes
.
GetMutable
(
var2id
.
at
(
in_var
.
arguments
(
k
)));
in
->
outlinks
.
push_back
(
o
);
in
->
outlinks
.
push_back
(
o
);
o
->
inlinks
.
push_back
(
in
);
o
->
inlinks
.
push_back
(
in
);
unique_written_vars
.
insert
(
in
);
}
}
}
}
for
(
int
j
=
0
;
j
<
op
.
outputs_size
();
j
++
)
{
for
(
int
j
=
0
;
j
<
op
.
outputs_size
();
j
++
)
{
...
@@ -112,7 +113,6 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) {
...
@@ -112,7 +113,6 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) {
}
}
out
->
inlinks
.
push_back
(
o
);
out
->
inlinks
.
push_back
(
o
);
o
->
outlinks
.
push_back
(
out
);
o
->
outlinks
.
push_back
(
out
);
unique_written_vars
.
insert
(
out
);
}
}
}
}
}
}
...
...
paddle/fluid/inference/analysis/subgraph_splitter.cc
浏览文件 @
ec9eb220
...
@@ -153,6 +153,7 @@ void SubGraphFuse::ReplaceNodesWithSubGraphs() {
...
@@ -153,6 +153,7 @@ void SubGraphFuse::ReplaceNodesWithSubGraphs() {
inlink_or_outlink_cleaner
(
o
->
inlinks
);
inlink_or_outlink_cleaner
(
o
->
inlinks
);
}
}
}
}
FilterRedundantOutputOfSubGraph
(
graph_
);
}
}
}
// namespace analysis
}
// namespace analysis
...
...
paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc
浏览文件 @
ec9eb220
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/api/api_impl.h"
#include "paddle/fluid/inference/api/api_impl.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/utils/singleton.h"
#include "paddle/fluid/inference/utils/singleton.h"
#include "paddle/fluid/operators/tensorrt_engine_op.h"
#include "paddle/fluid/operators/tensorrt_engine_op.h"
...
@@ -32,7 +33,9 @@ class TensorRTSubgraphPredictor : public NativePaddlePredictor {
...
@@ -32,7 +33,9 @@ class TensorRTSubgraphPredictor : public NativePaddlePredictor {
bool
Init
(
const
std
::
shared_ptr
<
framework
::
Scope
>&
parent_scope
)
{
bool
Init
(
const
std
::
shared_ptr
<
framework
::
Scope
>&
parent_scope
)
{
VLOG
(
3
)
<<
"Predictor::init()"
;
VLOG
(
3
)
<<
"Predictor::init()"
;
FLAGS_inference_analysis_enable_tensorrt_subgraph_engine
=
true
;
FLAGS_tensorrt_max_batch_size
=
config_
.
max_batch_size
;
FLAGS_tensorrt_workspace_size
=
config_
.
workspace_size
;
if
(
config_
.
use_gpu
)
{
if
(
config_
.
use_gpu
)
{
place_
=
paddle
::
platform
::
CUDAPlace
(
config_
.
device
);
place_
=
paddle
::
platform
::
CUDAPlace
(
config_
.
device
);
}
else
{
}
else
{
...
@@ -150,3 +153,13 @@ CreatePaddlePredictor<TensorRTConfig, PaddleEngineKind::kAutoMixedTensorRT>(
...
@@ -150,3 +153,13 @@ CreatePaddlePredictor<TensorRTConfig, PaddleEngineKind::kAutoMixedTensorRT>(
}
}
}
// namespace paddle
}
// namespace paddle
USE_TRT_CONVERTER
(
elementwise_add_weight
);
USE_TRT_CONVERTER
(
mul
);
USE_TRT_CONVERTER
(
conv2d
);
USE_TRT_CONVERTER
(
relu
);
USE_TRT_CONVERTER
(
fc
);
USE_TRT_CONVERTER
(
pool2d
);
USE_TRT_CONVERTER
(
softmax
);
USE_TRT_CONVERTER
(
batch_norm
);
USE_TRT_CONVERTER
(
concat
);
paddle/fluid/inference/api/api_tensorrt_subgraph_engine_tester.cc
浏览文件 @
ec9eb220
...
@@ -37,6 +37,7 @@ void CompareTensorRTWithFluid(bool enable_tensorrt) {
...
@@ -37,6 +37,7 @@ void CompareTensorRTWithFluid(bool enable_tensorrt) {
config1
.
use_gpu
=
true
;
config1
.
use_gpu
=
true
;
config1
.
fraction_of_gpu_memory
=
0.3
;
config1
.
fraction_of_gpu_memory
=
0.3
;
config1
.
device
=
0
;
config1
.
device
=
0
;
config1
.
max_batch_size
=
10
;
auto
predictor0
=
auto
predictor0
=
CreatePaddlePredictor
<
NativeConfig
,
PaddleEngineKind
::
kNative
>
(
config0
);
CreatePaddlePredictor
<
NativeConfig
,
PaddleEngineKind
::
kNative
>
(
config0
);
...
...
paddle/fluid/inference/api/paddle_inference_api.h
浏览文件 @
ec9eb220
...
@@ -137,6 +137,14 @@ struct AnakinConfig : public PaddlePredictor::Config {
...
@@ -137,6 +137,14 @@ struct AnakinConfig : public PaddlePredictor::Config {
struct
TensorRTConfig
:
public
NativeConfig
{
struct
TensorRTConfig
:
public
NativeConfig
{
// Determine whether a subgraph will be executed by TRT.
// Determine whether a subgraph will be executed by TRT.
int
min_subgraph_size
{
1
};
int
min_subgraph_size
{
1
};
// While TensorRT allows an engine optimized for a given max batch size
// to run at any smaller size, the performance for those smaller
// sizes may not be as well-optimized. Therefore, Max batch is best
// equivalent to the runtime batch size.
int
max_batch_size
{
1
};
// For workspace_size, refer it from here:
// https://docs.nvidia.com/deeplearning/sdk/tensorrt-developer-guide/index.html#troubleshooting
int
workspace_size
{
1
<<
30
};
};
};
// A factory to help create different predictors.
// A factory to help create different predictors.
...
...
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
浏览文件 @
ec9eb220
# Add TRT tests
# Add TRT tests
nv_library
(
tensorrt_converter
nv_library
(
tensorrt_converter
SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
activation_op.cc softmax
_op.cc
batch_norm_op.cc activation_op.cc softmax_op.cc concat
_op.cc
DEPS tensorrt_engine operator scope framework_proto op_registry
)
DEPS tensorrt_engine operator scope framework_proto op_registry
)
nv_test
(
test_op_converter SRCS test_op_converter.cc DEPS
nv_test
(
test_op_converter SRCS test_op_converter.cc DEPS
...
@@ -18,9 +18,12 @@ nv_test(test_trt_conv_op SRCS test_conv2d_op.cc conv2d_op.cc
...
@@ -18,9 +18,12 @@ nv_test(test_trt_conv_op SRCS test_conv2d_op.cc conv2d_op.cc
DEPS
${
FLUID_CORE_MODULES
}
tensorrt_engine conv_op SERIAL
)
DEPS
${
FLUID_CORE_MODULES
}
tensorrt_engine conv_op SERIAL
)
nv_test
(
test_trt_pool2d_op SRCS test_pool2d_op.cc pool2d_op.cc
nv_test
(
test_trt_pool2d_op SRCS test_pool2d_op.cc pool2d_op.cc
DEPS
${
FLUID_CORE_MODULES
}
tensorrt_engine pool_op SERIAL
)
DEPS
${
FLUID_CORE_MODULES
}
tensorrt_engine pool_op SERIAL
)
nv_test
(
test_trt_elementwise_op SRCS test_elementwise_op.cc elementwise_op.cc
nv_test
(
test_trt_elementwise_op SRCS test_elementwise_op.cc elementwise_op.cc
DEPS
${
FLUID_CORE_MODULES
}
tensorrt_engine elementwise_add_op SERIAL
)
DEPS
${
FLUID_CORE_MODULES
}
tensorrt_engine elementwise_add_op SERIAL
)
nv_test
(
test_trt_softmax_op SRCS test_softmax_op.cc softmax_op.cc
nv_test
(
test_trt_softmax_op SRCS test_softmax_op.cc softmax_op.cc
DEPS
${
FLUID_CORE_MODULES
}
tensorrt_engine softmax_op SERIAL
)
DEPS
${
FLUID_CORE_MODULES
}
tensorrt_engine softmax_op SERIAL
)
nv_test
(
test_trt_batch_norm_op SRCS test_batch_norm_op.cc batch_norm_op.cc
DEPS
${
FLUID_CORE_MODULES
}
tensorrt_engine batch_norm_op SERIAL
)
nv_test
(
test_trt_concat_op SRCS test_concat_op.cc concat_op.cc
DEPS
${
FLUID_CORE_MODULES
}
tensorrt_engine concat_op SERIAL
)
paddle/fluid/inference/tensorrt/convert/batch_norm_op.cc
0 → 100644
浏览文件 @
ec9eb220
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <math.h>
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
class
BatchNormOpConverter
:
public
OpConverter
{
public:
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
LOG
(
INFO
)
<<
"convert a fluid batch norm op to tensorrt batch_norm"
;
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
PADDLE_ENFORCE_EQ
(
op_desc
.
Input
(
"X"
).
size
(),
1
);
PADDLE_ENFORCE_EQ
(
op_desc
.
Input
(
"Bias"
).
size
(),
1
);
// Bias is a weight
PADDLE_ENFORCE_EQ
(
op_desc
.
Input
(
"Mean"
).
size
(),
1
);
// Mean is a weight
PADDLE_ENFORCE_EQ
(
op_desc
.
Input
(
"Scale"
).
size
(),
1
);
// Scale is a weight
PADDLE_ENFORCE_EQ
(
op_desc
.
Input
(
"Variance"
).
size
(),
1
);
// Variance is a weight
PADDLE_ENFORCE_EQ
(
op_desc
.
Output
(
"Y"
).
size
(),
1
);
auto
*
X
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"X"
).
front
());
// Declare weights
auto
*
Bias_v
=
scope
.
FindVar
(
op_desc
.
Input
(
"Bias"
).
front
());
auto
*
Mean_v
=
scope
.
FindVar
(
op_desc
.
Input
(
"Mean"
).
front
());
auto
*
Scale_v
=
scope
.
FindVar
(
op_desc
.
Input
(
"Scale"
).
front
());
auto
*
Variance_v
=
scope
.
FindVar
(
op_desc
.
Input
(
"Variance"
).
front
());
const
float
eps
=
boost
::
get
<
float
>
(
op_desc
.
GetAttr
(
"epsilon"
));
PADDLE_ENFORCE_NOT_NULL
(
Bias_v
);
PADDLE_ENFORCE_NOT_NULL
(
Mean_v
);
PADDLE_ENFORCE_NOT_NULL
(
Scale_v
);
PADDLE_ENFORCE_NOT_NULL
(
Variance_v
);
// get tensor
auto
*
Bias_t
=
Bias_v
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
Mean_t
=
Mean_v
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
Scale_t
=
Scale_v
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
Variance_t
=
Variance_v
->
GetMutable
<
framework
::
LoDTensor
>
();
// create temp tensor for weights
framework
::
LoDTensor
bias_tensor
;
framework
::
LoDTensor
mean_tensor
;
framework
::
LoDTensor
scale_tensor
;
framework
::
LoDTensor
variance_tensor
;
bias_tensor
.
Resize
(
Bias_t
->
dims
());
mean_tensor
.
Resize
(
Mean_t
->
dims
());
scale_tensor
.
Resize
(
Scale_t
->
dims
());
variance_tensor
.
Resize
(
Variance_t
->
dims
());
platform
::
CPUPlace
cpu_place
;
// copy data from gpu to cpu
TensorCopySync
((
*
Bias_t
),
cpu_place
,
&
bias_tensor
);
TensorCopySync
((
*
Mean_t
),
cpu_place
,
&
mean_tensor
);
TensorCopySync
((
*
Scale_t
),
cpu_place
,
&
scale_tensor
);
TensorCopySync
((
*
Variance_t
),
cpu_place
,
&
variance_tensor
);
auto
*
bias_data
=
bias_tensor
.
mutable_data
<
float
>
(
platform
::
CPUPlace
());
auto
*
mean_data
=
mean_tensor
.
mutable_data
<
float
>
(
platform
::
CPUPlace
());
auto
*
scale_data
=
scale_tensor
.
mutable_data
<
float
>
(
platform
::
CPUPlace
());
auto
*
variance_data
=
variance_tensor
.
mutable_data
<
float
>
(
platform
::
CPUPlace
());
std
::
unique_ptr
<
framework
::
LoDTensor
>
combile_scale_tensor
(
new
framework
::
LoDTensor
());
std
::
unique_ptr
<
framework
::
LoDTensor
>
combile_bias_tensor
(
new
framework
::
LoDTensor
());
combile_scale_tensor
->
Resize
(
scale_tensor
.
dims
());
combile_bias_tensor
->
Resize
(
bias_tensor
.
dims
());
auto
*
combile_scale_data
=
combile_scale_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
auto
*
combile_bias_data
=
combile_bias_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
size_t
ele_num
=
combile_scale_tensor
->
memory_size
()
/
sizeof
(
float
);
for
(
size_t
i
=
0
;
i
<
ele_num
;
i
++
)
{
float
scale
=
scale_data
[
i
];
float
bias
=
bias_data
[
i
];
float
mean
=
mean_data
[
i
];
float
variance
=
variance_data
[
i
];
combile_scale_data
[
i
]
=
scale
/
sqrtf
(
variance
+
eps
);
combile_bias_data
[
i
]
=
bias
-
mean
*
combile_scale_data
[
i
];
}
TensorRTEngine
::
Weight
scale_weights
{
nvinfer1
::
DataType
::
kFLOAT
,
static_cast
<
void
*>
(
combile_scale_data
),
combile_scale_tensor
->
memory_size
()
/
sizeof
(
float
)};
TensorRTEngine
::
Weight
shift_weights
{
nvinfer1
::
DataType
::
kFLOAT
,
static_cast
<
void
*>
(
combile_bias_data
),
combile_bias_tensor
->
memory_size
()
/
sizeof
(
float
)};
TensorRTEngine
::
Weight
power_weights
{
nvinfer1
::
DataType
::
kFLOAT
,
nullptr
,
0
};
nvinfer1
::
IScaleLayer
*
layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Scale
,
*
const_cast
<
nvinfer1
::
ITensor
*>
(
X
),
nvinfer1
::
ScaleMode
::
kCHANNEL
,
shift_weights
.
get
(),
scale_weights
.
get
(),
power_weights
.
get
());
auto
output_name
=
op_desc
.
Output
(
"Y"
).
front
();
engine_
->
weight_map
[
op_desc
.
Input
(
"Bias"
).
front
()]
=
std
::
move
(
combile_bias_tensor
);
engine_
->
weight_map
[
op_desc
.
Input
(
"Scale"
).
front
()]
=
std
::
move
(
combile_scale_tensor
);
engine_
->
SetITensor
(
output_name
,
layer
->
getOutput
(
0
));
if
(
test_mode
)
{
engine_
->
DeclareOutput
(
output_name
);
}
}
};
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
REGISTER_TRT_OP_CONVERTER
(
batch_norm
,
BatchNormOpConverter
);
paddle/fluid/inference/tensorrt/convert/concat_op.cc
0 → 100644
浏览文件 @
ec9eb220
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
/*
* MulOp, IMatrixMultiplyLayer in TRT. This Layer doesn't has weights.
*/
class
ConcatOpConverter
:
public
OpConverter
{
public:
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
VLOG
(
4
)
<<
"convert a fluid mul op to tensorrt mul layer without bias"
;
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
// Declare inputs
std
::
vector
<
nvinfer1
::
ITensor
*>
itensors
;
for
(
auto
&
input_name
:
op_desc
.
Input
(
"X"
))
{
itensors
.
push_back
(
engine_
->
GetITensor
(
input_name
));
}
int
axis
=
boost
::
get
<
int
>
(
op_desc
.
GetAttr
(
"axis"
));
PADDLE_ENFORCE
(
axis
>
0
,
"The axis attr of Concat op should be large than 0 for trt"
);
auto
*
layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Concatenation
,
itensors
.
data
(),
itensors
.
size
());
axis
=
axis
-
1
;
// Remove batch dim
layer
->
setAxis
(
axis
);
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
engine_
->
SetITensor
(
output_name
,
layer
->
getOutput
(
0
));
if
(
test_mode
)
{
// the test framework can not determine which is the
// output, so place the declaration inside.
engine_
->
DeclareOutput
(
output_name
);
}
}
};
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
REGISTER_TRT_OP_CONVERTER
(
concat
,
ConcatOpConverter
);
paddle/fluid/inference/tensorrt/convert/conv2d_op.cc
浏览文件 @
ec9eb220
...
@@ -35,12 +35,20 @@ class Conv2dOpConverter : public OpConverter {
...
@@ -35,12 +35,20 @@ class Conv2dOpConverter : public OpConverter {
auto
*
Y_v
=
scope
.
FindVar
(
op_desc
.
Input
(
"Filter"
).
front
());
auto
*
Y_v
=
scope
.
FindVar
(
op_desc
.
Input
(
"Filter"
).
front
());
PADDLE_ENFORCE_NOT_NULL
(
Y_v
);
PADDLE_ENFORCE_NOT_NULL
(
Y_v
);
auto
*
Y_t
=
Y_v
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
Y_t
=
Y_v
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
weight_data
=
Y_t
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
PADDLE_ENFORCE_EQ
(
Y_t
->
dims
().
size
(),
4UL
);
platform
::
CPUPlace
cpu_place
;
const
int
n_output
=
Y_t
->
dims
()[
0
];
std
::
unique_ptr
<
framework
::
LoDTensor
>
weight_tensor
(
const
int
filter_h
=
Y_t
->
dims
()[
2
];
new
framework
::
LoDTensor
());
const
int
filter_w
=
Y_t
->
dims
()[
3
];
weight_tensor
->
Resize
(
Y_t
->
dims
());
TensorCopySync
((
*
Y_t
),
cpu_place
,
weight_tensor
.
get
());
auto
*
weight_data
=
weight_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
PADDLE_ENFORCE_EQ
(
weight_tensor
->
dims
().
size
(),
4UL
);
const
int
n_output
=
weight_tensor
->
dims
()[
0
];
const
int
filter_h
=
weight_tensor
->
dims
()[
2
];
const
int
filter_w
=
weight_tensor
->
dims
()[
3
];
const
int
groups
=
boost
::
get
<
int
>
(
op_desc
.
GetAttr
(
"groups"
));
const
int
groups
=
boost
::
get
<
int
>
(
op_desc
.
GetAttr
(
"groups"
));
const
std
::
vector
<
int
>
dilations
=
const
std
::
vector
<
int
>
dilations
=
...
@@ -57,7 +65,7 @@ class Conv2dOpConverter : public OpConverter {
...
@@ -57,7 +65,7 @@ class Conv2dOpConverter : public OpConverter {
TensorRTEngine
::
Weight
weight
{
nvinfer1
::
DataType
::
kFLOAT
,
TensorRTEngine
::
Weight
weight
{
nvinfer1
::
DataType
::
kFLOAT
,
static_cast
<
void
*>
(
weight_data
),
static_cast
<
void
*>
(
weight_data
),
Y_t
->
memory_size
()
/
sizeof
(
float
)};
weight_tensor
->
memory_size
()
/
sizeof
(
float
)};
TensorRTEngine
::
Weight
bias
{
nvinfer1
::
DataType
::
kFLOAT
,
nullptr
,
0
};
TensorRTEngine
::
Weight
bias
{
nvinfer1
::
DataType
::
kFLOAT
,
nullptr
,
0
};
auto
*
layer
=
TRT_ENGINE_ADD_LAYER
(
auto
*
layer
=
TRT_ENGINE_ADD_LAYER
(
...
@@ -70,6 +78,8 @@ class Conv2dOpConverter : public OpConverter {
...
@@ -70,6 +78,8 @@ class Conv2dOpConverter : public OpConverter {
layer
->
setNbGroups
(
groups
);
layer
->
setNbGroups
(
groups
);
auto
output_name
=
op_desc
.
Output
(
"Output"
).
front
();
auto
output_name
=
op_desc
.
Output
(
"Output"
).
front
();
engine_
->
weight_map
[
op_desc
.
Input
(
"Filter"
).
front
()]
=
std
::
move
(
weight_tensor
);
engine_
->
SetITensor
(
output_name
,
layer
->
getOutput
(
0
));
engine_
->
SetITensor
(
output_name
,
layer
->
getOutput
(
0
));
if
(
test_mode
)
{
if
(
test_mode
)
{
engine_
->
DeclareOutput
(
output_name
);
engine_
->
DeclareOutput
(
output_name
);
...
...
paddle/fluid/inference/tensorrt/convert/elementwise_op.cc
浏览文件 @
ec9eb220
...
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -40,10 +39,17 @@ class ElementwiseWeightOpConverter : public OpConverter {
...
@@ -40,10 +39,17 @@ class ElementwiseWeightOpConverter : public OpConverter {
auto
*
Y_v
=
scope
.
FindVar
(
op_desc
.
Input
(
"Y"
).
front
());
auto
*
Y_v
=
scope
.
FindVar
(
op_desc
.
Input
(
"Y"
).
front
());
PADDLE_ENFORCE_NOT_NULL
(
Y_v
);
PADDLE_ENFORCE_NOT_NULL
(
Y_v
);
auto
*
Y_t
=
Y_v
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
Y_t
=
Y_v
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
weight_data
=
Y_t
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
platform
::
CPUPlace
cpu_place
;
std
::
unique_ptr
<
framework
::
LoDTensor
>
weight_tensor
(
new
framework
::
LoDTensor
());
weight_tensor
->
Resize
(
Y_t
->
dims
());
TensorCopySync
((
*
Y_t
),
cpu_place
,
weight_tensor
.
get
());
auto
*
weight_data
=
weight_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
auto
scale_mode
=
nvinfer1
::
ScaleMode
::
kELEMENTWISE
;
auto
scale_mode
=
nvinfer1
::
ScaleMode
::
kELEMENTWISE
;
std
::
vector
<
int
>
dims_y
=
framework
::
vectorize2int
(
Y_t
->
dims
());
std
::
vector
<
int
>
dims_y
=
framework
::
vectorize2int
(
weight_tensor
->
dims
());
if
(
static_cast
<
int
>
(
dims_y
.
size
())
==
dims_x
.
nbDims
+
1
)
{
if
(
static_cast
<
int
>
(
dims_y
.
size
())
==
dims_x
.
nbDims
+
1
)
{
if
(
dims_y
[
0
]
==
1
)
dims_y
.
erase
(
dims_y
.
begin
());
if
(
dims_y
[
0
]
==
1
)
dims_y
.
erase
(
dims_y
.
begin
());
}
}
...
@@ -70,9 +76,9 @@ class ElementwiseWeightOpConverter : public OpConverter {
...
@@ -70,9 +76,9 @@ class ElementwiseWeightOpConverter : public OpConverter {
PADDLE_THROW
(
"TensorRT unsupported weight Shape for Elementwise op!"
);
PADDLE_THROW
(
"TensorRT unsupported weight Shape for Elementwise op!"
);
}
}
TensorRTEngine
::
Weight
shift_weights
{
nvinfer1
::
DataType
::
kFLOAT
,
TensorRTEngine
::
Weight
shift_weights
{
static_cast
<
void
*>
(
weight_data
),
nvinfer1
::
DataType
::
kFLOAT
,
static_cast
<
void
*>
(
weight_data
),
Y_t
->
memory_size
()
/
sizeof
(
float
)};
weight_tensor
->
memory_size
()
/
sizeof
(
float
)};
TensorRTEngine
::
Weight
scale_weights
{
nvinfer1
::
DataType
::
kFLOAT
,
nullptr
,
TensorRTEngine
::
Weight
scale_weights
{
nvinfer1
::
DataType
::
kFLOAT
,
nullptr
,
0
};
0
};
TensorRTEngine
::
Weight
power_weights
{
nvinfer1
::
DataType
::
kFLOAT
,
nullptr
,
TensorRTEngine
::
Weight
power_weights
{
nvinfer1
::
DataType
::
kFLOAT
,
nullptr
,
...
@@ -82,6 +88,8 @@ class ElementwiseWeightOpConverter : public OpConverter {
...
@@ -82,6 +88,8 @@ class ElementwiseWeightOpConverter : public OpConverter {
engine_
,
Scale
,
*
const_cast
<
nvinfer1
::
ITensor
*>
(
X
),
scale_mode
,
engine_
,
Scale
,
*
const_cast
<
nvinfer1
::
ITensor
*>
(
X
),
scale_mode
,
shift_weights
.
get
(),
scale_weights
.
get
(),
power_weights
.
get
());
shift_weights
.
get
(),
scale_weights
.
get
(),
power_weights
.
get
());
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
engine_
->
weight_map
[
op_desc
.
Input
(
"Y"
).
front
()]
=
std
::
move
(
weight_tensor
);
engine_
->
SetITensor
(
output_name
,
layer
->
getOutput
(
0
));
engine_
->
SetITensor
(
output_name
,
layer
->
getOutput
(
0
));
if
(
test_mode
)
{
// the test framework can not determine which is the
if
(
test_mode
)
{
// the test framework can not determine which is the
// output, so place the declaration inside.
// output, so place the declaration inside.
...
...
paddle/fluid/inference/tensorrt/convert/fc_op.cc
浏览文件 @
ec9eb220
...
@@ -12,12 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,12 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/platform/place.h"
namespace
paddle
{
namespace
paddle
{
namespace
inference
{
namespace
inference
{
...
@@ -73,19 +68,26 @@ class FcOpConverter : public OpConverter {
...
@@ -73,19 +68,26 @@ class FcOpConverter : public OpConverter {
auto
*
Y_t
=
Y_v
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
Y_t
=
Y_v
->
GetMutable
<
framework
::
LoDTensor
>
();
// This may trigger a GPU->CPU copy, because TRT's weight can only be
// This may trigger a GPU->CPU copy, because TRT's weight can only be
// assigned from CPU memory, that can't be avoided.
// assigned from CPU memory, that can't be avoided.
auto
*
weight_data
=
Y_t
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
platform
::
CPUPlace
cpu_place
;
PADDLE_ENFORCE_EQ
(
Y_t
->
dims
().
size
(),
2UL
);
// a matrix
framework
::
LoDTensor
weight_tensor
;
size_t
n_output
=
Y_t
->
dims
()[
1
];
weight_tensor
.
Resize
(
Y_t
->
dims
());
TensorCopySync
((
*
Y_t
),
cpu_place
,
&
weight_tensor
);
framework
::
LoDTensor
tmp
;
auto
*
weight_data
=
weight_tensor
.
mutable_data
<
float
>
(
platform
::
CPUPlace
());
tmp
.
Resize
(
Y_t
->
dims
());
memcpy
(
tmp
.
mutable_data
<
float
>
(
platform
::
CPUPlace
()),
weight_data
,
PADDLE_ENFORCE_EQ
(
weight_tensor
.
dims
().
size
(),
2UL
);
// a matrix
size_t
n_output
=
weight_tensor
.
dims
()[
1
];
std
::
unique_ptr
<
framework
::
Tensor
>
tmp
(
new
framework
::
LoDTensor
());
tmp
->
Resize
(
weight_tensor
.
dims
());
memcpy
(
tmp
->
mutable_data
<
float
>
(
platform
::
CPUPlace
()),
weight_data
,
Y_t
->
dims
()[
0
]
*
Y_t
->
dims
()[
1
]
*
sizeof
(
float
));
Y_t
->
dims
()[
0
]
*
Y_t
->
dims
()[
1
]
*
sizeof
(
float
));
TensorRTEngine
::
Weight
weight
{
nvinfer1
::
DataType
::
kFLOAT
,
TensorRTEngine
::
Weight
weight
{
nvinfer1
::
DataType
::
kFLOAT
,
static_cast
<
void
*>
(
weight_data
),
static_cast
<
void
*>
(
weight_data
),
Y_t
->
memory_size
()
/
sizeof
(
float
)};
Y_t
->
memory_size
()
/
sizeof
(
float
)};
TensorRTEngine
::
Weight
tmp_weight
(
nvinfer1
::
DataType
::
kFLOAT
,
TensorRTEngine
::
Weight
tmp_weight
(
nvinfer1
::
DataType
::
kFLOAT
,
static_cast
<
void
*>
(
tmp
.
data
<
float
>
()),
static_cast
<
void
*>
(
tmp
->
data
<
float
>
()),
Y_t
->
memory_size
()
/
sizeof
(
float
));
Y_t
->
memory_size
()
/
sizeof
(
float
));
weight
.
dims
.
assign
({
Y_t
->
dims
()[
0
],
Y_t
->
dims
()[
1
]});
weight
.
dims
.
assign
({
Y_t
->
dims
()[
0
],
Y_t
->
dims
()[
1
]});
tmp_weight
.
dims
=
weight
.
dims
;
tmp_weight
.
dims
=
weight
.
dims
;
...
@@ -106,6 +108,7 @@ class FcOpConverter : public OpConverter {
...
@@ -106,6 +108,7 @@ class FcOpConverter : public OpConverter {
auto
output_name
=
op_desc
.
Output
(
"Out"
).
front
();
auto
output_name
=
op_desc
.
Output
(
"Out"
).
front
();
engine_
->
SetITensor
(
output_name
,
layer
->
getOutput
(
0
));
engine_
->
SetITensor
(
output_name
,
layer
->
getOutput
(
0
));
engine_
->
weight_map
[
op_desc
.
Input
(
"Y"
).
front
()]
=
std
::
move
(
tmp
);
if
(
test_mode
)
{
if
(
test_mode
)
{
engine_
->
DeclareOutput
(
output_name
);
engine_
->
DeclareOutput
(
output_name
);
}
}
...
...
paddle/fluid/inference/tensorrt/convert/op_converter.h
浏览文件 @
ec9eb220
...
@@ -79,6 +79,14 @@ class OpConverter {
...
@@ -79,6 +79,14 @@ class OpConverter {
it
=
it
=
Registry
<
OpConverter
>::
Lookup
(
"elementwise_"
+
op_type
+
"_tensor"
);
Registry
<
OpConverter
>::
Lookup
(
"elementwise_"
+
op_type
+
"_tensor"
);
}
}
PADDLE_ENFORCE_NOT_NULL
(
it
,
"no OpConverter for optype [%s]"
,
op_desc
.
Type
());
}
if
(
op_desc
.
Type
()
==
"depthwise_conv2d"
)
{
it
=
Registry
<
OpConverter
>::
Lookup
(
"conv2d"
);
PADDLE_ENFORCE_NOT_NULL
(
it
,
"no OpConverter for optype [%s]"
,
op_desc
.
Type
());
}
}
if
(
!
it
)
{
if
(
!
it
)
{
...
...
paddle/fluid/inference/tensorrt/convert/pool2d_op.cc
浏览文件 @
ec9eb220
...
@@ -33,6 +33,7 @@ class Pool2dOpConverter : public OpConverter {
...
@@ -33,6 +33,7 @@ class Pool2dOpConverter : public OpConverter {
PADDLE_ENFORCE_EQ
(
op_desc
.
Output
(
"Out"
).
size
(),
1
);
PADDLE_ENFORCE_EQ
(
op_desc
.
Output
(
"Out"
).
size
(),
1
);
auto
*
input1
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"X"
)[
0
]);
auto
*
input1
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"X"
)[
0
]);
bool
global_pooling
=
boost
::
get
<
bool
>
(
op_desc
.
GetAttr
(
"global_pooling"
));
std
::
string
pool_type
=
std
::
string
pool_type
=
boost
::
get
<
std
::
string
>
(
op_desc
.
GetAttr
(
"pooling_type"
));
boost
::
get
<
std
::
string
>
(
op_desc
.
GetAttr
(
"pooling_type"
));
std
::
vector
<
int
>
ksize
=
std
::
vector
<
int
>
ksize
=
...
@@ -42,7 +43,13 @@ class Pool2dOpConverter : public OpConverter {
...
@@ -42,7 +43,13 @@ class Pool2dOpConverter : public OpConverter {
std
::
vector
<
int
>
paddings
=
std
::
vector
<
int
>
paddings
=
boost
::
get
<
std
::
vector
<
int
>>
(
op_desc
.
GetAttr
(
"paddings"
));
boost
::
get
<
std
::
vector
<
int
>>
(
op_desc
.
GetAttr
(
"paddings"
));
const
nvinfer1
::
DimsHW
nv_ksize
(
ksize
[
0
],
ksize
[
1
]);
nvinfer1
::
DimsHW
nv_ksize
(
ksize
[
0
],
ksize
[
1
]);
if
(
global_pooling
==
true
)
{
nvinfer1
::
Dims
input_shape
=
input1
->
getDimensions
();
int
nbDims
=
input_shape
.
nbDims
;
nv_ksize
.
d
[
0
]
=
input_shape
.
d
[
nbDims
-
2
];
nv_ksize
.
d
[
1
]
=
input_shape
.
d
[
nbDims
-
1
];
}
const
nvinfer1
::
DimsHW
nv_strides
(
strides
[
0
],
strides
[
1
]);
const
nvinfer1
::
DimsHW
nv_strides
(
strides
[
0
],
strides
[
1
]);
const
nvinfer1
::
DimsHW
nv_paddings
(
paddings
[
0
],
paddings
[
1
]);
const
nvinfer1
::
DimsHW
nv_paddings
(
paddings
[
0
],
paddings
[
1
]);
...
...
paddle/fluid/inference/tensorrt/convert/test_batch_norm_op.cc
0 → 100644
浏览文件 @
ec9eb220
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
TEST
(
batch_norm_op
,
test
)
{
std
::
unordered_set
<
std
::
string
>
parameters
(
{
"batch_norm_scale"
,
"batch_norm_bias"
,
"batch_norm_mean"
,
"batch_norm_variance"
});
framework
::
Scope
scope
;
TRTConvertValidation
validator
(
5
,
parameters
,
scope
,
1
<<
15
);
std
::
vector
<
int
>
param_shape
{
2
};
validator
.
DeclInputVar
(
"batch_norm_X"
,
nvinfer1
::
DimsCHW
(
2
,
5
,
5
));
validator
.
DeclParamVar
(
"batch_norm_scale"
,
param_shape
);
validator
.
DeclParamVar
(
"batch_norm_bias"
,
param_shape
);
validator
.
DeclParamVar
(
"batch_norm_mean"
,
param_shape
);
validator
.
DeclParamVar
(
"batch_norm_variance"
,
param_shape
);
validator
.
DeclOutputVar
(
"batch_norm_Y"
,
nvinfer1
::
DimsCHW
(
2
,
5
,
5
));
validator
.
DeclOutputVar
(
"batch_norm_save_mean"
,
param_shape
);
validator
.
DeclOutputVar
(
"batch_norm_save_variance"
,
param_shape
);
// Prepare Op description
framework
::
OpDesc
desc
;
desc
.
SetType
(
"batch_norm"
);
desc
.
SetInput
(
"X"
,
{
"batch_norm_X"
});
desc
.
SetInput
(
"Scale"
,
{
"batch_norm_scale"
});
desc
.
SetInput
(
"Bias"
,
{
"batch_norm_bias"
});
desc
.
SetInput
(
"Mean"
,
{
"batch_norm_mean"
});
desc
.
SetInput
(
"Variance"
,
{
"batch_norm_variance"
});
desc
.
SetOutput
(
"Y"
,
{
"batch_norm_Y"
});
desc
.
SetOutput
(
"MeanOut"
,
{
"batch_norm_mean"
});
desc
.
SetOutput
(
"VarianceOut"
,
{
"batch_norm_variance"
});
desc
.
SetOutput
(
"SavedMean"
,
{
"batch_norm_save_mean"
});
desc
.
SetOutput
(
"SavedVariance"
,
{
"batch_norm_save_variance"
});
float
eps
=
1e-5
f
;
bool
is_test
=
true
;
desc
.
SetAttr
(
"epsilon"
,
eps
);
desc
.
SetAttr
(
"is_test"
,
is_test
);
validator
.
SetOp
(
*
desc
.
Proto
());
std
::
unordered_set
<
std
::
string
>
neglected_output
=
{
"batch_norm_save_mean"
,
"batch_norm_save_variance"
,
"batch_norm_mean"
,
"batch_norm_variance"
};
validator
.
Execute
(
3
,
neglected_output
);
}
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
USE_OP
(
batch_norm
);
paddle/fluid/inference/tensorrt/convert/test_concat_op.cc
0 → 100644
浏览文件 @
ec9eb220
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
TEST
(
concat_op
,
test
)
{
std
::
unordered_set
<
std
::
string
>
parameters
({
""
});
framework
::
Scope
scope
;
TRTConvertValidation
validator
(
10
,
parameters
,
scope
,
1000
);
validator
.
DeclInputVar
(
"concat_x1"
,
nvinfer1
::
DimsCHW
(
10
,
3
,
1
));
validator
.
DeclInputVar
(
"concat_x2"
,
nvinfer1
::
DimsCHW
(
3
,
3
,
1
));
validator
.
DeclInputVar
(
"concat_x3"
,
nvinfer1
::
DimsCHW
(
7
,
3
,
1
));
validator
.
DeclOutputVar
(
"concat_out"
,
nvinfer1
::
DimsCHW
(
20
,
3
,
1
));
// Prepare Op description
framework
::
OpDesc
desc
;
desc
.
SetType
(
"concat"
);
desc
.
SetInput
(
"X"
,
{
"concat_x1"
,
"concat_x2"
,
"concat_x3"
});
desc
.
SetOutput
(
"Out"
,
{
"concat_out"
});
int
axis
=
1
;
desc
.
SetAttr
(
"axis"
,
axis
);
validator
.
SetOp
(
*
desc
.
Proto
());
validator
.
Execute
(
5
);
}
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
USE_OP
(
concat
);
paddle/fluid/inference/tensorrt/convert/test_op_converter.cc
浏览文件 @
ec9eb220
...
@@ -57,6 +57,7 @@ TEST(OpConverter, ConvertBlock) {
...
@@ -57,6 +57,7 @@ TEST(OpConverter, ConvertBlock) {
auto
*
x
=
scope
.
Var
(
"conv2d-Y"
);
auto
*
x
=
scope
.
Var
(
"conv2d-Y"
);
auto
*
x_tensor
=
x
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
x_tensor
=
x
->
GetMutable
<
framework
::
LoDTensor
>
();
x_tensor
->
Resize
(
framework
::
make_ddim
(
dim_vec
));
x_tensor
->
Resize
(
framework
::
make_ddim
(
dim_vec
));
x_tensor
->
mutable_data
<
float
>
(
platform
::
CUDAPlace
(
0
));
OpConverter
converter
;
OpConverter
converter
;
converter
.
ConvertBlock
(
*
block
->
Proto
(),
{
"conv2d-Y"
},
scope
,
converter
.
ConvertBlock
(
*
block
->
Proto
(),
{
"conv2d-Y"
},
scope
,
...
...
paddle/fluid/inference/tensorrt/convert/test_pool2d_op.cc
浏览文件 @
ec9eb220
...
@@ -20,7 +20,7 @@ namespace paddle {
...
@@ -20,7 +20,7 @@ namespace paddle {
namespace
inference
{
namespace
inference
{
namespace
tensorrt
{
namespace
tensorrt
{
TEST
(
Pool2dOpConverter
,
main
)
{
void
test_pool2d
(
bool
global_pooling
)
{
framework
::
Scope
scope
;
framework
::
Scope
scope
;
std
::
unordered_set
<
std
::
string
>
parameters
;
std
::
unordered_set
<
std
::
string
>
parameters
;
TRTConvertValidation
validator
(
5
,
parameters
,
scope
,
1
<<
15
);
TRTConvertValidation
validator
(
5
,
parameters
,
scope
,
1
<<
15
);
...
@@ -28,7 +28,10 @@ TEST(Pool2dOpConverter, main) {
...
@@ -28,7 +28,10 @@ TEST(Pool2dOpConverter, main) {
// The ITensor's Dims should not contain the batch size.
// The ITensor's Dims should not contain the batch size.
// So, the ITensor's Dims of input and output should be C * H * W.
// So, the ITensor's Dims of input and output should be C * H * W.
validator
.
DeclInputVar
(
"pool2d-X"
,
nvinfer1
::
Dims3
(
3
,
4
,
4
));
validator
.
DeclInputVar
(
"pool2d-X"
,
nvinfer1
::
Dims3
(
3
,
4
,
4
));
validator
.
DeclOutputVar
(
"pool2d-Out"
,
nvinfer1
::
Dims3
(
3
,
2
,
2
));
if
(
global_pooling
)
validator
.
DeclOutputVar
(
"pool2d-Out"
,
nvinfer1
::
Dims3
(
3
,
1
,
1
));
else
validator
.
DeclOutputVar
(
"pool2d-Out"
,
nvinfer1
::
Dims3
(
3
,
2
,
2
));
// Prepare Op description
// Prepare Op description
framework
::
OpDesc
desc
;
framework
::
OpDesc
desc
;
...
@@ -45,6 +48,7 @@ TEST(Pool2dOpConverter, main) {
...
@@ -45,6 +48,7 @@ TEST(Pool2dOpConverter, main) {
desc
.
SetAttr
(
"ksize"
,
ksize
);
desc
.
SetAttr
(
"ksize"
,
ksize
);
desc
.
SetAttr
(
"strides"
,
strides
);
desc
.
SetAttr
(
"strides"
,
strides
);
desc
.
SetAttr
(
"paddings"
,
paddings
);
desc
.
SetAttr
(
"paddings"
,
paddings
);
desc
.
SetAttr
(
"global_pooling"
,
global_pooling
);
LOG
(
INFO
)
<<
"set OP"
;
LOG
(
INFO
)
<<
"set OP"
;
validator
.
SetOp
(
*
desc
.
Proto
());
validator
.
SetOp
(
*
desc
.
Proto
());
...
@@ -53,6 +57,10 @@ TEST(Pool2dOpConverter, main) {
...
@@ -53,6 +57,10 @@ TEST(Pool2dOpConverter, main) {
validator
.
Execute
(
3
);
validator
.
Execute
(
3
);
}
}
TEST
(
Pool2dOpConverter
,
normal
)
{
test_pool2d
(
false
);
}
TEST
(
Pool2dOpConverter
,
test_global_pooling
)
{
test_pool2d
(
true
);
}
}
// namespace tensorrt
}
// namespace tensorrt
}
// namespace inference
}
// namespace inference
}
// namespace paddle
}
// namespace paddle
...
...
paddle/fluid/inference/tensorrt/convert/ut_helper.h
浏览文件 @
ec9eb220
...
@@ -24,6 +24,7 @@ limitations under the License. */
...
@@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
...
@@ -48,11 +49,17 @@ void RandomizeTensor(framework::LoDTensor* tensor, const platform::Place& place,
...
@@ -48,11 +49,17 @@ void RandomizeTensor(framework::LoDTensor* tensor, const platform::Place& place,
auto
dims
=
tensor
->
dims
();
auto
dims
=
tensor
->
dims
();
size_t
num_elements
=
analysis
::
AccuDims
(
dims
,
dims
.
size
());
size_t
num_elements
=
analysis
::
AccuDims
(
dims
,
dims
.
size
());
PADDLE_ENFORCE_GT
(
num_elements
,
0
);
PADDLE_ENFORCE_GT
(
num_elements
,
0
);
auto
*
data
=
tensor
->
mutable_data
<
float
>
(
place
);
platform
::
CPUPlace
cpu_place
;
framework
::
LoDTensor
temp_tensor
;
temp_tensor
.
Resize
(
dims
);
auto
*
temp_data
=
temp_tensor
.
mutable_data
<
float
>
(
cpu_place
);
for
(
size_t
i
=
0
;
i
<
num_elements
;
i
++
)
{
for
(
size_t
i
=
0
;
i
<
num_elements
;
i
++
)
{
*
(
data
+
i
)
=
random
(
0.
,
1.
);
*
(
temp_
data
+
i
)
=
random
(
0.
,
1.
);
}
}
TensorCopySync
(
temp_tensor
,
place
,
tensor
);
}
}
/*
/*
...
@@ -91,18 +98,26 @@ class TRTConvertValidation {
...
@@ -91,18 +98,26 @@ class TRTConvertValidation {
engine_
->
DeclareInput
(
name
,
nvinfer1
::
DataType
::
kFLOAT
,
dims
);
engine_
->
DeclareInput
(
name
,
nvinfer1
::
DataType
::
kFLOAT
,
dims
);
}
}
void
DeclParamVar
(
const
std
::
string
&
name
,
const
std
::
vector
<
int
>
dim_vec
)
{
DeclVar
(
name
,
dim_vec
);
}
// Declare a parameter varaible in the scope.
// Declare a parameter varaible in the scope.
void
DeclParamVar
(
const
std
::
string
&
name
,
const
nvinfer1
::
Dims
&
dims
)
{
void
DeclParamVar
(
const
std
::
string
&
name
,
const
nvinfer1
::
Dims
&
dims
)
{
DeclVar
(
name
,
dims
,
true
);
DeclVar
(
name
,
dims
,
true
);
}
}
void
DeclOutputVar
(
const
std
::
string
&
name
,
const
std
::
vector
<
int
>
dim_vec
)
{
DeclVar
(
name
,
dim_vec
);
}
void
DeclOutputVar
(
const
std
::
string
&
name
,
const
nvinfer1
::
Dims
&
dims
)
{
void
DeclOutputVar
(
const
std
::
string
&
name
,
const
nvinfer1
::
Dims
&
dims
)
{
DeclVar
(
name
,
dims
);
DeclVar
(
name
,
dims
);
}
}
void
DeclVar
(
const
std
::
string
&
name
,
const
std
::
vector
<
int
>
dim_vec
)
{
void
DeclVar
(
const
std
::
string
&
name
,
const
std
::
vector
<
int
>
dim_vec
)
{
platform
::
C
PU
Place
place
;
platform
::
C
UDA
Place
place
;
platform
::
C
PU
DeviceContext
ctx
(
place
);
platform
::
C
UDA
DeviceContext
ctx
(
place
);
auto
*
x
=
scope_
.
Var
(
name
);
auto
*
x
=
scope_
.
Var
(
name
);
auto
*
x_tensor
=
x
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
x_tensor
=
x
->
GetMutable
<
framework
::
LoDTensor
>
();
...
@@ -141,18 +156,22 @@ class TRTConvertValidation {
...
@@ -141,18 +156,22 @@ class TRTConvertValidation {
PADDLE_ENFORCE
(
var
);
PADDLE_ENFORCE
(
var
);
auto
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
engine_
->
SetInputFrom
C
PU
(
engine_
->
SetInputFrom
G
PU
(
input
,
static_cast
<
void
*>
(
tensor
->
data
<
void
>
()),
input
,
static_cast
<
void
*>
(
tensor
->
data
<
void
>
()),
sizeof
(
float
)
*
sizeof
(
float
)
*
analysis
::
AccuDims
(
tensor
->
dims
(),
tensor
->
dims
().
size
()));
analysis
::
AccuDims
(
tensor
->
dims
(),
tensor
->
dims
().
size
()));
}
}
}
}
void
Execute
(
int
batch_size
)
{
// We use the set 'neglected_output' here, because some Ops like batch norm,
// the outputs specified in the op des are only used during training,
// so we should neglect those output during inference.
void
Execute
(
int
batch_size
,
std
::
unordered_set
<
std
::
string
>
neglected_output
=
{})
{
// Execute Fluid Op
// Execute Fluid Op
PADDLE_ENFORCE_LE
(
batch_size
,
max_batch_size_
);
PADDLE_ENFORCE_LE
(
batch_size
,
max_batch_size_
);
platform
::
C
PU
Place
place
;
platform
::
C
UDA
Place
place
;
platform
::
C
PU
DeviceContext
ctx
(
place
);
platform
::
C
UDA
DeviceContext
ctx
(
place
);
op_
->
Run
(
scope_
,
place
);
op_
->
Run
(
scope_
,
place
);
// Execute TRT.
// Execute TRT.
engine_
->
Execute
(
batch_size
);
engine_
->
Execute
(
batch_size
);
...
@@ -161,6 +180,7 @@ class TRTConvertValidation {
...
@@ -161,6 +180,7 @@ class TRTConvertValidation {
ASSERT_FALSE
(
op_desc_
->
OutputArgumentNames
().
empty
());
ASSERT_FALSE
(
op_desc_
->
OutputArgumentNames
().
empty
());
const
size_t
output_space_size
=
3000
;
const
size_t
output_space_size
=
3000
;
for
(
const
auto
&
output
:
op_desc_
->
OutputArgumentNames
())
{
for
(
const
auto
&
output
:
op_desc_
->
OutputArgumentNames
())
{
if
(
neglected_output
.
count
(
output
))
continue
;
std
::
vector
<
float
>
fluid_out
;
std
::
vector
<
float
>
fluid_out
;
std
::
vector
<
float
>
trt_out
(
output_space_size
);
std
::
vector
<
float
>
trt_out
(
output_space_size
);
engine_
->
GetOutputInCPU
(
output
,
&
trt_out
[
0
],
output_space_size
);
engine_
->
GetOutputInCPU
(
output
,
&
trt_out
[
0
],
output_space_size
);
...
...
paddle/fluid/inference/tensorrt/engine.cc
浏览文件 @
ec9eb220
...
@@ -33,6 +33,7 @@ void TensorRTEngine::Build(const DescType &paddle_model) {
...
@@ -33,6 +33,7 @@ void TensorRTEngine::Build(const DescType &paddle_model) {
}
}
void
TensorRTEngine
::
Execute
(
int
batch_size
)
{
void
TensorRTEngine
::
Execute
(
int
batch_size
)
{
freshDeviceId
();
batch_size_
=
batch_size
;
batch_size_
=
batch_size
;
std
::
vector
<
void
*>
buffers
;
std
::
vector
<
void
*>
buffers
;
for
(
auto
&
buf
:
buffers_
)
{
for
(
auto
&
buf
:
buffers_
)
{
...
@@ -60,6 +61,7 @@ TensorRTEngine::~TensorRTEngine() {
...
@@ -60,6 +61,7 @@ TensorRTEngine::~TensorRTEngine() {
}
}
void
TensorRTEngine
::
FreezeNetwork
()
{
void
TensorRTEngine
::
FreezeNetwork
()
{
freshDeviceId
();
PADDLE_ENFORCE
(
infer_builder_
!=
nullptr
,
PADDLE_ENFORCE
(
infer_builder_
!=
nullptr
,
"Call InitNetwork first to initialize network."
);
"Call InitNetwork first to initialize network."
);
PADDLE_ENFORCE
(
infer_network_
!=
nullptr
,
PADDLE_ENFORCE
(
infer_network_
!=
nullptr
,
...
@@ -241,6 +243,13 @@ void TensorRTEngine::SetRuntimeBatch(size_t batch_size) {
...
@@ -241,6 +243,13 @@ void TensorRTEngine::SetRuntimeBatch(size_t batch_size) {
int
TensorRTEngine
::
GetRuntimeBatch
()
{
return
runtime_batch_
;
}
int
TensorRTEngine
::
GetRuntimeBatch
()
{
return
runtime_batch_
;
}
void
TensorRTEngine
::
freshDeviceId
()
{
int
count
;
cudaGetDeviceCount
(
&
count
);
PADDLE_ENFORCE_LT
(
device_
,
count
);
cudaSetDevice
(
device_
);
}
}
// namespace tensorrt
}
// namespace tensorrt
}
// namespace inference
}
// namespace inference
}
// namespace paddle
}
// namespace paddle
paddle/fluid/inference/tensorrt/engine.h
浏览文件 @
ec9eb220
...
@@ -19,6 +19,7 @@ limitations under the License. */
...
@@ -19,6 +19,7 @@ limitations under the License. */
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_map>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/inference/engine.h"
#include "paddle/fluid/inference/engine.h"
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/utils/singleton.h"
#include "paddle/fluid/inference/utils/singleton.h"
...
@@ -52,13 +53,15 @@ class TensorRTEngine : public EngineBase {
...
@@ -52,13 +53,15 @@ class TensorRTEngine : public EngineBase {
};
};
TensorRTEngine
(
int
max_batch
,
int
max_workspace
,
TensorRTEngine
(
int
max_batch
,
int
max_workspace
,
cudaStream_t
*
stream
=
nullptr
,
cudaStream_t
*
stream
=
nullptr
,
int
device
=
0
,
nvinfer1
::
ILogger
&
logger
=
NaiveLogger
::
Global
())
nvinfer1
::
ILogger
&
logger
=
NaiveLogger
::
Global
())
:
max_batch_
(
max_batch
),
:
max_batch_
(
max_batch
),
max_workspace_
(
max_workspace
),
max_workspace_
(
max_workspace
),
stream_
(
stream
?
stream
:
&
default_stream_
),
stream_
(
stream
?
stream
:
&
default_stream_
),
logger_
(
logger
)
{
logger_
(
logger
),
cudaStreamCreate
(
&
default_stream_
);
device_
(
device
)
{
freshDeviceId
();
cudaStreamCreate
(
stream_
);
}
}
virtual
~
TensorRTEngine
();
virtual
~
TensorRTEngine
();
...
@@ -119,6 +122,15 @@ class TensorRTEngine : public EngineBase {
...
@@ -119,6 +122,15 @@ class TensorRTEngine : public EngineBase {
nvinfer1
::
INetworkDefinition
*
network
()
{
return
infer_network_
.
get
();
}
nvinfer1
::
INetworkDefinition
*
network
()
{
return
infer_network_
.
get
();
}
void
SetRuntimeBatch
(
size_t
batch_size
);
void
SetRuntimeBatch
(
size_t
batch_size
);
int
GetRuntimeBatch
();
int
GetRuntimeBatch
();
int
GetDevice
()
{
return
device_
;
}
// A pointer to CPU memory is needed of the TRT weight.
// Before TRT runs, fluid loads weight into GPU storage.
// so we need to copy the weights from GPU to CPU in our op converter.
// We use a map to store these weights for the weight memory is not released
// in advance, which affecting the construction of TRT Op.
std
::
unordered_map
<
std
::
string
/*name*/
,
std
::
unique_ptr
<
framework
::
Tensor
>>
weight_map
;
private:
private:
// the max batch size
// the max batch size
...
@@ -140,6 +152,8 @@ class TensorRTEngine : public EngineBase {
...
@@ -140,6 +152,8 @@ class TensorRTEngine : public EngineBase {
std
::
unordered_map
<
std
::
string
/*name*/
,
size_t
/*max size*/
>
buffer_sizes_
;
std
::
unordered_map
<
std
::
string
/*name*/
,
size_t
/*max size*/
>
buffer_sizes_
;
std
::
unordered_map
<
std
::
string
/*name*/
,
nvinfer1
::
ITensor
*
/*ITensor*/
>
std
::
unordered_map
<
std
::
string
/*name*/
,
nvinfer1
::
ITensor
*
/*ITensor*/
>
itensor_map_
;
itensor_map_
;
// The specific GPU id that the TensorRTEngine bounded to.
int
device_
;
// TensorRT related internal members
// TensorRT related internal members
template
<
typename
T
>
template
<
typename
T
>
...
@@ -156,6 +170,10 @@ class TensorRTEngine : public EngineBase {
...
@@ -156,6 +170,10 @@ class TensorRTEngine : public EngineBase {
infer_ptr
<
nvinfer1
::
INetworkDefinition
>
infer_network_
;
infer_ptr
<
nvinfer1
::
INetworkDefinition
>
infer_network_
;
infer_ptr
<
nvinfer1
::
ICudaEngine
>
infer_engine_
;
infer_ptr
<
nvinfer1
::
ICudaEngine
>
infer_engine_
;
infer_ptr
<
nvinfer1
::
IExecutionContext
>
infer_context_
;
infer_ptr
<
nvinfer1
::
IExecutionContext
>
infer_context_
;
// Each ICudaEngine object is bound to a specific GPU when it is instantiated,
// ensure that the thread is associated with the correct device by calling
// freshDeviceId().
void
freshDeviceId
();
};
// class TensorRTEngine
};
// class TensorRTEngine
// Add an layer__ into engine__ with args ARGS.
// Add an layer__ into engine__ with args ARGS.
...
@@ -188,8 +206,8 @@ class TRT_EngineManager {
...
@@ -188,8 +206,8 @@ class TRT_EngineManager {
// Create or get an engine called `name`
// Create or get an engine called `name`
TensorRTEngine
*
Create
(
int
max_batch
,
int
max_workspace
,
cudaStream_t
*
stream
,
TensorRTEngine
*
Create
(
int
max_batch
,
int
max_workspace
,
cudaStream_t
*
stream
,
const
std
::
string
&
name
)
{
const
std
::
string
&
name
,
int
gpu_device
=
0
)
{
auto
*
p
=
new
TensorRTEngine
(
max_batch
,
max_workspace
,
stream
);
auto
*
p
=
new
TensorRTEngine
(
max_batch
,
max_workspace
,
stream
,
gpu_device
);
engines_
[
name
].
reset
(
p
);
engines_
[
name
].
reset
(
p
);
return
p
;
return
p
;
}
}
...
...
paddle/fluid/inference/tensorrt/test_engine.cc
浏览文件 @
ec9eb220
...
@@ -27,7 +27,7 @@ namespace tensorrt {
...
@@ -27,7 +27,7 @@ namespace tensorrt {
class
TensorRTEngineTest
:
public
::
testing
::
Test
{
class
TensorRTEngineTest
:
public
::
testing
::
Test
{
protected:
protected:
void
SetUp
()
override
{
void
SetUp
()
override
{
ASSERT_EQ
(
0
,
cudaStreamCreate
(
&
stream_
));
//
ASSERT_EQ(0, cudaStreamCreate(&stream_));
engine_
=
new
TensorRTEngine
(
10
,
1
<<
10
,
&
stream_
);
engine_
=
new
TensorRTEngine
(
10
,
1
<<
10
,
&
stream_
);
engine_
->
InitNetwork
();
engine_
->
InitNetwork
();
}
}
...
...
paddle/fluid/operators/CMakeLists.txt
浏览文件 @
ec9eb220
...
@@ -100,7 +100,8 @@ function(op_library TARGET)
...
@@ -100,7 +100,8 @@ function(op_library TARGET)
endif
()
endif
()
# Define operators that don't need pybind here.
# Define operators that don't need pybind here.
foreach
(
manual_pybind_op
"compare_op"
"logical_op"
"nccl_op"
"tensor_array_read_write_op"
)
foreach
(
manual_pybind_op
"compare_op"
"logical_op"
"nccl_op"
"tensor_array_read_write_op"
"tensorrt_engine_op"
)
if
(
"
${
TARGET
}
"
STREQUAL
"
${
manual_pybind_op
}
"
)
if
(
"
${
TARGET
}
"
STREQUAL
"
${
manual_pybind_op
}
"
)
set
(
pybind_flag 1
)
set
(
pybind_flag 1
)
endif
()
endif
()
...
@@ -248,6 +249,7 @@ op_library(softmax_op DEPS softmax)
...
@@ -248,6 +249,7 @@ op_library(softmax_op DEPS softmax)
op_library
(
sequence_softmax_op DEPS softmax
)
op_library
(
sequence_softmax_op DEPS softmax
)
if
(
WITH_GPU AND TENSORRT_FOUND
)
if
(
WITH_GPU AND TENSORRT_FOUND
)
op_library
(
tensorrt_engine_op DEPS tensorrt_engine tensorrt_converter
)
op_library
(
tensorrt_engine_op DEPS tensorrt_engine tensorrt_converter
)
file
(
APPEND
${
pybind_file
}
"USE_CUDA_ONLY_OP(tensorrt_engine);
\n
"
)
nv_test
(
test_tensorrt_engine_op SRCS tensorrt_engine_op_test.cc
nv_test
(
test_tensorrt_engine_op SRCS tensorrt_engine_op_test.cc
DEPS tensorrt_engine_op
DEPS tensorrt_engine_op
analysis
)
analysis
)
...
...
paddle/fluid/operators/tensorrt_engine_op.cc
浏览文件 @
ec9eb220
...
@@ -17,112 +17,16 @@
...
@@ -17,112 +17,16 @@
#include <string>
#include <string>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/utils/singleton.h"
#include "paddle/fluid/operators/tensorrt_engine_op.h"
#include "paddle/fluid/operators/tensorrt_engine_op.h"
namespace
paddle
{
namespace
paddle
{
DEFINE_int32
(
tensorrt_engine_batch_size
,
1
,
"the batch_size of TensorRT"
);
DEFINE_int32
(
tensorrt_engine_batch_size
,
1
,
"the batch_size of TensorRT"
);
DEFINE_int32
(
tensorrt_max_batch_size
,
1
,
"TensorRT maximum batch size"
);
DEFINE_int32
(
tensorrt_workspace_size
,
16
<<
20
,
"TensorRT workspace size"
);
namespace
operators
{
namespace
operators
{
using
inference
::
Singleton
;
using
inference
::
tensorrt
::
TRT_EngineManager
;
using
FluidDT
=
framework
::
proto
::
VarType_Type
;
using
TRT_DT
=
nvinfer1
::
DataType
;
namespace
{
TRT_DT
FluidDataType2TRT
(
FluidDT
type
)
{
switch
(
type
)
{
case
FluidDT
::
VarType_Type_FP32
:
return
TRT_DT
::
kFLOAT
;
case
FluidDT
::
VarType_Type_INT32
:
return
TRT_DT
::
kINT32
;
default:
return
TRT_DT
::
kINT32
;
}
PADDLE_THROW
(
"unkown type"
);
return
TRT_DT
::
kINT32
;
}
nvinfer1
::
Dims
Vec2TRT_Dims
(
const
std
::
vector
<
int64_t
>
&
shape
)
{
PADDLE_ENFORCE_GT
(
shape
.
size
(),
1UL
,
"TensorRT' tensor input requires at least 2 dimensions"
);
PADDLE_ENFORCE_LE
(
shape
.
size
(),
4UL
,
"TensorRT' tensor input requires at most 4 dimensions"
);
PADDLE_ENFORCE_EQ
(
shape
.
size
(),
4UL
);
return
nvinfer1
::
DimsCHW
(
shape
[
1
],
shape
[
2
],
shape
[
3
]);
}
}
// namespace
template
<
typename
DeviceContext
,
typename
T
>
void
TensorRTEngineKernel
<
DeviceContext
,
T
>::
Prepare
(
const
framework
::
ExecutionContext
&
context
)
const
{
VLOG
(
4
)
<<
"Prepare engine"
;
// Get the ProgramDesc and pass to convert.
framework
::
proto
::
BlockDesc
block_desc
;
block_desc
.
ParseFromString
(
context
.
Attr
<
std
::
string
>
(
"subgraph"
));
int
max_batch
=
context
.
Attr
<
int
>
(
"max_batch"
);
auto
max_workspace
=
context
.
Attr
<
int
>
(
"max_workspace"
);
auto
params
=
context
.
Attr
<
std
::
vector
<
std
::
string
>>
(
"parameters"
);
std
::
unordered_set
<
std
::
string
>
parameters
;
for
(
const
auto
&
param
:
params
)
{
parameters
.
insert
(
param
);
}
std
::
vector
<
std
::
string
>
output_maps
=
context
.
Attr
<
std
::
vector
<
std
::
string
>>
(
"output_name_mapping"
);
// TODO(Superjomn) replace this with a different stream
auto
*
engine
=
Singleton
<
TRT_EngineManager
>::
Global
().
Create
(
max_batch
,
max_workspace
,
nullptr
/*engine hold its own stream*/
,
context
.
Attr
<
std
::
string
>
(
"engine_uniq_key"
));
engine
->
InitNetwork
();
framework
::
BlockDesc
block
(
nullptr
/*programdesc*/
,
&
block_desc
);
VLOG
(
4
)
<<
"parsed var size "
<<
block
.
AllVars
().
size
();
// Add inputs
VLOG
(
4
)
<<
"declare inputs"
;
for
(
auto
&
input
:
context
.
Inputs
(
"Xs"
))
{
if
(
parameters
.
count
(
input
))
continue
;
VLOG
(
4
)
<<
"declare input "
<<
input
;
auto
*
var
=
block
.
FindVar
(
input
);
// TensorRT engine need to create parameters. The parameter's description
// should be set in
PADDLE_ENFORCE
(
var
,
"no variable called %s"
,
input
);
PADDLE_ENFORCE_EQ
(
var
->
GetType
(),
FluidDT
::
VarType_Type_LOD_TENSOR
,
"TensorRT engine only takes LoDTensor as input"
);
auto
shape
=
var
->
GetShape
();
// For the special batch_size placeholder -1, drop it and pass the real
// shape of data.
// TODO(Superjomn) fix this with batch broadcast, or it can't handle
// variational batch size.
if
(
shape
[
0
]
==
-
1
)
{
shape
[
0
]
=
FLAGS_tensorrt_engine_batch_size
;
}
engine
->
DeclareInput
(
input
,
FluidDataType2TRT
(
var
->
Proto
()
->
type
().
lod_tensor
().
tensor
().
data_type
()),
Vec2TRT_Dims
(
shape
));
}
inference
::
Singleton
<
inference
::
tensorrt
::
OpConverter
>::
Global
().
ConvertBlock
(
block_desc
,
parameters
,
context
.
scope
(),
engine
);
// Add outputs
for
(
auto
&
output
:
output_maps
)
{
engine
->
DeclareOutput
(
output
);
}
engine
->
FreezeNetwork
();
}
class
TensorRTEngineOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
TensorRTEngineOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
public:
void
Make
()
override
{
void
Make
()
override
{
...
@@ -130,8 +34,6 @@ class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -130,8 +34,6 @@ class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput
(
"Ys"
,
"A list of outputs"
).
AsDuplicable
();
AddOutput
(
"Ys"
,
"A list of outputs"
).
AsDuplicable
();
AddAttr
<
std
::
string
>
(
"subgraph"
,
"the subgraph."
);
AddAttr
<
std
::
string
>
(
"subgraph"
,
"the subgraph."
);
AddAttr
<
std
::
string
>
(
"engine_uniq_key"
,
"unique key for the TRT engine."
);
AddAttr
<
std
::
string
>
(
"engine_uniq_key"
,
"unique key for the TRT engine."
);
AddAttr
<
int
>
(
"max_batch"
,
"the maximum batch size."
);
AddAttr
<
int
>
(
"max_workspace"
,
"the maximum batch size."
);
AddComment
(
"TensorRT engine operator."
);
AddComment
(
"TensorRT engine operator."
);
}
}
};
};
...
@@ -150,11 +52,4 @@ namespace ops = paddle::operators;
...
@@ -150,11 +52,4 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR
(
tensorrt_engine
,
ops
::
TensorRTEngineOp
,
REGISTER_OPERATOR
(
tensorrt_engine
,
ops
::
TensorRTEngineOp
,
ops
::
TensorRTEngineOpMaker
,
ops
::
TensorRTEngineOpMaker
);
ops
::
TensorRTEngineOpMaker
,
ops
::
TensorRTEngineOpMaker
);
REGISTER_OP_CPU_KERNEL
(
tensorrt_engine
,
ops
::
TensorRTEngineKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
TensorRTEngineKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
,
ops
::
TensorRTEngineKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int
>
,
ops
::
TensorRTEngineKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
);
#endif // PADDLE_WITH_CUDA
#endif // PADDLE_WITH_CUDA
paddle/fluid/operators/tensorrt_engine_op.cu.cc
0 → 100644
浏览文件 @
ec9eb220
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/tensorrt_engine_op.h"
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
tensorrt_engine
,
ops
::
TensorRTEngineKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
TensorRTEngineKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
TensorRTEngineKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
,
ops
::
TensorRTEngineKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int64_t
>
);
paddle/fluid/operators/tensorrt_engine_op.h
浏览文件 @
ec9eb220
...
@@ -19,16 +19,51 @@
...
@@ -19,16 +19,51 @@
#include <string>
#include <string>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
namespace
paddle
{
namespace
paddle
{
DECLARE_int32
(
tensorrt_engine_batch_size
);
DECLARE_int32
(
tensorrt_engine_batch_size
);
DECLARE_int32
(
tensorrt_max_batch_size
);
DECLARE_int32
(
tensorrt_workspace_size
);
namespace
operators
{
namespace
operators
{
using
FluidDT
=
framework
::
proto
::
VarType_Type
;
using
TRT_DT
=
nvinfer1
::
DataType
;
namespace
{
TRT_DT
FluidDataType2TRT
(
FluidDT
type
)
{
switch
(
type
)
{
case
FluidDT
::
VarType_Type_FP32
:
return
TRT_DT
::
kFLOAT
;
case
FluidDT
::
VarType_Type_INT32
:
return
TRT_DT
::
kINT32
;
default:
return
TRT_DT
::
kINT32
;
}
PADDLE_THROW
(
"unkown type"
);
return
TRT_DT
::
kINT32
;
}
nvinfer1
::
Dims
Vec2TRT_Dims
(
const
std
::
vector
<
int64_t
>&
shape
)
{
PADDLE_ENFORCE_GT
(
shape
.
size
(),
1UL
,
"TensorRT' tensor input requires at least 2 dimensions"
);
PADDLE_ENFORCE_LE
(
shape
.
size
(),
4UL
,
"TensorRT' tensor input requires at most 4 dimensions"
);
PADDLE_ENFORCE
(
shape
.
size
()
==
4UL
||
shape
.
size
()
==
2UL
);
if
(
shape
.
size
()
==
4UL
)
return
nvinfer1
::
DimsCHW
(
shape
[
1
],
shape
[
2
],
shape
[
3
]);
return
nvinfer1
::
DimsCHW
(
shape
[
1
],
1
,
1
);
}
}
// namespace
using
inference
::
Singleton
;
using
inference
::
Singleton
;
using
inference
::
tensorrt
::
TRT_EngineManager
;
using
inference
::
tensorrt
::
TRT_EngineManager
;
...
@@ -47,7 +82,7 @@ class TensorRTEngineOp : public framework::OperatorWithKernel {
...
@@ -47,7 +82,7 @@ class TensorRTEngineOp : public framework::OperatorWithKernel {
.
FindVar
(
input0
)
.
FindVar
(
input0
)
->
GetMutable
<
framework
::
LoDTensor
>
()
->
GetMutable
<
framework
::
LoDTensor
>
()
->
type
()),
->
type
()),
platform
::
CPU
Place
());
ctx
.
Get
Place
());
return
kt
;
return
kt
;
}
}
};
};
...
@@ -64,7 +99,7 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
...
@@ -64,7 +99,7 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
auto
input_names
=
context
.
op
().
Inputs
(
"Xs"
);
auto
input_names
=
context
.
op
().
Inputs
(
"Xs"
);
PADDLE_ENFORCE
(
!
input_names
.
empty
(),
"should pass more than one inputs"
);
PADDLE_ENFORCE
(
!
input_names
.
empty
(),
"should pass more than one inputs"
);
PADDLE_ENFORCE_LE
(
FLAGS_tensorrt_engine_batch_size
,
PADDLE_ENFORCE_LE
(
FLAGS_tensorrt_engine_batch_size
,
context
.
Attr
<
int
>
(
"max_batch"
)
);
FLAGS_tensorrt_max_batch_size
);
std
::
vector
<
std
::
string
>
output_maps
=
std
::
vector
<
std
::
string
>
output_maps
=
context
.
Attr
<
std
::
vector
<
std
::
string
>>
(
"output_name_mapping"
);
context
.
Attr
<
std
::
vector
<
std
::
string
>>
(
"output_name_mapping"
);
...
@@ -94,12 +129,19 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
...
@@ -94,12 +129,19 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
// Convert output tensor from engine to fluid
// Convert output tensor from engine to fluid
int
output_index
=
0
;
int
output_index
=
0
;
VLOG
(
4
)
<<
"TensorRT Engine Op Outputs:"
;
for
(
const
auto
&
y
:
context
.
Outputs
(
"Ys"
))
{
for
(
const
auto
&
y
:
context
.
Outputs
(
"Ys"
))
{
VLOG
(
4
)
<<
y
;
// convert output and copy to fluid.
// convert output and copy to fluid.
nvinfer1
::
ITensor
*
trt_t
=
engine
->
GetITensor
(
output_maps
[
output_index
]);
nvinfer1
::
ITensor
*
trt_t
=
engine
->
GetITensor
(
output_maps
[
output_index
]);
auto
dims
=
trt_t
->
getDimensions
();
auto
dims
=
trt_t
->
getDimensions
();
// Use the output ITensor's dims to reshape the Fluid Tensor.
// Use the output ITensor's dims to reshape the Fluid Tensor.
std
::
vector
<
int
>
ddim
(
dims
.
d
,
dims
.
d
+
dims
.
nbDims
);
// The ITensor doesn't contain the batch size dim.
std
::
vector
<
int
>
ddim
;
ddim
.
push_back
(
FLAGS_tensorrt_engine_batch_size
);
for
(
int
i
=
0
;
i
<
dims
.
nbDims
;
i
++
)
{
ddim
.
push_back
(
dims
.
d
[
i
]);
}
auto
*
fluid_v
=
context
.
scope
().
FindVar
(
y
);
auto
*
fluid_v
=
context
.
scope
().
FindVar
(
y
);
PADDLE_ENFORCE_NOT_NULL
(
fluid_v
,
"no output variable called %s"
,
y
);
PADDLE_ENFORCE_NOT_NULL
(
fluid_v
,
"no output variable called %s"
,
y
);
...
@@ -113,9 +155,11 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
...
@@ -113,9 +155,11 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
// TODO(Superjomn) change this float to dtype size.
// TODO(Superjomn) change this float to dtype size.
auto
size
=
inference
::
analysis
::
AccuDims
(
dims
.
d
,
dims
.
nbDims
)
*
auto
size
=
inference
::
analysis
::
AccuDims
(
dims
.
d
,
dims
.
nbDims
)
*
FLAGS_tensorrt_engine_batch_size
;
FLAGS_tensorrt_engine_batch_size
;
engine
->
GetOutputInCPU
(
output_maps
[
output_index
],
engine
->
GetOutputInGPU
(
fluid_t
->
mutable_data
<
float
>
(
platform
::
CPUPlace
()),
output_maps
[
output_index
],
size
*
sizeof
(
float
));
fluid_t
->
mutable_data
<
float
>
(
platform
::
CUDAPlace
(
boost
::
get
<
platform
::
CUDAPlace
>
(
context
.
GetPlace
()).
device
)),
size
*
sizeof
(
float
));
//} else {
//} else {
// engine->GetOutputInGPU(
// engine->GetOutputInGPU(
// y, fluid_t->mutable_data<float>(platform::CUDAPlace()),
// y, fluid_t->mutable_data<float>(platform::CUDAPlace()),
...
@@ -128,8 +172,67 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
...
@@ -128,8 +172,67 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
}
}
protected:
protected:
// Build the engine.
void
Prepare
(
const
framework
::
ExecutionContext
&
context
)
const
{
void
Prepare
(
const
framework
::
ExecutionContext
&
context
)
const
;
VLOG
(
4
)
<<
"Prepare engine"
;
// Get the ProgramDesc and pass to convert.
framework
::
proto
::
BlockDesc
block_desc
;
block_desc
.
ParseFromString
(
context
.
Attr
<
std
::
string
>
(
"subgraph"
));
int
max_batch
=
FLAGS_tensorrt_max_batch_size
;
auto
max_workspace
=
FLAGS_tensorrt_workspace_size
;
auto
params
=
context
.
Attr
<
std
::
vector
<
std
::
string
>>
(
"parameters"
);
std
::
unordered_set
<
std
::
string
>
parameters
;
for
(
const
auto
&
param
:
params
)
{
parameters
.
insert
(
param
);
}
std
::
vector
<
std
::
string
>
output_maps
=
context
.
Attr
<
std
::
vector
<
std
::
string
>>
(
"output_name_mapping"
);
// TODO(Superjomn) replace this with a different stream
auto
*
engine
=
Singleton
<
TRT_EngineManager
>::
Global
().
Create
(
max_batch
,
max_workspace
,
nullptr
/*engine hold its own stream*/
,
context
.
Attr
<
std
::
string
>
(
"engine_uniq_key"
),
boost
::
get
<
platform
::
CUDAPlace
>
(
context
.
GetPlace
()).
device
);
engine
->
InitNetwork
();
framework
::
BlockDesc
block
(
nullptr
/*programdesc*/
,
&
block_desc
);
VLOG
(
4
)
<<
"parsed var size "
<<
block
.
AllVars
().
size
();
// Add inputs
VLOG
(
4
)
<<
"declare inputs"
;
for
(
auto
&
input
:
context
.
Inputs
(
"Xs"
))
{
if
(
parameters
.
count
(
input
))
continue
;
VLOG
(
4
)
<<
"declare input "
<<
input
;
auto
*
var
=
block
.
FindVar
(
input
);
// TensorRT engine need to create parameters. The parameter's description
// should be set in
PADDLE_ENFORCE
(
var
,
"no variable called %s"
,
input
);
PADDLE_ENFORCE_EQ
(
var
->
GetType
(),
FluidDT
::
VarType_Type_LOD_TENSOR
,
"TensorRT engine only takes LoDTensor as input"
);
auto
shape
=
var
->
GetShape
();
// For the special batch_size placeholder -1, drop it and pass the real
// shape of data.
// TODO(Superjomn) fix this with batch broadcast, or it can't handle
// variational batch size.
if
(
shape
[
0
]
==
-
1
)
{
shape
[
0
]
=
FLAGS_tensorrt_engine_batch_size
;
}
engine
->
DeclareInput
(
input
,
FluidDataType2TRT
(
var
->
Proto
()
->
type
().
lod_tensor
().
tensor
().
data_type
()),
Vec2TRT_Dims
(
shape
));
}
inference
::
Singleton
<
inference
::
tensorrt
::
OpConverter
>::
Global
()
.
ConvertBlock
(
block_desc
,
parameters
,
context
.
scope
(),
engine
);
// Add outputs
for
(
auto
&
output
:
output_maps
)
{
engine
->
DeclareOutput
(
output
);
}
engine
->
FreezeNetwork
();
}
};
};
}
// namespace operators
}
// namespace operators
...
...
paddle/fluid/operators/tensorrt_engine_op_test.cc
浏览文件 @
ec9eb220
...
@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/operators/tensorrt_engine_op.h"
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
...
@@ -23,20 +24,20 @@ limitations under the License. */
...
@@ -23,20 +24,20 @@ limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
USE_C
PU
_ONLY_OP
(
tensorrt_engine
);
USE_C
UDA
_ONLY_OP
(
tensorrt_engine
);
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
namespace
{
namespace
{
void
CreateC
PU
Tensor
(
framework
::
Scope
*
scope
,
const
std
::
string
&
name
,
void
CreateC
UDA
Tensor
(
framework
::
Scope
*
scope
,
const
std
::
string
&
name
,
const
std
::
vector
<
int64_t
>&
shape
)
{
const
std
::
vector
<
int64_t
>&
shape
)
{
auto
*
var
=
scope
->
Var
(
name
);
auto
*
var
=
scope
->
Var
(
name
);
auto
*
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
dims
=
framework
::
make_ddim
(
shape
);
auto
dims
=
framework
::
make_ddim
(
shape
);
tensor
->
Resize
(
dims
);
tensor
->
Resize
(
dims
);
platform
::
C
PU
Place
place
;
platform
::
C
UDA
Place
place
;
platform
::
C
PU
DeviceContext
ctx
(
place
);
platform
::
C
UDA
DeviceContext
ctx
(
place
);
inference
::
tensorrt
::
RandomizeTensor
(
tensor
,
place
,
ctx
);
inference
::
tensorrt
::
RandomizeTensor
(
tensor
,
place
,
ctx
);
}
}
...
@@ -57,6 +58,8 @@ void AddTensorToBlockDesc(framework::proto::BlockDesc* block,
...
@@ -57,6 +58,8 @@ void AddTensorToBlockDesc(framework::proto::BlockDesc* block,
using
inference
::
analysis
::
SetAttr
;
using
inference
::
analysis
::
SetAttr
;
TEST
(
TensorRTEngineOp
,
manual
)
{
TEST
(
TensorRTEngineOp
,
manual
)
{
FLAGS_tensorrt_engine_batch_size
=
2
;
FLAGS_tensorrt_max_batch_size
=
2
;
framework
::
ProgramDesc
program
;
framework
::
ProgramDesc
program
;
auto
*
block_
=
program
.
Proto
()
->
add_blocks
();
auto
*
block_
=
program
.
Proto
()
->
add_blocks
();
block_
->
set_idx
(
0
);
block_
->
set_idx
(
0
);
...
@@ -98,8 +101,6 @@ TEST(TensorRTEngineOp, manual) {
...
@@ -98,8 +101,6 @@ TEST(TensorRTEngineOp, manual) {
engine_op_desc
.
SetOutput
(
"Ys"
,
std
::
vector
<
std
::
string
>
({
"z0"
}));
engine_op_desc
.
SetOutput
(
"Ys"
,
std
::
vector
<
std
::
string
>
({
"z0"
}));
SetAttr
<
std
::
string
>
(
engine_op_desc
.
Proto
(),
"subgraph"
,
SetAttr
<
std
::
string
>
(
engine_op_desc
.
Proto
(),
"subgraph"
,
block_
->
SerializeAsString
());
block_
->
SerializeAsString
());
SetAttr
<
int
>
(
engine_op_desc
.
Proto
(),
"max_batch"
,
100
);
SetAttr
<
int
>
(
engine_op_desc
.
Proto
(),
"max_workspace"
,
1
<<
10
);
SetAttr
<
std
::
string
>
(
engine_op_desc
.
Proto
(),
"engine_uniq_key"
,
"a_engine"
);
SetAttr
<
std
::
string
>
(
engine_op_desc
.
Proto
(),
"engine_uniq_key"
,
"a_engine"
);
SetAttr
<
std
::
vector
<
std
::
string
>>
(
engine_op_desc
.
Proto
(),
"parameters"
,
SetAttr
<
std
::
vector
<
std
::
string
>>
(
engine_op_desc
.
Proto
(),
"parameters"
,
std
::
vector
<
std
::
string
>
({}));
std
::
vector
<
std
::
string
>
({}));
...
@@ -112,15 +113,15 @@ TEST(TensorRTEngineOp, manual) {
...
@@ -112,15 +113,15 @@ TEST(TensorRTEngineOp, manual) {
LOG
(
INFO
)
<<
"engine_op "
<<
engine_op
.
get
();
LOG
(
INFO
)
<<
"engine_op "
<<
engine_op
.
get
();
framework
::
Scope
scope
;
framework
::
Scope
scope
;
platform
::
C
PU
Place
place
;
platform
::
C
UDA
Place
place
;
platform
::
C
PU
DeviceContext
ctx
(
place
);
platform
::
C
UDA
DeviceContext
ctx
(
place
);
// Prepare variables.
// Prepare variables.
CreateC
PU
Tensor
(
&
scope
,
"x"
,
std
::
vector
<
int64_t
>
({
2
,
4
}));
CreateC
UDA
Tensor
(
&
scope
,
"x"
,
std
::
vector
<
int64_t
>
({
2
,
4
}));
CreateC
PU
Tensor
(
&
scope
,
"y"
,
std
::
vector
<
int64_t
>
({
4
,
6
}));
CreateC
UDA
Tensor
(
&
scope
,
"y"
,
std
::
vector
<
int64_t
>
({
4
,
6
}));
CreateC
PU
Tensor
(
&
scope
,
"z"
,
std
::
vector
<
int64_t
>
({
2
,
6
}));
CreateC
UDA
Tensor
(
&
scope
,
"z"
,
std
::
vector
<
int64_t
>
({
2
,
6
}));
CreateC
PU
Tensor
(
&
scope
,
"y0"
,
std
::
vector
<
int64_t
>
({
6
,
8
}));
CreateC
UDA
Tensor
(
&
scope
,
"y0"
,
std
::
vector
<
int64_t
>
({
6
,
8
}));
CreateC
PU
Tensor
(
&
scope
,
"z0"
,
std
::
vector
<
int64_t
>
({
2
,
8
}));
CreateC
UDA
Tensor
(
&
scope
,
"z0"
,
std
::
vector
<
int64_t
>
({
2
,
8
}));
// Execute them.
// Execute them.
LOG
(
INFO
)
<<
"engine_op run"
;
LOG
(
INFO
)
<<
"engine_op run"
;
...
@@ -128,10 +129,12 @@ TEST(TensorRTEngineOp, manual) {
...
@@ -128,10 +129,12 @@ TEST(TensorRTEngineOp, manual) {
}
}
void
Execute
(
int
batch_size
,
int
input_dim
,
int
output_dim
,
int
nlayers
=
1
)
{
void
Execute
(
int
batch_size
,
int
input_dim
,
int
output_dim
,
int
nlayers
=
1
)
{
FLAGS_tensorrt_engine_batch_size
=
batch_size
;
FLAGS_tensorrt_max_batch_size
=
batch_size
;
framework
::
ProgramDesc
program
;
framework
::
ProgramDesc
program
;
framework
::
Scope
scope
;
framework
::
Scope
scope
;
platform
::
C
PU
Place
place
;
platform
::
C
UDA
Place
place
;
platform
::
C
PU
DeviceContext
ctx
(
place
);
platform
::
C
UDA
DeviceContext
ctx
(
place
);
auto
*
block_
=
program
.
Proto
()
->
add_blocks
();
auto
*
block_
=
program
.
Proto
()
->
add_blocks
();
block_
->
set_idx
(
0
);
block_
->
set_idx
(
0
);
...
@@ -165,10 +168,10 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) {
...
@@ -165,10 +168,10 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) {
// Prepare variables.
// Prepare variables.
if
(
!
x_created
)
{
if
(
!
x_created
)
{
CreateC
PU
Tensor
(
&
scope
,
x_name
,
std
::
vector
<
int64_t
>
(
x_shape
));
CreateC
UDA
Tensor
(
&
scope
,
x_name
,
std
::
vector
<
int64_t
>
(
x_shape
));
}
}
CreateC
PU
Tensor
(
&
scope
,
y_name
,
std
::
vector
<
int64_t
>
(
y_shape
));
CreateC
UDA
Tensor
(
&
scope
,
y_name
,
std
::
vector
<
int64_t
>
(
y_shape
));
CreateC
PU
Tensor
(
&
scope
,
z_name
,
std
::
vector
<
int64_t
>
(
z_shape
));
CreateC
UDA
Tensor
(
&
scope
,
z_name
,
std
::
vector
<
int64_t
>
(
z_shape
));
// It is wired, need to copy manually.
// It is wired, need to copy manually.
*
block_
->
add_ops
()
=
*
fc
->
Proto
();
*
block_
->
add_ops
()
=
*
fc
->
Proto
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录