Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
4b54c594
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
338
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4b54c594
编写于
7月 22, 2020
作者:
J
jiweibo
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update name format. test=develop
上级
df43322a
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
38 addition
and
27 deletion
+38
-27
lite/api/cxx_api.cc
lite/api/cxx_api.cc
+1
-1
lite/api/cxx_api.h
lite/api/cxx_api.h
+4
-4
lite/api/cxx_api_impl.cc
lite/api/cxx_api_impl.cc
+4
-4
lite/api/test_resnet50_lite_cuda.cc
lite/api/test_resnet50_lite_cuda.cc
+1
-1
lite/backends/cuda/CMakeLists.txt
lite/backends/cuda/CMakeLists.txt
+1
-1
lite/backends/cuda/stream_guard.cc
lite/backends/cuda/stream_guard.cc
+1
-1
lite/backends/cuda/stream_guard.h
lite/backends/cuda/stream_guard.h
+19
-8
lite/core/program.cc
lite/core/program.cc
+3
-3
lite/core/program.h
lite/core/program.h
+4
-4
未找到文件。
lite/api/cxx_api.cc
浏览文件 @
4b54c594
...
...
@@ -355,7 +355,7 @@ void Predictor::GenRuntimeProgram() {
program_generated_
=
true
;
#ifdef LITE_WITH_CUDA
if
(
!
cuda_use_multi_stream_
)
{
program_
->
UpdateContext
(
cuda_exec_stream_
,
cuda_io_stream_
);
program_
->
UpdateC
udaC
ontext
(
cuda_exec_stream_
,
cuda_io_stream_
);
}
#endif
}
...
...
lite/api/cxx_api.h
浏览文件 @
4b54c594
...
...
@@ -29,7 +29,7 @@
#ifdef LITE_WITH_CUDA
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/backends/cuda/stream_
wrapper
.h"
#include "lite/backends/cuda/stream_
guard
.h"
#endif
namespace
paddle
{
...
...
@@ -254,12 +254,12 @@ class CxxPaddleApiImpl : public lite_api::PaddlePredictor {
#ifdef LITE_WITH_CUDA
bool
cuda_use_multi_stream_
{
false
};
std
::
unique_ptr
<
lite
::
StreamWrapper
>
cuda_io_stream_
;
std
::
unique_ptr
<
lite
::
StreamWrapper
>
cuda_exec_stream_
;
std
::
unique_ptr
<
lite
::
CudaStreamGuard
>
cuda_io_stream_
;
std
::
unique_ptr
<
lite
::
CudaStreamGuard
>
cuda_exec_stream_
;
cudaEvent_t
cuda_input_event_
;
std
::
vector
<
cudaEvent_t
>
cuda_output_events_
;
// only used for multi exec stream mode.
std
::
vector
<
lite
::
StreamWrapper
>
cuda_exec_streams_
;
std
::
vector
<
lite
::
CudaStreamGuard
>
cuda_exec_streams_
;
#endif
};
...
...
lite/api/cxx_api_impl.cc
浏览文件 @
4b54c594
...
...
@@ -97,14 +97,14 @@ void CxxPaddleApiImpl::InitCudaEnv(std::vector<std::string> *passes) {
// init two streams for each predictor.
if
(
config_
.
cuda_exec_stream
())
{
cuda_exec_stream_
.
reset
(
new
lite
::
StreamWrapper
(
*
config_
.
cuda_exec_stream
()));
new
lite
::
CudaStreamGuard
(
*
config_
.
cuda_exec_stream
()));
}
else
{
cuda_exec_stream_
.
reset
(
new
lite
::
StreamWrapper
());
cuda_exec_stream_
.
reset
(
new
lite
::
CudaStreamGuard
());
}
if
(
config_
.
cuda_io_stream
())
{
cuda_io_stream_
.
reset
(
new
lite
::
StreamWrapper
(
*
config_
.
cuda_io_stream
()));
cuda_io_stream_
.
reset
(
new
lite
::
CudaStreamGuard
(
*
config_
.
cuda_io_stream
()));
}
else
{
cuda_io_stream_
.
reset
(
new
lite
::
StreamWrapper
());
cuda_io_stream_
.
reset
(
new
lite
::
CudaStreamGuard
());
}
raw_predictor_
->
set_cuda_exec_stream
(
cuda_exec_stream_
->
stream
());
...
...
lite/api/test_resnet50_lite_cuda.cc
浏览文件 @
4b54c594
...
...
@@ -29,7 +29,7 @@
namespace
paddle
{
namespace
lite
{
void
RunModel
(
lite_api
::
CxxConfig
config
)
{
void
RunModel
(
const
lite_api
::
CxxConfig
&
config
)
{
auto
predictor
=
lite_api
::
CreatePaddlePredictor
(
config
);
const
int
batch_size
=
4
;
const
int
channels
=
3
;
...
...
lite/backends/cuda/CMakeLists.txt
浏览文件 @
4b54c594
...
...
@@ -9,6 +9,6 @@ nv_library(cuda_blas SRCS blas.cc DEPS ${cuda_deps})
nv_library
(
nvtx_wrapper SRCS nvtx_wrapper DEPS
${
cuda_deps
}
)
lite_cc_library
(
cuda_context SRCS context.cc DEPS device_info
)
lite_cc_library
(
stream_
wrapper SRCS stream_wrapper
.cc DEPS target_wrapper_cuda
${
cuda_deps
}
)
lite_cc_library
(
stream_
guard SRCS stream_guard
.cc DEPS target_wrapper_cuda
${
cuda_deps
}
)
add_subdirectory
(
math
)
lite/backends/cuda/stream_
wrapper
.cc
→
lite/backends/cuda/stream_
guard
.cc
浏览文件 @
4b54c594
...
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/cuda/stream_
wrapper
.h"
#include "lite/backends/cuda/stream_
guard
.h"
#include "lite/backends/cuda/cuda_utils.h"
namespace
paddle
{
...
...
lite/backends/cuda/stream_
wrapper
.h
→
lite/backends/cuda/stream_
guard
.h
浏览文件 @
4b54c594
...
...
@@ -21,24 +21,35 @@
namespace
paddle
{
namespace
lite
{
class
StreamWrapper
{
// CudaStreamGuard is a encapsulation of cudaStream_t, which can accept external
// stream or internally created stream
//
// std::unique_ptr<CudaStreamGuard> sm;
//
// external stream: exec_stream
// sm.reset(new CudaStreamGuard(exec_stream));
// internal stream
// sm.reset(new CudaStreamGuard());
// get cudaStream_t
// sm->stream();
class
CudaStreamGuard
{
public:
explicit
StreamWrapper
(
cudaStream_t
stream
)
:
stream_
(
stream
),
owne
r
_
(
false
)
{}
StreamWrapper
()
:
owner
_
(
true
)
{
explicit
CudaStreamGuard
(
cudaStream_t
stream
)
:
stream_
(
stream
),
owne
d
_
(
false
)
{}
CudaStreamGuard
()
:
owned
_
(
true
)
{
lite
::
TargetWrapperCuda
::
CreateStream
(
&
stream_
);
}
~
StreamWrapper
()
{
if
(
owne
r
_
)
{
~
CudaStreamGuard
()
{
if
(
owne
d
_
)
{
lite
::
TargetWrapperCuda
::
DestroyStream
(
stream_
);
}
}
cudaStream_t
stream
()
{
return
stream_
;
}
bool
owne
r
()
{
return
owner
_
;
}
bool
owne
d
()
{
return
owned
_
;
}
private:
cudaStream_t
stream_
;
bool
owne
r_
;
bool
owne
d_
{
false
}
;
};
}
// namespace lite
...
...
lite/core/program.cc
浏览文件 @
4b54c594
...
...
@@ -71,7 +71,7 @@ void RuntimeProgram::UpdateVarsOfProgram(cpp::ProgramDesc* desc) {
std
::
map
<
std
::
string
,
cpp
::
VarDesc
>
origin_var_maps
;
auto
&
main_block
=
*
desc
->
GetBlock
<
cpp
::
BlockDesc
>
(
0
);
auto
var_size
=
main_block
.
VarsSize
();
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
var_size
)
;
i
++
)
{
for
(
size_t
i
=
0
;
i
<
var_size
;
i
++
)
{
auto
v
=
main_block
.
GetVar
<
cpp
::
VarDesc
>
(
i
);
auto
name
=
v
->
Name
();
origin_var_maps
.
emplace
(
name
,
*
v
);
...
...
@@ -144,9 +144,9 @@ void RuntimeProgram::UpdateVarsOfProgram(cpp::ProgramDesc* desc) {
}
#ifdef LITE_WITH_CUDA
void
RuntimeProgram
::
UpdateContext
(
cudaStream_t
exec
,
cudaStream_t
io
)
{
void
RuntimeProgram
::
UpdateC
udaC
ontext
(
cudaStream_t
exec
,
cudaStream_t
io
)
{
for
(
auto
&
inst
:
instructions_
)
{
inst
.
UpdateContext
(
exec
,
io
);
inst
.
UpdateC
udaC
ontext
(
exec
,
io
);
}
}
#endif
...
...
lite/core/program.h
浏览文件 @
4b54c594
...
...
@@ -129,7 +129,7 @@ struct Instruction {
}
}
void
Sync
()
const
{
kernel_
->
mutable_context
()
->
As
<
CUDAContext
>
().
Sync
();
}
void
UpdateContext
(
cudaStream_t
exec
,
cudaStream_t
io
)
{
void
UpdateC
udaC
ontext
(
cudaStream_t
exec
,
cudaStream_t
io
)
{
if
(
kernel_
->
target
()
==
TargetType
::
kCUDA
)
{
kernel_
->
mutable_context
()
->
As
<
CUDAContext
>
().
SetExecStream
(
exec
);
kernel_
->
mutable_context
()
->
As
<
CUDAContext
>
().
SetIoStream
(
io
);
...
...
@@ -223,9 +223,9 @@ class LITE_API RuntimeProgram {
void
UpdateVarsOfProgram
(
cpp
::
ProgramDesc
*
desc
);
#ifdef LITE_WITH_CUDA
// UpdateC
ontext will update the exec stream and io stream of all kernels in
// the program.
void
UpdateContext
(
cudaStream_t
exec
,
cudaStream_t
io
);
// UpdateC
udaContext will update the exec stream and io stream of all kernels
//
in
the program.
void
UpdateC
udaC
ontext
(
cudaStream_t
exec
,
cudaStream_t
io
);
#endif
private:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录