Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
5c59d213
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5c59d213
编写于
4月 01, 2020
作者:
石
石晓伟
提交者:
GitHub
4月 01, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
reverts the commit 23177, test=develop (#23363)
上级
2fe0758f
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
95 addition
and
344 deletion
+95
-344
paddle/fluid/platform/CMakeLists.txt
paddle/fluid/platform/CMakeLists.txt
+1
-2
paddle/fluid/platform/device_context.cc
paddle/fluid/platform/device_context.cc
+72
-36
paddle/fluid/platform/device_context.h
paddle/fluid/platform/device_context.h
+22
-182
paddle/fluid/platform/stream/CMakeLists.txt
paddle/fluid/platform/stream/CMakeLists.txt
+0
-3
paddle/fluid/platform/stream/cuda_stream.cc
paddle/fluid/platform/stream/cuda_stream.cc
+0
-60
paddle/fluid/platform/stream/cuda_stream.h
paddle/fluid/platform/stream/cuda_stream.h
+0
-61
未找到文件。
paddle/fluid/platform/CMakeLists.txt
浏览文件 @
5c59d213
...
...
@@ -44,7 +44,6 @@ cc_library(place SRCS place.cc DEPS enforce boost)
cc_test
(
place_test SRCS place_test.cc DEPS place glog gflags
)
add_subdirectory
(
dynload
)
add_subdirectory
(
stream
)
cc_library
(
cpu_helper SRCS cpu_helper.cc DEPS cblas enforce
)
cc_test
(
cpu_helper_test SRCS cpu_helper_test.cc DEPS cpu_helper
)
...
...
@@ -55,7 +54,7 @@ IF(WITH_DGC)
ENDIF
()
IF
(
WITH_GPU
)
set
(
GPU_CTX_DEPS dynload_cuda dynamic_loader
cuda_stream
)
set
(
GPU_CTX_DEPS dynload_cuda dynamic_loader
)
ENDIF
()
IF
(
WITH_MKLDNN
)
...
...
paddle/fluid/platform/device_context.cc
浏览文件 @
5c59d213
...
...
@@ -211,34 +211,6 @@ void CudnnWorkspaceHandle::ReallocWorkspace(size_t required_workspace_bytes) {
allocation_
=
memory
::
Alloc
(
device_context_
,
required_workspace_bytes
);
}
thread_local
std
::
unordered_map
<
const
CUDADeviceContext
*
,
std
::
unique_ptr
<
CUDAContext
>>
CUDADeviceContext
::
thread_ctx_
;
thread_local
std
::
mutex
CUDADeviceContext
::
ctx_mtx_
;
void
CUDAContext
::
InitEigenContext
(
const
stream
::
CUDAStream
&
stream
)
{
eigen_stream_
.
reset
(
new
EigenCudaStreamDevice
());
eigen_stream_
->
Reinitialize
(
&
stream
.
stream
(),
place_
);
eigen_device_
.
reset
(
new
Eigen
::
GpuDevice
(
eigen_stream_
.
get
()));
}
CUDAContext
::
CUDAContext
(
const
CUDAPlace
&
place
,
const
enum
stream
::
Priority
&
priority
)
{
place_
=
place
;
CUDADeviceGuard
guard
(
place_
.
device
);
stream_
.
Init
(
place
,
priority
);
InitEigenContext
(
stream_
);
InitCuBlasContext
(
stream_
);
InitCuDNNContext
(
stream_
);
InitCallbackManager
(
stream_
);
}
CUDAContext
::~
CUDAContext
()
{
CUDADeviceGuard
guard
(
place_
.
device
);
DestoryCuDNNContext
();
DestoryCuBlasContext
();
}
CUDADeviceContext
::
CUDADeviceContext
(
CUDAPlace
place
)
:
place_
(
place
)
{
CUDADeviceGuard
guard
(
place_
.
device
);
compute_capability_
=
GetCUDAComputeCapability
(
place_
.
device
);
...
...
@@ -246,6 +218,18 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
max_threads_per_mp_
=
GetCUDAMaxThreadsPerMultiProcessor
(
place_
.
device
);
max_grid_dim_size_
=
GetGpuMaxGridDimSize
(
place_
.
device
);
max_threads_per_block_
=
GetCUDAMaxThreadsPerBlock
(
place_
.
device
);
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaStreamCreate
(
&
stream_
));
eigen_stream_
.
reset
(
new
EigenCudaStreamDevice
());
eigen_stream_
->
Reinitialize
(
&
stream_
,
place
);
eigen_device_
.
reset
(
new
Eigen
::
GpuDevice
(
eigen_stream_
.
get
()));
cublas_handle_
.
reset
(
new
CublasHandleHolder
(
stream_
,
CUBLAS_DEFAULT_MATH
));
if
(
TensorCoreAvailable
())
{
#if CUDA_VERSION >= 9000
cublas_tensor_core_handle_
.
reset
(
new
CublasHandleHolder
(
stream_
,
CUBLAS_TENSOR_OP_MATH
));
#endif
}
driver_version_
=
GetCUDADriverVersion
(
place_
.
device
);
runtime_version_
=
GetCUDARuntimeVersion
(
place_
.
device
);
...
...
@@ -279,19 +263,73 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
<<
"Please recompile or reinstall Paddle with compatible CUDA "
"version."
;
}
if
(
dynload
::
HasCUDNN
())
{
auto
local_cudnn_version
=
cudnn_dso_ver
/
100
;
auto
compile_cudnn_version
=
CUDNN_VERSION
/
100
;
if
(
local_cudnn_version
<
static_cast
<
size_t
>
(
compile_cudnn_version
))
{
LOG_FIRST_N
(
WARNING
,
1
)
<<
"WARNING: device: "
<<
place_
.
device
<<
". The installed Paddle is compiled with CUDNN "
<<
compile_cudnn_version
/
10
<<
"."
<<
compile_cudnn_version
%
10
<<
", but CUDNN version in your machine is "
<<
local_cudnn_version
/
10
<<
"."
<<
local_cudnn_version
%
10
<<
", which may cause serious incompatible bug. "
<<
"Please recompile or reinstall Paddle with compatible CUDNN "
"version."
;
}
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
cudnnCreate
(
&
cudnn_handle_
),
"Failed to create Cudnn handle in DeviceContext"
);
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
cudnnSetStream
(
cudnn_handle_
,
stream_
),
"Failed to set stream for Cudnn handle in DeviceContext"
);
}
else
{
cudnn_handle_
=
nullptr
;
}
}
default_ctx_
.
reset
(
new
CUDAContext
(
place_
));
callback_manager_
.
reset
(
new
StreamCallbackManager
(
stream_
));
}
CUDADeviceContext
::~
CUDADeviceContext
()
{
SetDeviceId
(
place_
.
device
);
Wait
();
WaitStreamCallback
();
cublas_handle_
.
reset
();
cublas_tensor_core_handle_
.
reset
();
eigen_stream_
.
reset
();
eigen_device_
.
reset
();
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaStreamDestroy
(
stream_
));
if
(
cudnn_handle_
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
cudnnDestroy
(
cudnn_handle_
),
"Failed to destory Cudnn handle"
);
}
#if defined(PADDLE_WITH_NCCL)
if
(
nccl_comm_
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
ncclCommDestroy
(
nccl_comm_
));
}
#endif
}
Place
CUDADeviceContext
::
GetPlace
()
const
{
return
place_
;
}
void
CUDADeviceContext
::
Wait
()
const
{
context
()
->
Wait
();
}
void
CUDADeviceContext
::
Wait
()
const
{
cudaError_t
e_sync
=
cudaSuccess
;
#if !defined(_WIN32)
e_sync
=
cudaStreamSynchronize
(
stream_
);
#else
while
(
e_sync
=
cudaStreamQuery
(
stream_
))
{
if
(
e_sync
==
cudaErrorNotReady
)
continue
;
break
;
}
#endif
PADDLE_ENFORCE_CUDA_SUCCESS
(
e_sync
,
platform
::
errors
::
Fatal
(
"cudaStreamSynchronize raises error: %s, errono: %d"
,
cudaGetErrorString
(
e_sync
),
static_cast
<
int
>
(
e_sync
)));
}
int
CUDADeviceContext
::
GetComputeCapability
()
const
{
return
compute_capability_
;
...
...
@@ -308,26 +346,24 @@ int CUDADeviceContext::GetMaxThreadsPerBlock() const {
}
Eigen
::
GpuDevice
*
CUDADeviceContext
::
eigen_device
()
const
{
return
context
()
->
EigenDevice
()
.
get
();
return
eigen_device_
.
get
();
}
bool
CUDADeviceContext
::
tensor_core_available
()
const
{
return
c
ontext
()
->
CublasTensorCoreHandle
()
!=
nullptr
;
return
c
ublas_tensor_core_handle_
!=
nullptr
;
}
dim3
CUDADeviceContext
::
GetCUDAMaxGridDimSize
()
const
{
return
max_grid_dim_size_
;
}
cudnnHandle_t
CUDADeviceContext
::
cudnn_handle
()
const
{
return
context
()
->
CudnnHandle
();
}
cudnnHandle_t
CUDADeviceContext
::
cudnn_handle
()
const
{
return
cudnn_handle_
;
}
CudnnWorkspaceHandle
CUDADeviceContext
::
cudnn_workspace_handle
()
const
{
return
CudnnWorkspaceHandle
(
*
this
,
&
cudnn_handle_mtx_
);
}
cudaStream_t
CUDADeviceContext
::
stream
()
const
{
return
context
()
->
Stream
()
;
}
cudaStream_t
CUDADeviceContext
::
stream
()
const
{
return
stream_
;
}
CUDAPinnedDeviceContext
::
CUDAPinnedDeviceContext
()
{
eigen_device_
.
reset
(
new
Eigen
::
DefaultDevice
());
...
...
paddle/fluid/platform/device_context.h
浏览文件 @
5c59d213
...
...
@@ -38,7 +38,6 @@ limitations under the License. */
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/stream/cuda_stream.h"
#include "paddle/fluid/platform/stream_callback_manager.h"
#endif
#include "unsupported/Eigen/CXX11/Tensor"
...
...
@@ -81,160 +80,6 @@ struct DefaultDeviceContextType<platform::CPUPlace> {
class
EigenCudaStreamDevice
;
class
CudnnWorkspaceHandle
;
class
CUDAContext
{
public:
CUDAContext
()
=
default
;
explicit
CUDAContext
(
const
CUDAPlace
&
place
,
const
enum
stream
::
Priority
&
priority
=
stream
::
Priority
::
NORMAL
);
~
CUDAContext
();
const
CUDAPlace
&
Place
()
const
{
return
place_
;
}
const
std
::
unique_ptr
<
Eigen
::
GpuDevice
>&
EigenDevice
()
const
{
return
eigen_device_
;
}
const
std
::
unique_ptr
<
EigenCudaStreamDevice
>&
EigenStream
()
const
{
return
eigen_stream_
;
}
const
cudaStream_t
&
Stream
()
const
{
return
stream_
.
stream
();
}
const
cudnnHandle_t
&
CudnnHandle
()
const
{
return
cudnn_handle_
;
}
const
std
::
unique_ptr
<
CublasHandleHolder
>&
CublasHandle
()
const
{
return
cublas_handle_
;
}
const
std
::
unique_ptr
<
CublasHandleHolder
>&
CublasTensorCoreHandle
()
const
{
return
cublas_tensor_core_handle_
;
}
/*! \brief Call cublas function safely. */
template
<
typename
Callback
>
inline
void
CublasCall
(
Callback
&&
callback
)
const
{
cublas_handle_
->
Call
(
std
::
forward
<
Callback
>
(
callback
));
}
/*! \brief Check whether tensor core is supported */
bool
tensor_core_available
()
const
;
/*! \brief Call cublas function with Tensor Core safely. If
Tensor Core is not available, use DEFAULT_MATH instead. */
template
<
typename
Callback
>
inline
void
TensorCoreCublasCallIfAvailable
(
Callback
&&
callback
)
const
{
if
(
cublas_tensor_core_handle_
)
{
cublas_tensor_core_handle_
->
Call
(
std
::
forward
<
Callback
>
(
callback
));
}
else
{
cublas_handle_
->
Call
(
std
::
forward
<
Callback
>
(
callback
));
}
}
template
<
typename
Callback
>
void
RecordEvent
(
cudaEvent_t
ev
,
Callback
callback
)
{
callback
();
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaEventRecord
(
ev
,
stream_
.
stream
()),
platform
::
errors
::
Fatal
(
"CUDA event recording failed."
));
}
template
<
typename
Callback
>
void
AddStreamCallback
(
Callback
&&
callback
)
const
{
callback_manager_
->
AddCallback
(
callback
);
}
void
WaitStreamCallback
()
const
{
callback_manager_
->
Wait
();
}
void
Wait
()
const
{
cudaError_t
e_sync
=
cudaSuccess
;
#if !defined(_WIN32)
e_sync
=
cudaStreamSynchronize
(
stream_
.
stream
());
#else
while
(
e_sync
=
cudaStreamQuery
(
stream_
.
stream
()))
{
if
(
e_sync
==
cudaErrorNotReady
)
continue
;
break
;
}
#endif
PADDLE_ENFORCE_CUDA_SUCCESS
(
e_sync
,
platform
::
errors
::
Fatal
(
"cudaStreamSynchronize raises error: %s, errono: %d"
,
cudaGetErrorString
(
e_sync
),
static_cast
<
int
>
(
e_sync
)));
}
private:
void
InitEigenContext
(
const
stream
::
CUDAStream
&
stream
);
void
InitCuBlasContext
(
const
stream
::
CUDAStream
&
stream
)
{
cublas_handle_
.
reset
(
new
CublasHandleHolder
(
stream
.
stream
(),
CUBLAS_DEFAULT_MATH
));
if
(
TensorCoreAvailable
())
{
#if CUDA_VERSION >= 9000
cublas_tensor_core_handle_
.
reset
(
new
CublasHandleHolder
(
stream
.
stream
(),
CUBLAS_TENSOR_OP_MATH
));
#endif
}
}
void
InitCallbackManager
(
const
stream
::
CUDAStream
&
stream
)
{
callback_manager_
.
reset
(
new
StreamCallbackManager
(
stream
.
stream
()));
}
void
InitCuDNNContext
(
const
stream
::
CUDAStream
&
stream
)
{
if
(
dynload
::
HasCUDNN
())
{
auto
local_cudnn_version
=
dynload
::
cudnnGetVersion
()
/
100
;
auto
compile_cudnn_version
=
CUDNN_VERSION
/
100
;
if
(
local_cudnn_version
<
static_cast
<
size_t
>
(
compile_cudnn_version
))
{
LOG_FIRST_N
(
WARNING
,
1
)
<<
"WARNING: device: "
<<
place_
.
device
<<
". The installed Paddle is compiled with CUDNN "
<<
compile_cudnn_version
/
10
<<
"."
<<
compile_cudnn_version
%
10
<<
", but CUDNN version in your machine is "
<<
local_cudnn_version
/
10
<<
"."
<<
local_cudnn_version
%
10
<<
", which may cause serious incompatible bug. "
<<
"Please recompile or reinstall Paddle with compatible CUDNN "
"version."
;
}
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
cudnnCreate
(
&
cudnn_handle_
),
platform
::
errors
::
Fatal
(
"Failed to create Cudnn handle in DeviceContext"
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
cudnnSetStream
(
cudnn_handle_
,
stream
.
stream
()),
platform
::
errors
::
Fatal
(
"Failed to set stream for Cudnn handle in DeviceContext"
));
}
else
{
cudnn_handle_
=
nullptr
;
}
}
void
DestoryCuDNNContext
()
{
if
(
cudnn_handle_
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
cudnnDestroy
(
cudnn_handle_
),
platform
::
errors
::
Fatal
(
"Failed to destory Cudnn handle"
));
}
cudnn_handle_
=
nullptr
;
}
void
DestoryCuBlasContext
()
{
cublas_handle_
.
reset
();
cublas_tensor_core_handle_
.
reset
();
}
CUDAPlace
place_
;
std
::
unique_ptr
<
Eigen
::
GpuDevice
>
eigen_device_
;
std
::
unique_ptr
<
EigenCudaStreamDevice
>
eigen_stream_
;
stream
::
CUDAStream
stream_
;
cudnnHandle_t
cudnn_handle_
;
std
::
unique_ptr
<
CublasHandleHolder
>
cublas_handle_
;
std
::
unique_ptr
<
CublasHandleHolder
>
cublas_tensor_core_handle_
;
std
::
unique_ptr
<
StreamCallbackManager
>
callback_manager_
;
DISABLE_COPY_AND_ASSIGN
(
CUDAContext
);
};
class
CUDADeviceContext
:
public
DeviceContext
{
public:
explicit
CUDADeviceContext
(
CUDAPlace
place
);
...
...
@@ -267,7 +112,7 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Call cublas function safely. */
template
<
typename
Callback
>
inline
void
CublasCall
(
Callback
&&
callback
)
const
{
return
context
()
->
Cublas
Call
(
std
::
forward
<
Callback
>
(
callback
));
cublas_handle_
->
Call
(
std
::
forward
<
Callback
>
(
callback
));
}
/*! \brief Check whether tensor core is supported */
...
...
@@ -277,8 +122,11 @@ class CUDADeviceContext : public DeviceContext {
Tensor Core is not available, use DEFAULT_MATH instead. */
template
<
typename
Callback
>
inline
void
TensorCoreCublasCallIfAvailable
(
Callback
&&
callback
)
const
{
return
context
()
->
TensorCoreCublasCallIfAvailable
(
std
::
forward
<
Callback
>
(
callback
));
if
(
cublas_tensor_core_handle_
)
{
cublas_tensor_core_handle_
->
Call
(
std
::
forward
<
Callback
>
(
callback
));
}
else
{
cublas_handle_
->
Call
(
std
::
forward
<
Callback
>
(
callback
));
}
}
/*! \brief Return cudnn handle in the device context. */
...
...
@@ -306,43 +154,32 @@ class CUDADeviceContext : public DeviceContext {
template
<
typename
Callback
>
void
RecordEvent
(
cudaEvent_t
ev
,
Callback
callback
)
{
return
context
()
->
RecordEvent
(
ev
,
callback
);
callback
();
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaEventRecord
(
ev
,
stream_
));
}
template
<
typename
Callback
>
void
AddStreamCallback
(
Callback
&&
callback
)
const
{
return
context
()
->
AddStreamCallback
(
callback
);
}
void
WaitStreamCallback
()
const
{
return
context
()
->
WaitStreamCallback
();
}
void
ResetDefaultContext
(
const
enum
stream
::
Priority
&
priority
)
{
default_ctx_
.
reset
(
new
CUDAContext
(
place_
,
priority
));
}
void
ResetThreadContext
(
const
enum
stream
::
Priority
&
priority
)
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
ctx_mtx_
);
thread_ctx_
[
this
].
reset
(
new
CUDAContext
(
place_
,
priority
));
callback_manager_
->
AddCallback
(
callback
);
}
const
std
::
unique_ptr
<
CUDAContext
>&
context
()
const
{
if
(
!
thread_ctx_
.
count
(
this
))
{
return
default_ctx_
;
}
return
thread_ctx_
.
at
(
this
);
}
void
WaitStreamCallback
()
const
{
callback_manager_
->
Wait
();
}
private:
CUDAPlace
place_
;
std
::
unique_ptr
<
CUDAContext
>
default_ctx_
;
static
thread_local
std
::
unordered_map
<
const
CUDADeviceContext
*
,
std
::
unique_ptr
<
CUDAContext
>>
thread_ctx_
;
static
thread_local
std
::
mutex
ctx_mtx_
;
mutable
std
::
once_flag
init_cudnn_
;
std
::
unique_ptr
<
Eigen
::
GpuDevice
>
eigen_device_
;
std
::
unique_ptr
<
EigenCudaStreamDevice
>
eigen_stream_
;
cudaStream_t
stream_
;
cudnnHandle_t
cudnn_handle_
;
mutable
std
::
mutex
cudnn_handle_mtx_
;
std
::
unique_ptr
<
CublasHandleHolder
>
cublas_handle_
;
std
::
unique_ptr
<
CublasHandleHolder
>
cublas_tensor_core_handle_
;
#if defined(PADDLE_WITH_NCCL)
// NCCL communicator (single process version) for NCCL collective operations.
// NCCL collective operations provides fast collectives over multiple GPUs
...
...
@@ -360,6 +197,9 @@ class CUDADeviceContext : public DeviceContext {
int
max_threads_per_block_
;
dim3
max_grid_dim_size_
;
// StreamCallbackManager is thread-safe
std
::
unique_ptr
<
StreamCallbackManager
>
callback_manager_
;
DISABLE_COPY_AND_ASSIGN
(
CUDADeviceContext
);
};
...
...
paddle/fluid/platform/stream/CMakeLists.txt
已删除
100644 → 0
浏览文件 @
2fe0758f
IF
(
WITH_GPU
)
cc_library
(
cuda_stream SRCS cuda_stream.cc DEPS enforce
)
ENDIF
()
paddle/fluid/platform/stream/cuda_stream.cc
已删除
100644 → 0
浏览文件 @
2fe0758f
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/stream/cuda_stream.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
platform
{
namespace
stream
{
constexpr
int64_t
kHighPriority
=
-
1
;
constexpr
int64_t
kNormalPriority
=
0
;
constexpr
unsigned
int
kDefaultFlag
=
cudaStreamDefault
;
bool
CUDAStream
::
Init
(
const
Place
&
place
,
const
enum
Priority
&
priority
)
{
PADDLE_ENFORCE_EQ
(
is_gpu_place
(
place
),
true
,
platform
::
errors
::
InvalidArgument
(
"Cuda stream must be created using cuda place."
));
place_
=
place
;
CUDADeviceGuard
guard
(
boost
::
get
<
CUDAPlace
>
(
place_
).
device
);
if
(
priority
==
Priority
::
HIGH
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaStreamCreateWithPriority
(
&
stream_
,
kDefaultFlag
,
kHighPriority
),
platform
::
errors
::
Fatal
(
"High priority cuda stream creation failed."
));
}
else
if
(
priority
==
Priority
::
NORMAL
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaStreamCreateWithPriority
(
&
stream_
,
kDefaultFlag
,
kNormalPriority
),
platform
::
errors
::
Fatal
(
"Normal priority cuda stream creation failed."
));
}
VLOG
(
3
)
<<
"CUDAStream Init stream: "
<<
stream_
<<
", priority: "
<<
static_cast
<
int
>
(
priority
);
return
true
;
}
void
CUDAStream
::
Destroy
()
{
CUDADeviceGuard
guard
(
boost
::
get
<
CUDAPlace
>
(
place_
).
device
);
if
(
stream_
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaStreamDestroy
(
stream_
),
platform
::
errors
::
Fatal
(
"Cuda stream destruction failed."
));
}
stream_
=
nullptr
;
}
}
// namespace stream
}
// namespace platform
}
// namespace paddle
paddle/fluid/platform/stream/cuda_stream.h
已删除
100644 → 0
浏览文件 @
2fe0758f
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cstdint>
#include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/place.h"
namespace
paddle
{
namespace
platform
{
namespace
stream
{
#ifdef PADDLE_WITH_CUDA
enum
class
Priority
:
uint8_t
{
NIL
=
0x0
,
HIGH
=
0x1
,
NORMAL
=
0x2
,
};
class
CUDAStream
final
{
public:
CUDAStream
()
=
default
;
CUDAStream
(
const
Place
&
place
,
const
enum
Priority
&
priority
=
Priority
::
NORMAL
)
{
Init
(
place
,
priority
);
}
virtual
~
CUDAStream
()
{
Destroy
();
}
bool
Init
(
const
Place
&
place
,
const
enum
Priority
&
priority
=
Priority
::
NORMAL
);
const
cudaStream_t
&
stream
()
const
{
return
stream_
;
}
void
Destroy
();
private:
Place
place_
;
cudaStream_t
stream_
{
nullptr
};
Priority
priority_
{
Priority
::
NORMAL
};
DISABLE_COPY_AND_ASSIGN
(
CUDAStream
);
};
#endif
}
// namespace stream
}
// namespace platform
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录