Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
281d7c34
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
338
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
281d7c34
编写于
6月 28, 2020
作者:
J
jiweibo
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update shared_ptr. test=develop
上级
f6ba9268
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
32 addition
and
32 deletion
+32
-32
lite/api/cxx_api.h
lite/api/cxx_api.h
+2
-2
lite/api/cxx_api_impl.cc
lite/api/cxx_api_impl.cc
+10
-12
lite/api/paddle_api.h
lite/api/paddle_api.h
+8
-6
lite/api/test_resnet50_lite_cuda.cc
lite/api/test_resnet50_lite_cuda.cc
+12
-12
未找到文件。
lite/api/cxx_api.h
浏览文件 @
281d7c34
...
...
@@ -256,8 +256,8 @@ class CxxPaddleApiImpl : public lite_api::PaddlePredictor {
bool
status_is_cloned_
;
#ifdef LITE_WITH_CUDA
bool
multi_stream_
{
false
};
cudaStream_t
*
io_stream_
{
nullptr
}
;
cudaStream_t
*
exec_stream_
{
nullptr
}
;
std
::
shared_ptr
<
cudaStream_t
>
io_stream_
;
std
::
shared_ptr
<
cudaStream_t
>
exec_stream_
;
cudaEvent_t
input_event_
;
std
::
vector
<
cudaEvent_t
>
output_events_
;
// only for multi exec stream mode.
...
...
lite/api/cxx_api_impl.cc
浏览文件 @
281d7c34
...
...
@@ -95,18 +95,18 @@ void CxxPaddleApiImpl::CudaEnvInit(std::vector<std::string> *passes) {
if
(
config_
.
exec_stream
())
{
exec_stream_
=
config_
.
exec_stream
();
}
else
{
exec_stream_
=
new
cudaStream_t
();
TargetWrapperCuda
::
CreateStream
(
exec_stream_
);
exec_stream_
=
std
::
make_shared
<
cudaStream_t
>
();
TargetWrapperCuda
::
CreateStream
(
exec_stream_
.
get
()
);
}
if
(
config_
.
io_stream
())
{
io_stream_
=
config_
.
io_stream
();
}
else
{
io_stream_
=
new
cudaStream_t
();
TargetWrapperCuda
::
CreateStream
(
io_stream_
);
io_stream_
=
std
::
make_shared
<
cudaStream_t
>
();
TargetWrapperCuda
::
CreateStream
(
io_stream_
.
get
()
);
}
raw_predictor_
->
set_exec_stream
(
exec_stream_
);
raw_predictor_
->
set_io_stream
(
io_stream_
);
raw_predictor_
->
set_exec_stream
(
exec_stream_
.
get
()
);
raw_predictor_
->
set_io_stream
(
io_stream_
.
get
()
);
// init sync events.
if
(
config_
.
multi_stream
())
{
...
...
@@ -158,7 +158,8 @@ void CxxPaddleApiImpl::OutputSync() {
std
::
unique_ptr
<
lite_api
::
Tensor
>
CxxPaddleApiImpl
::
GetInput
(
int
i
)
{
auto
*
x
=
raw_predictor_
->
GetInput
(
i
);
#ifdef LITE_WITH_CUDA
return
std
::
unique_ptr
<
lite_api
::
Tensor
>
(
new
lite_api
::
Tensor
(
x
,
io_stream_
));
return
std
::
unique_ptr
<
lite_api
::
Tensor
>
(
new
lite_api
::
Tensor
(
x
,
io_stream_
.
get
()));
#else
return
std
::
unique_ptr
<
lite_api
::
Tensor
>
(
new
lite_api
::
Tensor
(
x
));
#endif
...
...
@@ -168,7 +169,8 @@ std::unique_ptr<const lite_api::Tensor> CxxPaddleApiImpl::GetOutput(
int
i
)
const
{
const
auto
*
x
=
raw_predictor_
->
GetOutput
(
i
);
#ifdef LITE_WITH_CUDA
return
std
::
unique_ptr
<
lite_api
::
Tensor
>
(
new
lite_api
::
Tensor
(
x
,
io_stream_
));
return
std
::
unique_ptr
<
lite_api
::
Tensor
>
(
new
lite_api
::
Tensor
(
x
,
io_stream_
.
get
()));
#else
return
std
::
unique_ptr
<
lite_api
::
Tensor
>
(
new
lite_api
::
Tensor
(
x
));
#endif
...
...
@@ -250,10 +252,6 @@ CxxPaddleApiImpl::~CxxPaddleApiImpl() {
for
(
size_t
i
=
0
;
i
<
output_events_
.
size
();
++
i
)
{
TargetWrapperCuda
::
DestroyEvent
(
output_events_
[
i
]);
}
if
(
multi_stream_
)
{
TargetWrapperCuda
::
DestroyStream
(
*
io_stream_
);
TargetWrapperCuda
::
DestroyStream
(
*
exec_stream_
);
}
#endif
}
...
...
lite/api/paddle_api.h
浏览文件 @
281d7c34
...
...
@@ -167,8 +167,8 @@ class LITE_API CxxConfig : public ConfigBase {
#endif
#ifdef LITE_WITH_CUDA
bool
multi_stream_
{
false
};
cudaStream_t
*
exec_stream_
{
nullptr
}
;
cudaStream_t
*
io_stream_
{
nullptr
}
;
std
::
shared_ptr
<
cudaStream_t
>
exec_stream_
;
std
::
shared_ptr
<
cudaStream_t
>
io_stream_
;
#endif
#ifdef LITE_WITH_MLU
lite_api
::
MLUCoreVersion
mlu_core_version_
{
lite_api
::
MLUCoreVersion
::
MLU_270
};
...
...
@@ -217,12 +217,14 @@ class LITE_API CxxConfig : public ConfigBase {
#ifdef LITE_WITH_CUDA
void
set_multi_stream
(
bool
multi_stream
)
{
multi_stream_
=
multi_stream
;
}
bool
multi_stream
()
const
{
return
multi_stream_
;
}
void
set_exec_stream
(
cudaStream_t
*
exec_stream
)
{
void
set_exec_stream
(
std
::
shared_ptr
<
cudaStream_t
>
exec_stream
)
{
exec_stream_
=
exec_stream
;
}
void
set_io_stream
(
cudaStream_t
*
io_stream
)
{
io_stream_
=
io_stream
;
}
cudaStream_t
*
exec_stream
()
{
return
exec_stream_
;
}
cudaStream_t
*
io_stream
()
{
return
io_stream_
;
}
void
set_io_stream
(
std
::
shared_ptr
<
cudaStream_t
>
io_stream
)
{
io_stream_
=
io_stream
;
}
std
::
shared_ptr
<
cudaStream_t
>
exec_stream
()
{
return
exec_stream_
;
}
std
::
shared_ptr
<
cudaStream_t
>
io_stream
()
{
return
io_stream_
;
}
#endif
#ifdef LITE_WITH_MLU
...
...
lite/api/test_resnet50_lite_cuda.cc
浏览文件 @
281d7c34
...
...
@@ -95,9 +95,9 @@ TEST(Resnet50, config_exec_stream) {
lite_api
::
CxxConfig
config
;
config
.
set_model_dir
(
FLAGS_model_dir
);
config
.
set_valid_places
({
lite_api
::
Place
{
TARGET
(
kCUDA
),
PRECISION
(
kFloat
)}});
cudaStream_t
exec_stream
;
lite
::
TargetWrapperCuda
::
CreateStream
(
&
exec_stream
);
config
.
set_exec_stream
(
&
exec_stream
);
std
::
shared_ptr
<
cudaStream_t
>
exec_stream
=
std
::
make_shared
<
cudaStream_t
>
()
;
lite
::
TargetWrapperCuda
::
CreateStream
(
exec_stream
.
get
()
);
config
.
set_exec_stream
(
exec_stream
);
RunModel
(
config
);
}
...
...
@@ -106,9 +106,9 @@ TEST(Resnet50, config_io_stream) {
lite_api
::
CxxConfig
config
;
config
.
set_model_dir
(
FLAGS_model_dir
);
config
.
set_valid_places
({
lite_api
::
Place
{
TARGET
(
kCUDA
),
PRECISION
(
kFloat
)}});
cudaStream_t
io_stream
;
lite
::
TargetWrapperCuda
::
CreateStream
(
&
io_stream
);
config
.
set_io_stream
(
&
io_stream
);
std
::
shared_ptr
<
cudaStream_t
>
io_stream
=
std
::
make_shared
<
cudaStream_t
>
()
;
lite
::
TargetWrapperCuda
::
CreateStream
(
io_stream
.
get
()
);
config
.
set_io_stream
(
io_stream
);
RunModel
(
config
);
}
...
...
@@ -117,12 +117,12 @@ TEST(Resnet50, config_all_stream) {
lite_api
::
CxxConfig
config
;
config
.
set_model_dir
(
FLAGS_model_dir
);
config
.
set_valid_places
({
lite_api
::
Place
{
TARGET
(
kCUDA
),
PRECISION
(
kFloat
)}});
cudaStream_t
exec_stream
;
lite
::
TargetWrapperCuda
::
CreateStream
(
&
exec_stream
);
config
.
set_exec_stream
(
&
exec_stream
);
cudaStream_t
io_stream
;
lite
::
TargetWrapperCuda
::
CreateStream
(
&
io_stream
);
config
.
set_io_stream
(
&
io_stream
);
std
::
shared_ptr
<
cudaStream_t
>
exec_stream
=
std
::
make_shared
<
cudaStream_t
>
()
;
lite
::
TargetWrapperCuda
::
CreateStream
(
exec_stream
.
get
()
);
config
.
set_exec_stream
(
exec_stream
);
std
::
shared_ptr
<
cudaStream_t
>
io_stream
=
std
::
make_shared
<
cudaStream_t
>
()
;
lite
::
TargetWrapperCuda
::
CreateStream
(
io_stream
.
get
()
);
config
.
set_io_stream
(
io_stream
);
RunModel
(
config
);
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录