Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
04bfd5c1
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
04bfd5c1
编写于
6年前
作者:
F
fengjiayi
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add CudnnHolder to manage cudnn_handle and workspace
上级
a615ad46
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
51 addition
and
10 deletion
+51
-10
paddle/fluid/platform/device_context.cc
paddle/fluid/platform/device_context.cc
+45
-9
paddle/fluid/platform/device_context.h
paddle/fluid/platform/device_context.h
+6
-1
未找到文件。
paddle/fluid/platform/device_context.cc
浏览文件 @
04bfd5c1
...
@@ -142,7 +142,43 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface {
...
@@ -142,7 +142,43 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface {
mutable
unsigned
int
*
semaphore_
;
mutable
unsigned
int
*
semaphore_
;
};
};
CUDADeviceContext
::
CUDADeviceContext
(
CUDAPlace
place
)
:
place_
(
place
)
{
class
CudnnHolder
{
public:
CudnnHolder
(
const
cudaStream_t
*
stream
,
const
CUDAPlace
&
place
)
:
stream_
(
stream
),
place_
(
place
),
workspace_
(
nullptr
),
workspace_len_
(
0
)
{
PADDLE_ENFORCE
(
dynload
::
cudnnCreate
(
&
cudnn_handle_
));
PADDLE_ENFORCE
(
dynload
::
cudnnSetStream
(
cudnn_handle_
,
stream_
));
}
cudnnHandle_t
get_cudnn_handle
()
const
{
return
cudnn_handle_
;
}
void
*
get_workspace
(
size_t
required_len
)
{
if
(
required_len
>
workspace_len_
)
{
void
*
new_workspace
=
paddle
::
memory
::
Alloc
(
place_
,
required_len
);
if
(
workspace_
!=
nullptr
)
{
// Maybe someone is using the current workspace
PADDLE_ENFORCE
(
cudaStreamSynchronize
(
stream_
));
PADDLE_ENFORCE
(
cudaGetLastError
());
paddle
::
memory
::
Free
(
place_
,
workspace_
);
}
workspace_
=
new_workspace
;
}
return
workspace_
}
~
CudnnHolder
()
{
PADDLE_ENFORCE
(
dynload
::
cudnnDestroy
(
cudnn_handle_
));
}
private:
cudnnHandle_t
cudnn_handle_
;
void
*
workspace_
;
size_t
workspace_len_
;
const
cudaStream_t
*
stream_
;
// not owned;
const
CUDAPlace
place_
;
};
CUDADeviceContext
::
CUDADeviceContext
(
CUDAPlace
place
)
:
place_
(
place
),
cudnn_holder_
(
nullptr
)
{
SetDeviceId
(
place_
.
device
);
SetDeviceId
(
place_
.
device
);
compute_capability
=
GetCUDAComputeCapability
(
place_
.
device
);
compute_capability
=
GetCUDAComputeCapability
(
place_
.
device
);
multi_process
=
GetCUDAMultiProcessors
(
place_
.
device
);
multi_process
=
GetCUDAMultiProcessors
(
place_
.
device
);
...
@@ -154,10 +190,7 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
...
@@ -154,10 +190,7 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
PADDLE_ENFORCE
(
dynload
::
cublasCreate
(
&
cublas_handle_
));
PADDLE_ENFORCE
(
dynload
::
cublasCreate
(
&
cublas_handle_
));
PADDLE_ENFORCE
(
dynload
::
cublasSetStream
(
cublas_handle_
,
stream_
));
PADDLE_ENFORCE
(
dynload
::
cublasSetStream
(
cublas_handle_
,
stream_
));
if
(
dynload
::
HasCUDNN
())
{
if
(
dynload
::
HasCUDNN
())
{
PADDLE_ENFORCE
(
dynload
::
cudnnCreate
(
&
cudnn_handle_
));
cudnn_holder_
.
reset
(
new
CudnnHolder
(
&
stream_
,
place
));
PADDLE_ENFORCE
(
dynload
::
cudnnSetStream
(
cudnn_handle_
,
stream_
));
}
else
{
cudnn_handle_
=
nullptr
;
}
}
}
}
...
@@ -165,9 +198,6 @@ CUDADeviceContext::~CUDADeviceContext() {
...
@@ -165,9 +198,6 @@ CUDADeviceContext::~CUDADeviceContext() {
SetDeviceId
(
place_
.
device
);
SetDeviceId
(
place_
.
device
);
Wait
();
Wait
();
PADDLE_ENFORCE
(
dynload
::
cublasDestroy
(
cublas_handle_
));
PADDLE_ENFORCE
(
dynload
::
cublasDestroy
(
cublas_handle_
));
if
(
cudnn_handle_
!=
nullptr
)
{
PADDLE_ENFORCE
(
dynload
::
cudnnDestroy
(
cudnn_handle_
));
}
eigen_stream_
.
reset
();
eigen_stream_
.
reset
();
eigen_device_
.
reset
();
eigen_device_
.
reset
();
PADDLE_ENFORCE
(
cudaStreamDestroy
(
stream_
));
PADDLE_ENFORCE
(
cudaStreamDestroy
(
stream_
));
...
@@ -196,7 +226,13 @@ cublasHandle_t CUDADeviceContext::cublas_handle() const {
...
@@ -196,7 +226,13 @@ cublasHandle_t CUDADeviceContext::cublas_handle() const {
return
cublas_handle_
;
return
cublas_handle_
;
}
}
cudnnHandle_t
CUDADeviceContext
::
cudnn_handle
()
const
{
return
cudnn_handle_
;
}
cudnnHandle_t
CUDADeviceContext
::
cudnn_handle
()
const
{
return
cudnn_holder_
->
get_cudnn_handle
();
}
void
*
cudnn_workspace
(
size_t
required_len
)
const
{
return
cudnn_holder_
->
get_workspace
(
required_len
);
}
cudaStream_t
CUDADeviceContext
::
stream
()
const
{
return
stream_
;
}
cudaStream_t
CUDADeviceContext
::
stream
()
const
{
return
stream_
;
}
...
...
This diff is collapsed.
Click to expand it.
paddle/fluid/platform/device_context.h
浏览文件 @
04bfd5c1
...
@@ -69,6 +69,7 @@ struct DefaultDeviceContextType<platform::CPUPlace> {
...
@@ -69,6 +69,7 @@ struct DefaultDeviceContextType<platform::CPUPlace> {
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
class
EigenCudaStreamDevice
;
class
EigenCudaStreamDevice
;
class
CUDNNHolder
;
class
CUDADeviceContext
:
public
DeviceContext
{
class
CUDADeviceContext
:
public
DeviceContext
{
public:
public:
...
@@ -96,6 +97,10 @@ class CUDADeviceContext : public DeviceContext {
...
@@ -96,6 +97,10 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return cudnn handle in the device context. */
/*! \brief Return cudnn handle in the device context. */
cudnnHandle_t
cudnn_handle
()
const
;
cudnnHandle_t
cudnn_handle
()
const
;
/*! \brief Return a cudnn workspace whose length is greater than the
* 'required_len'. */
void
*
cudnn_workspace
(
size_t
required_len
)
const
;
/*! \brief Return cuda stream in the device context. */
/*! \brief Return cuda stream in the device context. */
cudaStream_t
stream
()
const
;
cudaStream_t
stream
()
const
;
...
@@ -111,8 +116,8 @@ class CUDADeviceContext : public DeviceContext {
...
@@ -111,8 +116,8 @@ class CUDADeviceContext : public DeviceContext {
std
::
unique_ptr
<
Eigen
::
GpuDevice
>
eigen_device_
;
std
::
unique_ptr
<
Eigen
::
GpuDevice
>
eigen_device_
;
std
::
unique_ptr
<
EigenCudaStreamDevice
>
eigen_stream_
;
std
::
unique_ptr
<
EigenCudaStreamDevice
>
eigen_stream_
;
std
::
unique_ptr
<
CudnnHolder
>
cudnn_holder_
;
cudaStream_t
stream_
;
cudaStream_t
stream_
;
cudnnHandle_t
cudnn_handle_
;
cublasHandle_t
cublas_handle_
;
cublasHandle_t
cublas_handle_
;
int
compute_capability
;
int
compute_capability
;
...
...
This diff is collapsed.
Click to expand it.
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录
新手
引导
客服
返回
顶部