Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
e5674f6d
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e5674f6d
编写于
8月 18, 2018
作者:
Z
Zhaolong Xing
提交者:
GitHub
8月 18, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #12753 from NHZlX/add_benchmark
modify tensorrt engine op from cpu mode to gpu
上级
31070872
1e92baf7
变更
15
显示空白变更内容
内联
并排
Showing
15 changed file
with
239 addition
and
168 deletion
+239
-168
paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc
...fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc
+1
-2
paddle/fluid/inference/analysis/subgraph_splitter.cc
paddle/fluid/inference/analysis/subgraph_splitter.cc
+1
-0
paddle/fluid/inference/tensorrt/convert/conv2d_op.cc
paddle/fluid/inference/tensorrt/convert/conv2d_op.cc
+16
-6
paddle/fluid/inference/tensorrt/convert/elementwise_op.cc
paddle/fluid/inference/tensorrt/convert/elementwise_op.cc
+14
-6
paddle/fluid/inference/tensorrt/convert/fc_op.cc
paddle/fluid/inference/tensorrt/convert/fc_op.cc
+15
-12
paddle/fluid/inference/tensorrt/convert/test_op_converter.cc
paddle/fluid/inference/tensorrt/convert/test_op_converter.cc
+1
-0
paddle/fluid/inference/tensorrt/convert/ut_helper.h
paddle/fluid/inference/tensorrt/convert/ut_helper.h
+14
-7
paddle/fluid/inference/tensorrt/engine.cc
paddle/fluid/inference/tensorrt/engine.cc
+9
-0
paddle/fluid/inference/tensorrt/engine.h
paddle/fluid/inference/tensorrt/engine.h
+23
-5
paddle/fluid/inference/tensorrt/test_engine.cc
paddle/fluid/inference/tensorrt/test_engine.cc
+1
-1
paddle/fluid/operators/CMakeLists.txt
paddle/fluid/operators/CMakeLists.txt
+3
-1
paddle/fluid/operators/tensorrt_engine_op.cc
paddle/fluid/operators/tensorrt_engine_op.cc
+0
-105
paddle/fluid/operators/tensorrt_engine_op.cu.cc
paddle/fluid/operators/tensorrt_engine_op.cu.cc
+24
-0
paddle/fluid/operators/tensorrt_engine_op.h
paddle/fluid/operators/tensorrt_engine_op.h
+100
-6
paddle/fluid/operators/tensorrt_engine_op_test.cc
paddle/fluid/operators/tensorrt_engine_op_test.cc
+17
-17
未找到文件。
paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc
浏览文件 @
e5674f6d
...
...
@@ -23,7 +23,7 @@
namespace
paddle
{
namespace
inference
{
DEFINE_int32
(
tensorrt_max_batchsize
,
3
,
"TensorRT maximum batch size"
);
DEFINE_int32
(
tensorrt_max_batchsize
,
1
,
"TensorRT maximum batch size"
);
DEFINE_int32
(
tensorrt_workspace_size
,
2048
,
"TensorRT workspace size"
);
namespace
analysis
{
...
...
@@ -52,7 +52,6 @@ bool DataFlowGraphToFluidPass::Initialize(Argument *argument) {
bool
DataFlowGraphToFluidPass
::
Finalize
()
{
return
true
;
}
void
DataFlowGraphToFluidPass
::
Run
(
DataFlowGraph
*
graph
)
{
FilterRedundantOutputOfSubGraph
(
graph
);
LOG
(
INFO
)
<<
"graph.inputs "
<<
graph
->
inputs
.
size
();
for
(
auto
&
node
:
GraphTraits
<
DataFlowGraph
>
(
graph
).
nodes_in_TS
())
{
if
(
node
.
deleted
())
continue
;
...
...
paddle/fluid/inference/analysis/subgraph_splitter.cc
浏览文件 @
e5674f6d
...
...
@@ -153,6 +153,7 @@ void SubGraphFuse::ReplaceNodesWithSubGraphs() {
inlink_or_outlink_cleaner
(
o
->
inlinks
);
}
}
FilterRedundantOutputOfSubGraph
(
graph_
);
}
}
// namespace analysis
...
...
paddle/fluid/inference/tensorrt/convert/conv2d_op.cc
浏览文件 @
e5674f6d
...
...
@@ -35,12 +35,20 @@ class Conv2dOpConverter : public OpConverter {
auto
*
Y_v
=
scope
.
FindVar
(
op_desc
.
Input
(
"Filter"
).
front
());
PADDLE_ENFORCE_NOT_NULL
(
Y_v
);
auto
*
Y_t
=
Y_v
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
weight_data
=
Y_t
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
PADDLE_ENFORCE_EQ
(
Y_t
->
dims
().
size
(),
4UL
);
const
int
n_output
=
Y_t
->
dims
()[
0
];
const
int
filter_h
=
Y_t
->
dims
()[
2
];
const
int
filter_w
=
Y_t
->
dims
()[
3
];
platform
::
CPUPlace
cpu_place
;
std
::
unique_ptr
<
framework
::
LoDTensor
>
weight_tensor
(
new
framework
::
LoDTensor
());
weight_tensor
->
Resize
(
Y_t
->
dims
());
TensorCopySync
((
*
Y_t
),
cpu_place
,
weight_tensor
.
get
());
auto
*
weight_data
=
weight_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
PADDLE_ENFORCE_EQ
(
weight_tensor
->
dims
().
size
(),
4UL
);
const
int
n_output
=
weight_tensor
->
dims
()[
0
];
const
int
filter_h
=
weight_tensor
->
dims
()[
2
];
const
int
filter_w
=
weight_tensor
->
dims
()[
3
];
const
int
groups
=
boost
::
get
<
int
>
(
op_desc
.
GetAttr
(
"groups"
));
const
std
::
vector
<
int
>
dilations
=
...
...
@@ -57,7 +65,7 @@ class Conv2dOpConverter : public OpConverter {
TensorRTEngine
::
Weight
weight
{
nvinfer1
::
DataType
::
kFLOAT
,
static_cast
<
void
*>
(
weight_data
),
Y_t
->
memory_size
()
/
sizeof
(
float
)};
weight_tensor
->
memory_size
()
/
sizeof
(
float
)};
TensorRTEngine
::
Weight
bias
{
nvinfer1
::
DataType
::
kFLOAT
,
nullptr
,
0
};
auto
*
layer
=
TRT_ENGINE_ADD_LAYER
(
...
...
@@ -70,6 +78,8 @@ class Conv2dOpConverter : public OpConverter {
layer
->
setNbGroups
(
groups
);
auto
output_name
=
op_desc
.
Output
(
"Output"
).
front
();
engine_
->
weight_map
[
op_desc
.
Input
(
"Filter"
).
front
()]
=
std
::
move
(
weight_tensor
);
engine_
->
SetITensor
(
output_name
,
layer
->
getOutput
(
0
));
if
(
test_mode
)
{
engine_
->
DeclareOutput
(
output_name
);
...
...
paddle/fluid/inference/tensorrt/convert/elementwise_op.cc
浏览文件 @
e5674f6d
...
...
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace
paddle
{
...
...
@@ -40,10 +39,17 @@ class ElementwiseWeightOpConverter : public OpConverter {
auto
*
Y_v
=
scope
.
FindVar
(
op_desc
.
Input
(
"Y"
).
front
());
PADDLE_ENFORCE_NOT_NULL
(
Y_v
);
auto
*
Y_t
=
Y_v
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
weight_data
=
Y_t
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
platform
::
CPUPlace
cpu_place
;
std
::
unique_ptr
<
framework
::
LoDTensor
>
weight_tensor
(
new
framework
::
LoDTensor
());
weight_tensor
->
Resize
(
Y_t
->
dims
());
TensorCopySync
((
*
Y_t
),
cpu_place
,
weight_tensor
.
get
());
auto
*
weight_data
=
weight_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
auto
scale_mode
=
nvinfer1
::
ScaleMode
::
kELEMENTWISE
;
std
::
vector
<
int
>
dims_y
=
framework
::
vectorize2int
(
Y_t
->
dims
());
std
::
vector
<
int
>
dims_y
=
framework
::
vectorize2int
(
weight_tensor
->
dims
());
if
(
static_cast
<
int
>
(
dims_y
.
size
())
==
dims_x
.
nbDims
+
1
)
{
if
(
dims_y
[
0
]
==
1
)
dims_y
.
erase
(
dims_y
.
begin
());
}
...
...
@@ -70,9 +76,9 @@ class ElementwiseWeightOpConverter : public OpConverter {
PADDLE_THROW
(
"TensorRT unsupported weight Shape for Elementwise op!"
);
}
TensorRTEngine
::
Weight
shift_weights
{
nvinfer1
::
DataType
::
kFLOAT
,
static_cast
<
void
*>
(
weight_data
),
Y_t
->
memory_size
()
/
sizeof
(
float
)};
TensorRTEngine
::
Weight
shift_weights
{
nvinfer1
::
DataType
::
kFLOAT
,
static_cast
<
void
*>
(
weight_data
),
weight_tensor
->
memory_size
()
/
sizeof
(
float
)};
TensorRTEngine
::
Weight
scale_weights
{
nvinfer1
::
DataType
::
kFLOAT
,
nullptr
,
0
};
TensorRTEngine
::
Weight
power_weights
{
nvinfer1
::
DataType
::
kFLOAT
,
nullptr
,
...
...
@@ -82,6 +88,8 @@ class ElementwiseWeightOpConverter : public OpConverter {
engine_
,
Scale
,
*
const_cast
<
nvinfer1
::
ITensor
*>
(
X
),
scale_mode
,
shift_weights
.
get
(),
scale_weights
.
get
(),
power_weights
.
get
());
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
engine_
->
weight_map
[
op_desc
.
Input
(
"Y"
).
front
()]
=
std
::
move
(
weight_tensor
);
engine_
->
SetITensor
(
output_name
,
layer
->
getOutput
(
0
));
if
(
test_mode
)
{
// the test framework can not determine which is the
// output, so place the declaration inside.
...
...
paddle/fluid/inference/tensorrt/convert/fc_op.cc
浏览文件 @
e5674f6d
...
...
@@ -12,12 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/platform/place.h"
namespace
paddle
{
namespace
inference
{
...
...
@@ -73,19 +68,26 @@ class FcOpConverter : public OpConverter {
auto
*
Y_t
=
Y_v
->
GetMutable
<
framework
::
LoDTensor
>
();
// This may trigger a GPU->CPU copy, because TRT's weight can only be
// assigned from CPU memory, that can't be avoided.
auto
*
weight_data
=
Y_t
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
PADDLE_ENFORCE_EQ
(
Y_t
->
dims
().
size
(),
2UL
);
// a matrix
size_t
n_output
=
Y_t
->
dims
()[
1
];
platform
::
CPUPlace
cpu_place
;
framework
::
LoDTensor
weight_tensor
;
weight_tensor
.
Resize
(
Y_t
->
dims
());
TensorCopySync
((
*
Y_t
),
cpu_place
,
&
weight_tensor
);
framework
::
LoDTensor
tmp
;
tmp
.
Resize
(
Y_t
->
dims
());
memcpy
(
tmp
.
mutable_data
<
float
>
(
platform
::
CPUPlace
()),
weight_data
,
auto
*
weight_data
=
weight_tensor
.
mutable_data
<
float
>
(
platform
::
CPUPlace
());
PADDLE_ENFORCE_EQ
(
weight_tensor
.
dims
().
size
(),
2UL
);
// a matrix
size_t
n_output
=
weight_tensor
.
dims
()[
1
];
std
::
unique_ptr
<
framework
::
Tensor
>
tmp
(
new
framework
::
LoDTensor
());
tmp
->
Resize
(
weight_tensor
.
dims
());
memcpy
(
tmp
->
mutable_data
<
float
>
(
platform
::
CPUPlace
()),
weight_data
,
Y_t
->
dims
()[
0
]
*
Y_t
->
dims
()[
1
]
*
sizeof
(
float
));
TensorRTEngine
::
Weight
weight
{
nvinfer1
::
DataType
::
kFLOAT
,
static_cast
<
void
*>
(
weight_data
),
Y_t
->
memory_size
()
/
sizeof
(
float
)};
TensorRTEngine
::
Weight
tmp_weight
(
nvinfer1
::
DataType
::
kFLOAT
,
static_cast
<
void
*>
(
tmp
.
data
<
float
>
()),
static_cast
<
void
*>
(
tmp
->
data
<
float
>
()),
Y_t
->
memory_size
()
/
sizeof
(
float
));
weight
.
dims
.
assign
({
Y_t
->
dims
()[
0
],
Y_t
->
dims
()[
1
]});
tmp_weight
.
dims
=
weight
.
dims
;
...
...
@@ -106,6 +108,7 @@ class FcOpConverter : public OpConverter {
auto
output_name
=
op_desc
.
Output
(
"Out"
).
front
();
engine_
->
SetITensor
(
output_name
,
layer
->
getOutput
(
0
));
engine_
->
weight_map
[
op_desc
.
Input
(
"Y"
).
front
()]
=
std
::
move
(
tmp
);
if
(
test_mode
)
{
engine_
->
DeclareOutput
(
output_name
);
}
...
...
paddle/fluid/inference/tensorrt/convert/test_op_converter.cc
浏览文件 @
e5674f6d
...
...
@@ -57,6 +57,7 @@ TEST(OpConverter, ConvertBlock) {
auto
*
x
=
scope
.
Var
(
"conv2d-Y"
);
auto
*
x_tensor
=
x
->
GetMutable
<
framework
::
LoDTensor
>
();
x_tensor
->
Resize
(
framework
::
make_ddim
(
dim_vec
));
x_tensor
->
mutable_data
<
float
>
(
platform
::
CUDAPlace
(
0
));
OpConverter
converter
;
converter
.
ConvertBlock
(
*
block
->
Proto
(),
{
"conv2d-Y"
},
scope
,
...
...
paddle/fluid/inference/tensorrt/convert/ut_helper.h
浏览文件 @
e5674f6d
...
...
@@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
...
...
@@ -48,11 +49,17 @@ void RandomizeTensor(framework::LoDTensor* tensor, const platform::Place& place,
auto
dims
=
tensor
->
dims
();
size_t
num_elements
=
analysis
::
AccuDims
(
dims
,
dims
.
size
());
PADDLE_ENFORCE_GT
(
num_elements
,
0
);
auto
*
data
=
tensor
->
mutable_data
<
float
>
(
place
);
platform
::
CPUPlace
cpu_place
;
framework
::
LoDTensor
temp_tensor
;
temp_tensor
.
Resize
(
dims
);
auto
*
temp_data
=
temp_tensor
.
mutable_data
<
float
>
(
cpu_place
);
for
(
size_t
i
=
0
;
i
<
num_elements
;
i
++
)
{
*
(
data
+
i
)
=
random
(
0.
,
1.
);
*
(
temp_
data
+
i
)
=
random
(
0.
,
1.
);
}
TensorCopySync
(
temp_tensor
,
place
,
tensor
);
}
/*
...
...
@@ -101,8 +108,8 @@ class TRTConvertValidation {
}
void
DeclVar
(
const
std
::
string
&
name
,
const
std
::
vector
<
int
>
dim_vec
)
{
platform
::
C
PU
Place
place
;
platform
::
C
PU
DeviceContext
ctx
(
place
);
platform
::
C
UDA
Place
place
;
platform
::
C
UDA
DeviceContext
ctx
(
place
);
auto
*
x
=
scope_
.
Var
(
name
);
auto
*
x_tensor
=
x
->
GetMutable
<
framework
::
LoDTensor
>
();
...
...
@@ -141,7 +148,7 @@ class TRTConvertValidation {
PADDLE_ENFORCE
(
var
);
auto
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
engine_
->
SetInputFrom
C
PU
(
engine_
->
SetInputFrom
G
PU
(
input
,
static_cast
<
void
*>
(
tensor
->
data
<
void
>
()),
sizeof
(
float
)
*
analysis
::
AccuDims
(
tensor
->
dims
(),
tensor
->
dims
().
size
()));
...
...
@@ -151,8 +158,8 @@ class TRTConvertValidation {
void
Execute
(
int
batch_size
)
{
// Execute Fluid Op
PADDLE_ENFORCE_LE
(
batch_size
,
max_batch_size_
);
platform
::
C
PU
Place
place
;
platform
::
C
PU
DeviceContext
ctx
(
place
);
platform
::
C
UDA
Place
place
;
platform
::
C
UDA
DeviceContext
ctx
(
place
);
op_
->
Run
(
scope_
,
place
);
// Execute TRT.
engine_
->
Execute
(
batch_size
);
...
...
paddle/fluid/inference/tensorrt/engine.cc
浏览文件 @
e5674f6d
...
...
@@ -33,6 +33,7 @@ void TensorRTEngine::Build(const DescType &paddle_model) {
}
void
TensorRTEngine
::
Execute
(
int
batch_size
)
{
freshDeviceId
();
batch_size_
=
batch_size
;
std
::
vector
<
void
*>
buffers
;
for
(
auto
&
buf
:
buffers_
)
{
...
...
@@ -60,6 +61,7 @@ TensorRTEngine::~TensorRTEngine() {
}
void
TensorRTEngine
::
FreezeNetwork
()
{
freshDeviceId
();
PADDLE_ENFORCE
(
infer_builder_
!=
nullptr
,
"Call InitNetwork first to initialize network."
);
PADDLE_ENFORCE
(
infer_network_
!=
nullptr
,
...
...
@@ -241,6 +243,13 @@ void TensorRTEngine::SetRuntimeBatch(size_t batch_size) {
int
TensorRTEngine
::
GetRuntimeBatch
()
{
return
runtime_batch_
;
}
void
TensorRTEngine
::
freshDeviceId
()
{
int
count
;
cudaGetDeviceCount
(
&
count
);
PADDLE_ENFORCE_LT
(
device_
,
count
);
cudaSetDevice
(
device_
);
}
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/engine.h
浏览文件 @
e5674f6d
...
...
@@ -19,6 +19,7 @@ limitations under the License. */
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/inference/engine.h"
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/utils/singleton.h"
...
...
@@ -52,13 +53,15 @@ class TensorRTEngine : public EngineBase {
};
TensorRTEngine
(
int
max_batch
,
int
max_workspace
,
cudaStream_t
*
stream
=
nullptr
,
cudaStream_t
*
stream
=
nullptr
,
int
device
=
0
,
nvinfer1
::
ILogger
&
logger
=
NaiveLogger
::
Global
())
:
max_batch_
(
max_batch
),
max_workspace_
(
max_workspace
),
stream_
(
stream
?
stream
:
&
default_stream_
),
logger_
(
logger
)
{
cudaStreamCreate
(
&
default_stream_
);
logger_
(
logger
),
device_
(
device
)
{
freshDeviceId
();
cudaStreamCreate
(
stream_
);
}
virtual
~
TensorRTEngine
();
...
...
@@ -119,6 +122,15 @@ class TensorRTEngine : public EngineBase {
nvinfer1
::
INetworkDefinition
*
network
()
{
return
infer_network_
.
get
();
}
void
SetRuntimeBatch
(
size_t
batch_size
);
int
GetRuntimeBatch
();
int
GetDevice
()
{
return
device_
;
}
// A pointer to CPU memory is needed of the TRT weight.
// Before TRT runs, fluid loads weight into GPU storage.
// so we need to copy the weights from GPU to CPU in our op converter.
// We use a map to store these weights for the weight memory is not released
// in advance, which affecting the construction of TRT Op.
std
::
unordered_map
<
std
::
string
/*name*/
,
std
::
unique_ptr
<
framework
::
Tensor
>>
weight_map
;
private:
// the max batch size
...
...
@@ -140,6 +152,8 @@ class TensorRTEngine : public EngineBase {
std
::
unordered_map
<
std
::
string
/*name*/
,
size_t
/*max size*/
>
buffer_sizes_
;
std
::
unordered_map
<
std
::
string
/*name*/
,
nvinfer1
::
ITensor
*
/*ITensor*/
>
itensor_map_
;
// The specific GPU id that the TensorRTEngine bounded to.
int
device_
;
// TensorRT related internal members
template
<
typename
T
>
...
...
@@ -156,6 +170,10 @@ class TensorRTEngine : public EngineBase {
infer_ptr
<
nvinfer1
::
INetworkDefinition
>
infer_network_
;
infer_ptr
<
nvinfer1
::
ICudaEngine
>
infer_engine_
;
infer_ptr
<
nvinfer1
::
IExecutionContext
>
infer_context_
;
// Each ICudaEngine object is bound to a specific GPU when it is instantiated,
// ensure that the thread is associated with the correct device by calling
// freshDeviceId().
void
freshDeviceId
();
};
// class TensorRTEngine
// Add an layer__ into engine__ with args ARGS.
...
...
@@ -188,8 +206,8 @@ class TRT_EngineManager {
// Create or get an engine called `name`
TensorRTEngine
*
Create
(
int
max_batch
,
int
max_workspace
,
cudaStream_t
*
stream
,
const
std
::
string
&
name
)
{
auto
*
p
=
new
TensorRTEngine
(
max_batch
,
max_workspace
,
stream
);
const
std
::
string
&
name
,
int
gpu_device
=
0
)
{
auto
*
p
=
new
TensorRTEngine
(
max_batch
,
max_workspace
,
stream
,
gpu_device
);
engines_
[
name
].
reset
(
p
);
return
p
;
}
...
...
paddle/fluid/inference/tensorrt/test_engine.cc
浏览文件 @
e5674f6d
...
...
@@ -27,7 +27,7 @@ namespace tensorrt {
class
TensorRTEngineTest
:
public
::
testing
::
Test
{
protected:
void
SetUp
()
override
{
ASSERT_EQ
(
0
,
cudaStreamCreate
(
&
stream_
));
//
ASSERT_EQ(0, cudaStreamCreate(&stream_));
engine_
=
new
TensorRTEngine
(
10
,
1
<<
10
,
&
stream_
);
engine_
->
InitNetwork
();
}
...
...
paddle/fluid/operators/CMakeLists.txt
浏览文件 @
e5674f6d
...
...
@@ -100,7 +100,8 @@ function(op_library TARGET)
endif
()
# Define operators that don't need pybind here.
foreach
(
manual_pybind_op
"compare_op"
"logical_op"
"nccl_op"
"tensor_array_read_write_op"
)
foreach
(
manual_pybind_op
"compare_op"
"logical_op"
"nccl_op"
"tensor_array_read_write_op"
"tensorrt_engine_op"
)
if
(
"
${
TARGET
}
"
STREQUAL
"
${
manual_pybind_op
}
"
)
set
(
pybind_flag 1
)
endif
()
...
...
@@ -248,6 +249,7 @@ op_library(softmax_op DEPS softmax)
op_library
(
sequence_softmax_op DEPS softmax
)
if
(
WITH_GPU AND TENSORRT_FOUND
)
op_library
(
tensorrt_engine_op DEPS tensorrt_engine tensorrt_converter
)
file
(
APPEND
${
pybind_file
}
"USE_CUDA_ONLY_OP(tensorrt_engine);
\n
"
)
nv_test
(
test_tensorrt_engine_op SRCS tensorrt_engine_op_test.cc
DEPS tensorrt_engine_op
analysis
)
...
...
paddle/fluid/operators/tensorrt_engine_op.cc
浏览文件 @
e5674f6d
...
...
@@ -17,10 +17,6 @@
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/utils/singleton.h"
#include "paddle/fluid/operators/tensorrt_engine_op.h"
namespace
paddle
{
...
...
@@ -29,100 +25,6 @@ DEFINE_int32(tensorrt_engine_batch_size, 1, "the batch_size of TensorRT");
namespace
operators
{
using
inference
::
Singleton
;
using
inference
::
tensorrt
::
TRT_EngineManager
;
using
FluidDT
=
framework
::
proto
::
VarType_Type
;
using
TRT_DT
=
nvinfer1
::
DataType
;
namespace
{
TRT_DT
FluidDataType2TRT
(
FluidDT
type
)
{
switch
(
type
)
{
case
FluidDT
::
VarType_Type_FP32
:
return
TRT_DT
::
kFLOAT
;
case
FluidDT
::
VarType_Type_INT32
:
return
TRT_DT
::
kINT32
;
default:
return
TRT_DT
::
kINT32
;
}
PADDLE_THROW
(
"unkown type"
);
return
TRT_DT
::
kINT32
;
}
nvinfer1
::
Dims
Vec2TRT_Dims
(
const
std
::
vector
<
int64_t
>
&
shape
)
{
PADDLE_ENFORCE_GT
(
shape
.
size
(),
1UL
,
"TensorRT' tensor input requires at least 2 dimensions"
);
PADDLE_ENFORCE_LE
(
shape
.
size
(),
4UL
,
"TensorRT' tensor input requires at most 4 dimensions"
);
PADDLE_ENFORCE_EQ
(
shape
.
size
(),
4UL
);
return
nvinfer1
::
DimsCHW
(
shape
[
1
],
shape
[
2
],
shape
[
3
]);
}
}
// namespace
template
<
typename
DeviceContext
,
typename
T
>
void
TensorRTEngineKernel
<
DeviceContext
,
T
>::
Prepare
(
const
framework
::
ExecutionContext
&
context
)
const
{
VLOG
(
4
)
<<
"Prepare engine"
;
// Get the ProgramDesc and pass to convert.
framework
::
proto
::
BlockDesc
block_desc
;
block_desc
.
ParseFromString
(
context
.
Attr
<
std
::
string
>
(
"subgraph"
));
int
max_batch
=
context
.
Attr
<
int
>
(
"max_batch"
);
auto
max_workspace
=
context
.
Attr
<
int
>
(
"max_workspace"
);
auto
params
=
context
.
Attr
<
std
::
vector
<
std
::
string
>>
(
"parameters"
);
std
::
unordered_set
<
std
::
string
>
parameters
;
for
(
const
auto
&
param
:
params
)
{
parameters
.
insert
(
param
);
}
std
::
vector
<
std
::
string
>
output_maps
=
context
.
Attr
<
std
::
vector
<
std
::
string
>>
(
"output_name_mapping"
);
// TODO(Superjomn) replace this with a different stream
auto
*
engine
=
Singleton
<
TRT_EngineManager
>::
Global
().
Create
(
max_batch
,
max_workspace
,
nullptr
/*engine hold its own stream*/
,
context
.
Attr
<
std
::
string
>
(
"engine_uniq_key"
));
engine
->
InitNetwork
();
framework
::
BlockDesc
block
(
nullptr
/*programdesc*/
,
&
block_desc
);
VLOG
(
4
)
<<
"parsed var size "
<<
block
.
AllVars
().
size
();
// Add inputs
VLOG
(
4
)
<<
"declare inputs"
;
for
(
auto
&
input
:
context
.
Inputs
(
"Xs"
))
{
if
(
parameters
.
count
(
input
))
continue
;
VLOG
(
4
)
<<
"declare input "
<<
input
;
auto
*
var
=
block
.
FindVar
(
input
);
// TensorRT engine need to create parameters. The parameter's description
// should be set in
PADDLE_ENFORCE
(
var
,
"no variable called %s"
,
input
);
PADDLE_ENFORCE_EQ
(
var
->
GetType
(),
FluidDT
::
VarType_Type_LOD_TENSOR
,
"TensorRT engine only takes LoDTensor as input"
);
auto
shape
=
var
->
GetShape
();
// For the special batch_size placeholder -1, drop it and pass the real
// shape of data.
// TODO(Superjomn) fix this with batch broadcast, or it can't handle
// variational batch size.
if
(
shape
[
0
]
==
-
1
)
{
shape
[
0
]
=
FLAGS_tensorrt_engine_batch_size
;
}
engine
->
DeclareInput
(
input
,
FluidDataType2TRT
(
var
->
Proto
()
->
type
().
lod_tensor
().
tensor
().
data_type
()),
Vec2TRT_Dims
(
shape
));
}
inference
::
Singleton
<
inference
::
tensorrt
::
OpConverter
>::
Global
().
ConvertBlock
(
block_desc
,
parameters
,
context
.
scope
(),
engine
);
// Add outputs
for
(
auto
&
output
:
output_maps
)
{
engine
->
DeclareOutput
(
output
);
}
engine
->
FreezeNetwork
();
}
class
TensorRTEngineOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
...
...
@@ -150,11 +52,4 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR
(
tensorrt_engine
,
ops
::
TensorRTEngineOp
,
ops
::
TensorRTEngineOpMaker
,
ops
::
TensorRTEngineOpMaker
);
REGISTER_OP_CPU_KERNEL
(
tensorrt_engine
,
ops
::
TensorRTEngineKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
TensorRTEngineKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
,
ops
::
TensorRTEngineKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int
>
,
ops
::
TensorRTEngineKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
);
#endif // PADDLE_WITH_CUDA
paddle/fluid/operators/tensorrt_engine_op.cu.cc
0 → 100644
浏览文件 @
e5674f6d
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/tensorrt_engine_op.h"
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
tensorrt_engine
,
ops
::
TensorRTEngineKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
TensorRTEngineKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
TensorRTEngineKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
,
ops
::
TensorRTEngineKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int64_t
>
);
paddle/fluid/operators/tensorrt_engine_op.h
浏览文件 @
e5674f6d
...
...
@@ -19,8 +19,10 @@
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
namespace
paddle
{
...
...
@@ -29,6 +31,35 @@ DECLARE_int32(tensorrt_engine_batch_size);
namespace
operators
{
using
FluidDT
=
framework
::
proto
::
VarType_Type
;
using
TRT_DT
=
nvinfer1
::
DataType
;
namespace
{
TRT_DT
FluidDataType2TRT
(
FluidDT
type
)
{
switch
(
type
)
{
case
FluidDT
::
VarType_Type_FP32
:
return
TRT_DT
::
kFLOAT
;
case
FluidDT
::
VarType_Type_INT32
:
return
TRT_DT
::
kINT32
;
default:
return
TRT_DT
::
kINT32
;
}
PADDLE_THROW
(
"unkown type"
);
return
TRT_DT
::
kINT32
;
}
nvinfer1
::
Dims
Vec2TRT_Dims
(
const
std
::
vector
<
int64_t
>&
shape
)
{
PADDLE_ENFORCE_GT
(
shape
.
size
(),
1UL
,
"TensorRT' tensor input requires at least 2 dimensions"
);
PADDLE_ENFORCE_LE
(
shape
.
size
(),
4UL
,
"TensorRT' tensor input requires at most 4 dimensions"
);
PADDLE_ENFORCE_EQ
(
shape
.
size
(),
4UL
);
return
nvinfer1
::
DimsCHW
(
shape
[
1
],
shape
[
2
],
shape
[
3
]);
}
}
// namespace
using
inference
::
Singleton
;
using
inference
::
tensorrt
::
TRT_EngineManager
;
...
...
@@ -47,7 +78,7 @@ class TensorRTEngineOp : public framework::OperatorWithKernel {
.
FindVar
(
input0
)
->
GetMutable
<
framework
::
LoDTensor
>
()
->
type
()),
platform
::
CPU
Place
());
ctx
.
Get
Place
());
return
kt
;
}
};
...
...
@@ -94,7 +125,9 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
// Convert output tensor from engine to fluid
int
output_index
=
0
;
VLOG
(
4
)
<<
"TensorRT Engine Op Outputs:"
;
for
(
const
auto
&
y
:
context
.
Outputs
(
"Ys"
))
{
VLOG
(
4
)
<<
y
;
// convert output and copy to fluid.
nvinfer1
::
ITensor
*
trt_t
=
engine
->
GetITensor
(
output_maps
[
output_index
]);
auto
dims
=
trt_t
->
getDimensions
();
...
...
@@ -113,8 +146,10 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
// TODO(Superjomn) change this float to dtype size.
auto
size
=
inference
::
analysis
::
AccuDims
(
dims
.
d
,
dims
.
nbDims
)
*
FLAGS_tensorrt_engine_batch_size
;
engine
->
GetOutputInCPU
(
output_maps
[
output_index
],
fluid_t
->
mutable_data
<
float
>
(
platform
::
CPUPlace
()),
engine
->
GetOutputInGPU
(
output_maps
[
output_index
],
fluid_t
->
mutable_data
<
float
>
(
platform
::
CUDAPlace
(
boost
::
get
<
platform
::
CUDAPlace
>
(
context
.
GetPlace
()).
device
)),
size
*
sizeof
(
float
));
//} else {
// engine->GetOutputInGPU(
...
...
@@ -128,8 +163,67 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
}
protected:
// Build the engine.
void
Prepare
(
const
framework
::
ExecutionContext
&
context
)
const
;
void
Prepare
(
const
framework
::
ExecutionContext
&
context
)
const
{
VLOG
(
4
)
<<
"Prepare engine"
;
// Get the ProgramDesc and pass to convert.
framework
::
proto
::
BlockDesc
block_desc
;
block_desc
.
ParseFromString
(
context
.
Attr
<
std
::
string
>
(
"subgraph"
));
int
max_batch
=
context
.
Attr
<
int
>
(
"max_batch"
);
auto
max_workspace
=
context
.
Attr
<
int
>
(
"max_workspace"
);
auto
params
=
context
.
Attr
<
std
::
vector
<
std
::
string
>>
(
"parameters"
);
std
::
unordered_set
<
std
::
string
>
parameters
;
for
(
const
auto
&
param
:
params
)
{
parameters
.
insert
(
param
);
}
std
::
vector
<
std
::
string
>
output_maps
=
context
.
Attr
<
std
::
vector
<
std
::
string
>>
(
"output_name_mapping"
);
// TODO(Superjomn) replace this with a different stream
auto
*
engine
=
Singleton
<
TRT_EngineManager
>::
Global
().
Create
(
max_batch
,
max_workspace
,
nullptr
/*engine hold its own stream*/
,
context
.
Attr
<
std
::
string
>
(
"engine_uniq_key"
),
boost
::
get
<
platform
::
CUDAPlace
>
(
context
.
GetPlace
()).
device
);
engine
->
InitNetwork
();
framework
::
BlockDesc
block
(
nullptr
/*programdesc*/
,
&
block_desc
);
VLOG
(
4
)
<<
"parsed var size "
<<
block
.
AllVars
().
size
();
// Add inputs
VLOG
(
4
)
<<
"declare inputs"
;
for
(
auto
&
input
:
context
.
Inputs
(
"Xs"
))
{
if
(
parameters
.
count
(
input
))
continue
;
VLOG
(
4
)
<<
"declare input "
<<
input
;
auto
*
var
=
block
.
FindVar
(
input
);
// TensorRT engine need to create parameters. The parameter's description
// should be set in
PADDLE_ENFORCE
(
var
,
"no variable called %s"
,
input
);
PADDLE_ENFORCE_EQ
(
var
->
GetType
(),
FluidDT
::
VarType_Type_LOD_TENSOR
,
"TensorRT engine only takes LoDTensor as input"
);
auto
shape
=
var
->
GetShape
();
// For the special batch_size placeholder -1, drop it and pass the real
// shape of data.
// TODO(Superjomn) fix this with batch broadcast, or it can't handle
// variational batch size.
if
(
shape
[
0
]
==
-
1
)
{
shape
[
0
]
=
FLAGS_tensorrt_engine_batch_size
;
}
engine
->
DeclareInput
(
input
,
FluidDataType2TRT
(
var
->
Proto
()
->
type
().
lod_tensor
().
tensor
().
data_type
()),
Vec2TRT_Dims
(
shape
));
}
inference
::
Singleton
<
inference
::
tensorrt
::
OpConverter
>::
Global
()
.
ConvertBlock
(
block_desc
,
parameters
,
context
.
scope
(),
engine
);
// Add outputs
for
(
auto
&
output
:
output_maps
)
{
engine
->
DeclareOutput
(
output
);
}
engine
->
FreezeNetwork
();
}
};
}
// namespace operators
...
...
paddle/fluid/operators/tensorrt_engine_op_test.cc
浏览文件 @
e5674f6d
...
...
@@ -23,20 +23,20 @@ limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
USE_C
PU
_ONLY_OP
(
tensorrt_engine
);
USE_C
UDA
_ONLY_OP
(
tensorrt_engine
);
namespace
paddle
{
namespace
operators
{
namespace
{
void
CreateC
PU
Tensor
(
framework
::
Scope
*
scope
,
const
std
::
string
&
name
,
void
CreateC
UDA
Tensor
(
framework
::
Scope
*
scope
,
const
std
::
string
&
name
,
const
std
::
vector
<
int64_t
>&
shape
)
{
auto
*
var
=
scope
->
Var
(
name
);
auto
*
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
dims
=
framework
::
make_ddim
(
shape
);
tensor
->
Resize
(
dims
);
platform
::
C
PU
Place
place
;
platform
::
C
PU
DeviceContext
ctx
(
place
);
platform
::
C
UDA
Place
place
;
platform
::
C
UDA
DeviceContext
ctx
(
place
);
inference
::
tensorrt
::
RandomizeTensor
(
tensor
,
place
,
ctx
);
}
...
...
@@ -112,15 +112,15 @@ TEST(TensorRTEngineOp, manual) {
LOG
(
INFO
)
<<
"engine_op "
<<
engine_op
.
get
();
framework
::
Scope
scope
;
platform
::
C
PU
Place
place
;
platform
::
C
PU
DeviceContext
ctx
(
place
);
platform
::
C
UDA
Place
place
;
platform
::
C
UDA
DeviceContext
ctx
(
place
);
// Prepare variables.
CreateC
PU
Tensor
(
&
scope
,
"x"
,
std
::
vector
<
int64_t
>
({
2
,
4
}));
CreateC
PU
Tensor
(
&
scope
,
"y"
,
std
::
vector
<
int64_t
>
({
4
,
6
}));
CreateC
PU
Tensor
(
&
scope
,
"z"
,
std
::
vector
<
int64_t
>
({
2
,
6
}));
CreateC
UDA
Tensor
(
&
scope
,
"x"
,
std
::
vector
<
int64_t
>
({
2
,
4
}));
CreateC
UDA
Tensor
(
&
scope
,
"y"
,
std
::
vector
<
int64_t
>
({
4
,
6
}));
CreateC
UDA
Tensor
(
&
scope
,
"z"
,
std
::
vector
<
int64_t
>
({
2
,
6
}));
CreateC
PU
Tensor
(
&
scope
,
"y0"
,
std
::
vector
<
int64_t
>
({
6
,
8
}));
CreateC
PU
Tensor
(
&
scope
,
"z0"
,
std
::
vector
<
int64_t
>
({
2
,
8
}));
CreateC
UDA
Tensor
(
&
scope
,
"y0"
,
std
::
vector
<
int64_t
>
({
6
,
8
}));
CreateC
UDA
Tensor
(
&
scope
,
"z0"
,
std
::
vector
<
int64_t
>
({
2
,
8
}));
// Execute them.
LOG
(
INFO
)
<<
"engine_op run"
;
...
...
@@ -130,8 +130,8 @@ TEST(TensorRTEngineOp, manual) {
void
Execute
(
int
batch_size
,
int
input_dim
,
int
output_dim
,
int
nlayers
=
1
)
{
framework
::
ProgramDesc
program
;
framework
::
Scope
scope
;
platform
::
C
PU
Place
place
;
platform
::
C
PU
DeviceContext
ctx
(
place
);
platform
::
C
UDA
Place
place
;
platform
::
C
UDA
DeviceContext
ctx
(
place
);
auto
*
block_
=
program
.
Proto
()
->
add_blocks
();
block_
->
set_idx
(
0
);
...
...
@@ -165,10 +165,10 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) {
// Prepare variables.
if
(
!
x_created
)
{
CreateC
PU
Tensor
(
&
scope
,
x_name
,
std
::
vector
<
int64_t
>
(
x_shape
));
CreateC
UDA
Tensor
(
&
scope
,
x_name
,
std
::
vector
<
int64_t
>
(
x_shape
));
}
CreateC
PU
Tensor
(
&
scope
,
y_name
,
std
::
vector
<
int64_t
>
(
y_shape
));
CreateC
PU
Tensor
(
&
scope
,
z_name
,
std
::
vector
<
int64_t
>
(
z_shape
));
CreateC
UDA
Tensor
(
&
scope
,
y_name
,
std
::
vector
<
int64_t
>
(
y_shape
));
CreateC
UDA
Tensor
(
&
scope
,
z_name
,
std
::
vector
<
int64_t
>
(
z_shape
));
// It is wired, need to copy manually.
*
block_
->
add_ops
()
=
*
fc
->
Proto
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录