Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
ec213730
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ec213730
编写于
1月 22, 2019
作者:
N
nhzlx
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix trt stream bug.
BUG: After continuing to input different data, the output cannot be aligned test=develop
上级
e2ba9668
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
31 addition
and
66 deletion
+31
-66
paddle/fluid/inference/tensorrt/convert/test_op_converter.cc
paddle/fluid/inference/tensorrt/convert/test_op_converter.cc
+2
-2
paddle/fluid/inference/tensorrt/convert/ut_helper.h
paddle/fluid/inference/tensorrt/convert/ut_helper.h
+4
-6
paddle/fluid/inference/tensorrt/engine.cc
paddle/fluid/inference/tensorrt/engine.cc
+7
-9
paddle/fluid/inference/tensorrt/engine.h
paddle/fluid/inference/tensorrt/engine.h
+7
-43
paddle/fluid/inference/tensorrt/test_engine.cc
paddle/fluid/inference/tensorrt/test_engine.cc
+2
-2
paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
+7
-2
paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc
paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc
+2
-2
未找到文件。
paddle/fluid/inference/tensorrt/convert/test_op_converter.cc
浏览文件 @
ec213730
...
...
@@ -29,9 +29,9 @@ TEST(OpConverter, ConvertBlock) {
// init trt engine
cudaStream_t
stream_
;
std
::
unique_ptr
<
TensorRTEngine
>
engine_
;
engine_
.
reset
(
new
TensorRTEngine
(
5
,
1
<<
15
,
&
stream_
));
engine_
->
InitNetwork
();
PADDLE_ENFORCE_EQ
(
cudaStreamCreate
(
&
stream_
),
0
);
engine_
.
reset
(
new
TensorRTEngine
(
5
,
1
<<
15
,
stream_
));
engine_
->
InitNetwork
();
engine_
->
DeclareInput
(
"conv2d-X"
,
nvinfer1
::
DataType
::
kFLOAT
,
nvinfer1
::
Dims3
(
2
,
5
,
5
));
...
...
paddle/fluid/inference/tensorrt/convert/ut_helper.h
浏览文件 @
ec213730
...
...
@@ -78,11 +78,9 @@ class TRTConvertValidation {
scope_
(
scope
),
if_add_batch_
(
if_add_batch
),
max_batch_size_
(
max_batch_size
)
{
// create engine.
engine_
.
reset
(
new
TensorRTEngine
(
max_batch_size
,
workspace_size
,
&
stream_
));
engine_
->
InitNetwork
();
PADDLE_ENFORCE_EQ
(
cudaStreamCreate
(
&
stream_
),
0
);
engine_
.
reset
(
new
TensorRTEngine
(
max_batch_size
,
workspace_size
,
stream_
));
engine_
->
InitNetwork
();
}
// Declare a Variable as input with random initialization.
...
...
@@ -175,7 +173,7 @@ class TRTConvertValidation {
op_
->
Run
(
scope_
,
place
);
// Execute TRT.
engine_
->
Execute
(
batch_size
);
cudaStreamSynchronize
(
*
engine_
->
stream
());
cudaStreamSynchronize
(
engine_
->
stream
());
ASSERT_FALSE
(
op_desc_
->
OutputArgumentNames
().
empty
());
const
size_t
output_space_size
=
3000
;
...
...
@@ -184,7 +182,7 @@ class TRTConvertValidation {
std
::
vector
<
float
>
fluid_out
;
std
::
vector
<
float
>
trt_out
(
output_space_size
);
engine_
->
GetOutputInCPU
(
output
,
&
trt_out
[
0
],
output_space_size
);
cudaStreamSynchronize
(
*
engine_
->
stream
());
cudaStreamSynchronize
(
engine_
->
stream
());
auto
*
var
=
scope_
.
FindVar
(
output
);
auto
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
...
...
paddle/fluid/inference/tensorrt/engine.cc
浏览文件 @
ec213730
...
...
@@ -42,14 +42,13 @@ void TensorRTEngine::Execute(int batch_size) {
PADDLE_ENFORCE
(
buf
.
device
==
DeviceType
::
GPU
);
buffers
.
push_back
(
buf
.
buffer
);
}
PADDLE_ENFORCE_NOT_NULL
(
stream_
);
infer_context_
->
enqueue
(
batch_size
,
buffers
.
data
(),
*
stream_
,
nullptr
);
cudaStreamSynchronize
(
*
stream_
);
infer_context_
->
enqueue
(
batch_size
,
buffers
.
data
(),
stream_
,
nullptr
);
cudaStreamSynchronize
(
stream_
);
SetRuntimeBatch
(
batch_size
);
}
TensorRTEngine
::~
TensorRTEngine
()
{
cudaStreamSynchronize
(
*
stream_
);
cudaStreamSynchronize
(
stream_
);
// clean buffer
for
(
auto
&
buf
:
buffers_
)
{
if
(
buf
.
device
==
DeviceType
::
GPU
&&
buf
.
buffer
!=
nullptr
)
{
...
...
@@ -173,7 +172,7 @@ void TensorRTEngine::GetOutputInGPU(const std::string &name, void *dst,
auto
&
buf
=
buffer
(
name
);
PADDLE_ENFORCE_NOT_NULL
(
buf
.
buffer
,
"buffer should be allocated before"
);
PADDLE_ENFORCE_EQ
(
cudaMemcpyAsync
(
dst
,
buf
.
buffer
,
dst_size
,
cudaMemcpyDeviceToDevice
,
*
stream_
),
cudaMemcpyDeviceToDevice
,
stream_
),
0
);
}
...
...
@@ -194,7 +193,7 @@ void TensorRTEngine::GetOutputInCPU(const std::string &name, void *dst,
auto
&
buf
=
buffer
(
name
);
PADDLE_ENFORCE_NOT_NULL
(
buf
.
buffer
,
"buffer should be allocated before"
);
PADDLE_ENFORCE_EQ
(
0
,
cudaMemcpyAsync
(
dst
,
buf
.
buffer
,
dst_size
,
cudaMemcpyDeviceToHost
,
*
stream_
));
cudaMemcpyDeviceToHost
,
stream_
));
}
Buffer
&
TensorRTEngine
::
buffer
(
const
std
::
string
&
name
)
{
...
...
@@ -211,12 +210,11 @@ void TensorRTEngine::SetInputFromCPU(const std::string &name, const void *data,
auto
&
buf
=
buffer
(
name
);
PADDLE_ENFORCE_NOT_NULL
(
buf
.
buffer
);
PADDLE_ENFORCE_NOT_NULL
(
data
);
PADDLE_ENFORCE_NOT_NULL
(
stream_
);
PADDLE_ENFORCE_LE
(
size
,
buf
.
max_size
,
"buffer is too small"
);
PADDLE_ENFORCE
(
buf
.
device
==
DeviceType
::
GPU
);
buf
.
size
=
size
;
PADDLE_ENFORCE_EQ
(
0
,
cudaMemcpyAsync
(
buf
.
buffer
,
data
,
size
,
cudaMemcpyHostToDevice
,
*
stream_
));
cudaMemcpyHostToDevice
,
stream_
));
}
void
TensorRTEngine
::
SetInputFromGPU
(
const
std
::
string
&
name
,
const
void
*
data
,
...
...
@@ -227,7 +225,7 @@ void TensorRTEngine::SetInputFromGPU(const std::string &name, const void *data,
PADDLE_ENFORCE_LE
(
size
,
buf
.
max_size
,
"buffer is too small"
);
PADDLE_ENFORCE
(
buf
.
device
==
DeviceType
::
GPU
);
PADDLE_ENFORCE_EQ
(
0
,
cudaMemcpyAsync
(
buf
.
buffer
,
data
,
size
,
cudaMemcpyDeviceToDevice
,
*
stream_
));
cudaMemcpyDeviceToDevice
,
stream_
));
}
void
TensorRTEngine
::
SetITensor
(
const
std
::
string
&
name
,
...
...
paddle/fluid/inference/tensorrt/engine.h
浏览文件 @
ec213730
...
...
@@ -54,17 +54,14 @@ class TensorRTEngine : public EngineBase {
nvinfer1
::
Weights
w_
;
};
TensorRTEngine
(
int
max_batch
,
int
max_workspace
,
cudaStream_t
*
stream
=
nullptr
,
int
device
=
0
,
TensorRTEngine
(
int
max_batch
,
int
max_workspace
,
cudaStream_t
stream
,
int
device
=
0
,
nvinfer1
::
ILogger
&
logger
=
NaiveLogger
::
Global
())
:
max_batch_
(
max_batch
),
max_workspace_
(
max_workspace
),
stream_
(
stream
?
stream
:
&
default_stream_
),
stream_
(
stream
),
logger_
(
logger
),
device_
(
device
)
{
freshDeviceId
();
cudaStreamCreate
(
stream_
);
}
device_
(
device
)
{}
virtual
~
TensorRTEngine
();
...
...
@@ -102,7 +99,7 @@ class TensorRTEngine : public EngineBase {
// NOTE this should be used after calling `FreezeNetwork`.
Buffer
&
buffer
(
const
std
::
string
&
name
)
override
;
cudaStream_t
*
stream
()
{
return
stream_
;
}
cudaStream_t
stream
()
{
return
stream_
;
}
// Fill an input from CPU memory with name and size.
void
SetInputFromCPU
(
const
std
::
string
&
name
,
const
void
*
data
,
size_t
size
);
...
...
@@ -158,9 +155,8 @@ class TensorRTEngine : public EngineBase {
// batch size of the current data, will be updated each Executation.
int
batch_size_
{
-
1
};
cudaStream_t
*
stream_
;
// If stream_ is not set from outside, hold its own stream.
cudaStream_t
default_stream_
;
cudaStream_t
stream_
;
nvinfer1
::
ILogger
&
logger_
;
std
::
vector
<
Buffer
>
buffers_
;
...
...
@@ -208,38 +204,6 @@ class TensorRTEngine : public EngineBase {
#define TRT_ENGINE_ADD_LAYER(engine__, layer__, ARGS...) \
engine__->network()->add##layer__(ARGS);
/*
* Helper to control the TensorRT engine's creation and deletion.
*/
class
TRT_EngineManager
{
public:
bool
HasEngine
(
const
std
::
string
&
name
)
const
{
return
engines_
.
count
(
name
)
!=
0
;
}
// Get an engine called `name`.
TensorRTEngine
*
Get
(
const
std
::
string
&
name
)
const
{
return
engines_
.
at
(
name
).
get
();
}
// Create or get an engine called `name`
TensorRTEngine
*
Create
(
int
max_batch
,
int
max_workspace
,
cudaStream_t
*
stream
,
const
std
::
string
&
name
,
int
gpu_device
=
0
)
{
auto
*
p
=
new
TensorRTEngine
(
max_batch
,
max_workspace
,
stream
,
gpu_device
);
engines_
[
name
].
reset
(
p
);
return
p
;
}
void
DeleteALl
()
{
for
(
auto
&
item
:
engines_
)
{
item
.
second
.
reset
(
nullptr
);
}
}
private:
std
::
unordered_map
<
std
::
string
,
std
::
unique_ptr
<
TensorRTEngine
>>
engines_
;
};
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/test_engine.cc
浏览文件 @
ec213730
...
...
@@ -27,8 +27,8 @@ namespace tensorrt {
class
TensorRTEngineTest
:
public
::
testing
::
Test
{
protected:
void
SetUp
()
override
{
//
ASSERT_EQ(0, cudaStreamCreate(&stream_));
engine_
=
new
TensorRTEngine
(
10
,
1
<<
10
,
&
stream_
);
ASSERT_EQ
(
0
,
cudaStreamCreate
(
&
stream_
));
engine_
=
new
TensorRTEngine
(
10
,
1
<<
10
,
stream_
);
engine_
->
InitNetwork
();
}
...
...
paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
浏览文件 @
ec213730
...
...
@@ -96,9 +96,13 @@ class TensorRTEngineOp : public framework::OperatorBase {
void
RunTrt
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
{
int
runtime_batch
=
1
;
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
dev_ctx
).
stream
();
if
(
trt_engine_
.
get
()
==
nullptr
)
{
trt_engine_
.
reset
(
new
TensorRTEngine
(
max_batch_size_
,
workspace_size_
,
nullptr
,
max_batch_size_
,
workspace_size_
,
stream
,
boost
::
get
<
platform
::
CUDAPlace
>
(
dev_place
).
device
));
Prepare
(
scope
,
dev_place
,
trt_engine_
.
get
());
}
...
...
@@ -126,6 +130,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
}
}
cudaStreamSynchronize
(
stream
);
PADDLE_ENFORCE_LE
(
runtime_batch
,
max_batch_size_
);
// Execute the engine.
engine
->
Execute
(
runtime_batch
);
...
...
@@ -163,7 +168,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
output_index
+=
1
;
}
cudaStreamSynchronize
(
*
engine
->
stream
()
);
cudaStreamSynchronize
(
stream
);
}
void
Prepare
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
,
...
...
paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc
浏览文件 @
ec213730
...
...
@@ -99,7 +99,7 @@ TEST(TensorRTEngineOp, manual) {
SetAttr
<
std
::
string
>
(
engine_op_desc
.
Proto
(),
"subgraph"
,
block_
->
SerializeAsString
());
SetAttr
<
int
>
(
engine_op_desc
.
Proto
(),
"max_batch_size"
,
2
);
SetAttr
<
int
>
(
engine_op_desc
.
Proto
(),
"workspace_size"
,
2
<<
1
0
);
SetAttr
<
int
>
(
engine_op_desc
.
Proto
(),
"workspace_size"
,
2
<<
2
0
);
SetAttr
<
std
::
string
>
(
engine_op_desc
.
Proto
(),
"engine_uniq_key"
,
"a_engine"
);
SetAttr
<
std
::
vector
<
std
::
string
>>
(
engine_op_desc
.
Proto
(),
"parameters"
,
std
::
vector
<
std
::
string
>
({}));
...
...
@@ -193,7 +193,7 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) {
SetAttr
<
std
::
string
>
(
engine_op_desc
.
Proto
(),
"subgraph"
,
block_
->
SerializeAsString
());
SetAttr
<
int
>
(
engine_op_desc
.
Proto
(),
"max_batch_size"
,
batch_size
);
SetAttr
<
int
>
(
engine_op_desc
.
Proto
(),
"workspace_size"
,
2
<<
1
0
);
SetAttr
<
int
>
(
engine_op_desc
.
Proto
(),
"workspace_size"
,
2
<<
2
0
);
SetAttr
<
std
::
vector
<
std
::
string
>>
(
engine_op_desc
.
Proto
(),
"parameters"
,
std
::
vector
<
std
::
string
>
({
"y0"
,
"y1"
,
"y2"
,
"y3"
}));
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录