Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
ec5204bd
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ec5204bd
编写于
9月 04, 2018
作者:
G
guochaorong
提交者:
GitHub
9月 04, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #13195 from PaddlePaddle/revert-13078-dev_CudnnHolder
Revert "Add CudnnHolder and use it in Conv and ConvTranspose op"
上级
71176417
151e169e
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
73 addition
and
196 deletion
+73
-196
paddle/fluid/framework/rw_lock.h
paddle/fluid/framework/rw_lock.h
+0
-71
paddle/fluid/operators/conv_cudnn_op.cu.cc
paddle/fluid/operators/conv_cudnn_op.cu.cc
+30
-27
paddle/fluid/operators/conv_transpose_cudnn_op.cu.cc
paddle/fluid/operators/conv_transpose_cudnn_op.cu.cc
+33
-26
paddle/fluid/platform/device_context.cc
paddle/fluid/platform/device_context.cc
+9
-65
paddle/fluid/platform/device_context.h
paddle/fluid/platform/device_context.h
+1
-7
未找到文件。
paddle/fluid/framework/rw_lock.h
浏览文件 @
ec5204bd
...
@@ -56,76 +56,5 @@ struct RWLock {
...
@@ -56,76 +56,5 @@ struct RWLock {
};
};
#endif
#endif
class
RWLockGuard
{
public:
enum
Status
{
kUnLock
,
kWRLock
,
kRDLock
};
RWLockGuard
(
RWLock
*
rw_lock
,
Status
init_status
)
:
lock_
(
rw_lock
),
status_
(
Status
::
kUnLock
)
{
switch
(
init_status
)
{
case
Status
::
kRDLock
:
{
RDLock
();
break
;
}
case
Status
::
kWRLock
:
{
WRLock
();
break
;
}
case
Status
::
kUnLock
:
{
break
;
}
}
}
void
WRLock
()
{
switch
(
status_
)
{
case
Status
::
kUnLock
:
{
lock_
->
WRLock
();
status_
=
Status
::
kWRLock
;
break
;
}
case
Status
::
kWRLock
:
{
break
;
}
case
Status
::
kRDLock
:
{
PADDLE_THROW
(
"Please unlock read lock first before invoking write lock."
);
break
;
}
}
}
void
RDLock
()
{
switch
(
status_
)
{
case
Status
::
kUnLock
:
{
lock_
->
RDLock
();
status_
=
Status
::
kRDLock
;
break
;
}
case
Status
::
kRDLock
:
{
break
;
}
case
Status
::
kWRLock
:
{
PADDLE_THROW
(
"Please unlock write lock first before invoking read lock."
);
break
;
}
}
}
void
UnLock
()
{
if
(
status_
!=
Status
::
kUnLock
)
{
lock_
->
UNLock
();
status_
=
Status
::
kUnLock
;
}
}
~
RWLockGuard
()
{
UnLock
();
}
private:
RWLock
*
lock_
;
Status
status_
;
};
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/operators/conv_cudnn_op.cu.cc
浏览文件 @
ec5204bd
...
@@ -118,6 +118,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
...
@@ -118,6 +118,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
output_channels
/
groups
*
output_height
*
output_width
*
output_depth
;
output_channels
/
groups
*
output_height
*
output_width
*
output_depth
;
int
group_offset_filter
=
filter
->
numel
()
/
groups
;
int
group_offset_filter
=
filter
->
numel
()
/
groups
;
// ------------------- cudnn conv workspace ---------------------
// ------------------- cudnn conv workspace ---------------------
void
*
cudnn_workspace
=
nullptr
;
size_t
workspace_size_in_bytes
;
// final workspace to allocate.
size_t
workspace_size_in_bytes
;
// final workspace to allocate.
size_t
workspace_size_limit
=
kCONV_CUDNN_WORKSPACE_LIMIT_BYTES
;
size_t
workspace_size_limit
=
kCONV_CUDNN_WORKSPACE_LIMIT_BYTES
;
if
(
user_workspace_size
>
0
)
{
if
(
user_workspace_size
>
0
)
{
...
@@ -158,18 +159,20 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
...
@@ -158,18 +159,20 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_LE
(
workspace_size_in_bytes
,
workspace_size_limit
,
PADDLE_ENFORCE_LE
(
workspace_size_in_bytes
,
workspace_size_limit
,
"workspace_size to be allocated exceeds the limit"
);
"workspace_size to be allocated exceeds the limit"
);
// Allocate on GPU memory
platform
::
CUDAPlace
gpu
=
boost
::
get
<
platform
::
CUDAPlace
>
(
ctx
.
GetPlace
());
cudnn_workspace
=
paddle
::
memory
::
Alloc
(
gpu
,
workspace_size_in_bytes
);
// ------------------- cudnn conv forward ---------------------
// ------------------- cudnn conv forward ---------------------
ScalingParamType
<
T
>
alpha
=
1.0
f
,
beta
=
0.0
f
;
ScalingParamType
<
T
>
alpha
=
1.0
f
,
beta
=
0.0
f
;
for
(
int
i
=
0
;
i
<
groups
;
i
++
)
{
for
(
int
i
=
0
;
i
<
groups
;
i
++
)
{
auto
cudnn_func
=
[
&
](
void
*
cudnn_workspace
)
{
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnConvolutionForward
(
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnConvolutionForward
(
handle
,
&
alpha
,
cudnn_input_desc
,
input_data
+
i
*
group_offset_in
,
handle
,
&
alpha
,
cudnn_input_desc
,
input_data
+
i
*
group_offset_in
,
cudnn_filter_desc
,
filter_data
+
i
*
group_offset_filter
,
cudnn_filter_desc
,
filter_data
+
i
*
group_offset_filter
,
cudnn_conv_desc
,
algo
,
cudnn_workspace
,
workspace_size_in_bytes
,
cudnn_conv_desc
,
algo
,
cudnn_workspace
,
workspace_size_in_bytes
,
&
beta
,
cudnn_output_desc
,
output_data
+
i
*
group_offset_out
));
&
beta
,
cudnn_output_desc
,
output_data
+
i
*
group_offset_out
));
};
dev_ctx
.
RunCudnnFuncWithWorkspace
(
cudnn_func
,
workspace_size_in_bytes
);
}
}
// Release the cudnn workspace
paddle
::
memory
::
Free
(
gpu
,
cudnn_workspace
);
}
}
};
};
...
@@ -311,7 +314,11 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
...
@@ -311,7 +314,11 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
cudnn_filter_desc
,
filter_algo
,
&
tmp_size
));
cudnn_filter_desc
,
filter_algo
,
&
tmp_size
));
workspace_size_in_bytes
=
std
::
max
(
workspace_size_in_bytes
,
tmp_size
);
workspace_size_in_bytes
=
std
::
max
(
workspace_size_in_bytes
,
tmp_size
);
}
}
// ------------------- cudnn conv workspace ---------------------
// Already on GPU
void
*
cudnn_workspace
=
nullptr
;
platform
::
CUDAPlace
gpu
=
boost
::
get
<
platform
::
CUDAPlace
>
(
ctx
.
GetPlace
());
cudnn_workspace
=
paddle
::
memory
::
Alloc
(
gpu
,
workspace_size_in_bytes
);
// ------------------- cudnn conv backward data ---------------------
// ------------------- cudnn conv backward data ---------------------
ScalingParamType
<
T
>
alpha
=
1.0
f
,
beta
=
0.0
f
;
ScalingParamType
<
T
>
alpha
=
1.0
f
,
beta
=
0.0
f
;
if
(
input_grad
)
{
if
(
input_grad
)
{
...
@@ -319,15 +326,12 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
...
@@ -319,15 +326,12 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
// Because beta is zero, it is unnecessary to reset input_grad.
// Because beta is zero, it is unnecessary to reset input_grad.
for
(
int
i
=
0
;
i
<
groups
;
i
++
)
{
for
(
int
i
=
0
;
i
<
groups
;
i
++
)
{
auto
cudnn_func
=
[
&
](
void
*
cudnn_workspace
)
{
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnConvolutionBackwardData
(
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnConvolutionBackwardData
(
handle
,
&
alpha
,
cudnn_filter_desc
,
handle
,
&
alpha
,
cudnn_filter_desc
,
filter_data
+
i
*
group_offset_filter
,
cudnn_output_grad_desc
,
filter_data
+
i
*
group_offset_filter
,
cudnn_output_grad_desc
,
output_grad_data
+
i
*
group_offset_out
,
cudnn_conv_desc
,
data_algo
,
output_grad_data
+
i
*
group_offset_out
,
cudnn_conv_desc
,
cudnn_workspace
,
workspace_size_in_bytes
,
&
beta
,
cudnn_input_desc
,
data_algo
,
cudnn_workspace
,
workspace_size_in_bytes
,
&
beta
,
input_grad_data
+
i
*
group_offset_in
));
cudnn_input_desc
,
input_grad_data
+
i
*
group_offset_in
));
};
dev_ctx
.
RunCudnnFuncWithWorkspace
(
cudnn_func
,
workspace_size_in_bytes
);
}
}
}
}
// ------------------- cudnn conv backward filter ---------------------
// ------------------- cudnn conv backward filter ---------------------
...
@@ -335,17 +339,16 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
...
@@ -335,17 +339,16 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
T
*
filter_grad_data
=
filter_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
filter_grad_data
=
filter_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
// Because beta is zero, it is unnecessary to reset filter_grad.
// Because beta is zero, it is unnecessary to reset filter_grad.
for
(
int
i
=
0
;
i
<
groups
;
i
++
)
{
for
(
int
i
=
0
;
i
<
groups
;
i
++
)
{
auto
cudnn_func
=
[
&
](
void
*
cudnn_workspace
)
{
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnConvolutionBackwardFilter
(
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnConvolutionBackwardFilter
(
handle
,
&
alpha
,
cudnn_input_desc
,
input_data
+
i
*
group_offset_in
,
handle
,
&
alpha
,
cudnn_input_desc
,
cudnn_output_grad_desc
,
output_grad_data
+
i
*
group_offset_out
,
input_data
+
i
*
group_offset_in
,
cudnn_output_grad_desc
,
cudnn_conv_desc
,
filter_algo
,
cudnn_workspace
,
output_grad_data
+
i
*
group_offset_out
,
cudnn_conv_desc
,
workspace_size_in_bytes
,
&
beta
,
cudnn_filter_desc
,
filter_algo
,
cudnn_workspace
,
workspace_size_in_bytes
,
&
beta
,
filter_grad_data
+
i
*
group_offset_filter
));
cudnn_filter_desc
,
filter_grad_data
+
i
*
group_offset_filter
));
};
dev_ctx
.
RunCudnnFuncWithWorkspace
(
cudnn_func
,
workspace_size_in_bytes
);
}
}
}
}
// Release the cudnn workspace
paddle
::
memory
::
Free
(
gpu
,
cudnn_workspace
);
}
}
};
};
...
...
paddle/fluid/operators/conv_transpose_cudnn_op.cu.cc
浏览文件 @
ec5204bd
...
@@ -76,6 +76,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
...
@@ -76,6 +76,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
conv_desc
.
descriptor
<
T
>
(
paddings
,
strides
,
dilations
);
conv_desc
.
descriptor
<
T
>
(
paddings
,
strides
,
dilations
);
// ------------------- cudnn conv workspace ---------------------
// ------------------- cudnn conv workspace ---------------------
void
*
cudnn_workspace
=
nullptr
;
size_t
workspace_size_in_bytes
;
// final workspace to allocate.
size_t
workspace_size_in_bytes
;
// final workspace to allocate.
size_t
workspace_size_limit
=
kConvCUDNNWorkspaceLimitBytes
;
size_t
workspace_size_limit
=
kConvCUDNNWorkspaceLimitBytes
;
if
(
user_workspace_size
>
0
)
{
if
(
user_workspace_size
>
0
)
{
...
@@ -99,21 +100,25 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
...
@@ -99,21 +100,25 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
handle
,
cudnn_filter_desc
,
cudnn_input_desc
,
cudnn_conv_desc
,
handle
,
cudnn_filter_desc
,
cudnn_input_desc
,
cudnn_conv_desc
,
cudnn_output_desc
,
algo
,
&
workspace_size_in_bytes
));
cudnn_output_desc
,
algo
,
&
workspace_size_in_bytes
));
// Allocate on GPU memory
platform
::
CUDAPlace
gpu
=
boost
::
get
<
platform
::
CUDAPlace
>
(
ctx
.
GetPlace
());
cudnn_workspace
=
paddle
::
memory
::
Alloc
(
gpu
,
workspace_size_in_bytes
);
// ------------------- cudnn conv transpose forward ---------------------
// ------------------- cudnn conv transpose forward ---------------------
int
input_offset
=
input
->
numel
()
/
input
->
dims
()[
0
]
/
groups
;
int
input_offset
=
input
->
numel
()
/
input
->
dims
()[
0
]
/
groups
;
int
output_offset
=
output
->
numel
()
/
output
->
dims
()[
0
]
/
groups
;
int
output_offset
=
output
->
numel
()
/
output
->
dims
()[
0
]
/
groups
;
int
filter_offset
=
filter
->
numel
()
/
groups
;
int
filter_offset
=
filter
->
numel
()
/
groups
;
T
alpha
=
1.0
f
,
beta
=
0.0
f
;
T
alpha
=
1.0
f
,
beta
=
0.0
f
;
for
(
int
g
=
0
;
g
<
groups
;
g
++
)
{
for
(
int
g
=
0
;
g
<
groups
;
g
++
)
{
auto
cudnn_func
=
[
&
](
void
*
cudnn_workspace
)
{
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnConvolutionBackwardData
(
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnConvolutionBackwardData
(
handle
,
&
alpha
,
cudnn_filter_desc
,
filter_data
+
filter_offset
*
g
,
handle
,
&
alpha
,
cudnn_filter_desc
,
filter_data
+
filter_offset
*
g
,
cudnn_input_desc
,
input_data
+
input_offset
*
g
,
cudnn_conv_desc
,
cudnn_input_desc
,
input_data
+
input_offset
*
g
,
cudnn_conv_desc
,
algo
,
cudnn_workspace
,
workspace_size_in_bytes
,
&
beta
,
algo
,
cudnn_workspace
,
workspace_size_in_bytes
,
&
beta
,
cudnn_output_desc
,
output_data
+
output_offset
*
g
));
cudnn_output_desc
,
output_data
+
output_offset
*
g
));
};
dev_ctx
.
RunCudnnFuncWithWorkspace
(
cudnn_func
,
workspace_size_in_bytes
);
}
}
// Release the cudnn workspace
paddle
::
memory
::
Free
(
gpu
,
cudnn_workspace
);
}
}
};
};
...
@@ -201,6 +206,11 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
...
@@ -201,6 +206,11 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
std
::
max
(
workspace_size_in_bytes
,
bwd_filter_ws_size
);
std
::
max
(
workspace_size_in_bytes
,
bwd_filter_ws_size
);
}
}
// ------------------- cudnn conv workspace ---------------------
// Already on GPU
void
*
cudnn_workspace
=
nullptr
;
platform
::
CUDAPlace
gpu
=
boost
::
get
<
platform
::
CUDAPlace
>
(
ctx
.
GetPlace
());
cudnn_workspace
=
paddle
::
memory
::
Alloc
(
gpu
,
workspace_size_in_bytes
);
// ------------------- cudnn conv backward data ---------------------
// ------------------- cudnn conv backward data ---------------------
// FIXME(typhoonzero): template type T may not be the same as cudnn call.
// FIXME(typhoonzero): template type T may not be the same as cudnn call.
int
input_offset
=
input
->
numel
()
/
input
->
dims
()[
0
]
/
groups
;
int
input_offset
=
input
->
numel
()
/
input
->
dims
()[
0
]
/
groups
;
...
@@ -212,15 +222,12 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
...
@@ -212,15 +222,12 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
T
*
input_grad_data
=
input_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
input_grad_data
=
input_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
// Because beta is zero, it is unnecessary to reset input_grad.
// Because beta is zero, it is unnecessary to reset input_grad.
for
(
int
g
=
0
;
g
<
groups
;
g
++
)
{
for
(
int
g
=
0
;
g
<
groups
;
g
++
)
{
auto
cudnn_func
=
[
&
](
void
*
cudnn_workspace
)
{
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnConvolutionForward
(
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnConvolutionForward
(
handle
,
&
alpha
,
cudnn_output_desc
,
handle
,
&
alpha
,
cudnn_output_desc
,
output_grad_data
+
output_grad_offset
*
g
,
cudnn_filter_desc
,
output_grad_data
+
output_grad_offset
*
g
,
cudnn_filter_desc
,
filter_data
+
filter_offset
*
g
,
cudnn_conv_desc
,
data_algo
,
filter_data
+
filter_offset
*
g
,
cudnn_conv_desc
,
data_algo
,
cudnn_workspace
,
workspace_size_in_bytes
,
&
beta
,
cudnn_input_desc
,
cudnn_workspace
,
workspace_size_in_bytes
,
&
beta
,
cudnn_input_desc
,
input_grad_data
+
input_offset
*
g
));
input_grad_data
+
input_offset
*
g
));
};
dev_ctx
.
RunCudnnFuncWithWorkspace
(
cudnn_func
,
workspace_size_in_bytes
);
}
}
}
}
...
@@ -230,17 +237,17 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
...
@@ -230,17 +237,17 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
// Because beta is zero, it is unnecessary to reset filter_grad.
// Because beta is zero, it is unnecessary to reset filter_grad.
// Gradient with respect to the filter
// Gradient with respect to the filter
for
(
int
g
=
0
;
g
<
groups
;
g
++
)
{
for
(
int
g
=
0
;
g
<
groups
;
g
++
)
{
auto
cudnn_func
=
[
&
](
void
*
cudnn_workspace
)
{
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnConvolutionBackwardFilter
(
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnConvolutionBackwardFilter
(
handle
,
&
alpha
,
cudnn_output_desc
,
handle
,
&
alpha
,
cudnn_output_desc
,
output_grad_data
+
output_grad_offset
*
g
,
cudnn_input_desc
,
output_grad_data
+
output_grad_offset
*
g
,
cudnn_input_desc
,
input_data
+
input_offset
*
g
,
cudnn_conv_desc
,
filter_algo
,
input_data
+
input_offset
*
g
,
cudnn_conv_desc
,
filter_algo
,
cudnn_workspace
,
workspace_size_in_bytes
,
&
beta
,
cudnn_filter_desc
,
cudnn_workspace
,
workspace_size_in_bytes
,
&
beta
,
filter_grad_data
+
filter_offset
*
g
));
cudnn_filter_desc
,
filter_grad_data
+
filter_offset
*
g
));
};
dev_ctx
.
RunCudnnFuncWithWorkspace
(
cudnn_func
,
workspace_size_in_bytes
);
}
}
}
}
// Release the cudnn workspace
paddle
::
memory
::
Free
(
gpu
,
cudnn_workspace
);
}
}
};
};
...
...
paddle/fluid/platform/device_context.cc
浏览文件 @
ec5204bd
...
@@ -16,9 +16,6 @@ limitations under the License. */
...
@@ -16,9 +16,6 @@ limitations under the License. */
#include <vector>
#include <vector>
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/memory/memory.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/framework/rw_lock.h"
#endif
namespace
paddle
{
namespace
paddle
{
namespace
platform
{
namespace
platform
{
...
@@ -145,59 +142,7 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface {
...
@@ -145,59 +142,7 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface {
mutable
unsigned
int
*
semaphore_
;
mutable
unsigned
int
*
semaphore_
;
};
};
class
CudnnHolder
{
CUDADeviceContext
::
CUDADeviceContext
(
CUDAPlace
place
)
:
place_
(
place
)
{
public:
CudnnHolder
(
const
cudaStream_t
*
stream
,
const
CUDAPlace
&
place
)
:
workspace_
(
nullptr
),
workspace_len_
(
0
),
stream_
(
stream
),
place_
(
place
)
{
PADDLE_ENFORCE
(
dynload
::
cudnnCreate
(
&
cudnn_handle_
));
PADDLE_ENFORCE
(
dynload
::
cudnnSetStream
(
cudnn_handle_
,
*
stream_
));
}
cudnnHandle_t
cudnn_handle
()
const
{
return
cudnn_handle_
;
}
void
RunFunc
(
const
std
::
function
<
void
(
void
*
)
>&
cudnn_func
,
size_t
required_workspace_len
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mtx_
);
if
(
required_workspace_len
>
workspace_len_
)
{
ReallocateWorkspace
(
required_workspace_len
);
}
cudnn_func
(
workspace_
);
}
~
CudnnHolder
()
{
PADDLE_ENFORCE
(
dynload
::
cudnnDestroy
(
cudnn_handle_
));
if
(
workspace_
!=
nullptr
)
{
paddle
::
memory
::
Free
(
place_
,
workspace_
);
}
}
private:
void
ReallocateWorkspace
(
size_t
required_workspace_len
)
{
if
(
required_workspace_len
<=
workspace_len_
)
{
return
;
}
void
*
new_workspace
=
paddle
::
memory
::
Alloc
(
place_
,
required_workspace_len
);
if
(
workspace_
!=
nullptr
)
{
// Maybe someone is using the current workspace
PADDLE_ENFORCE
(
cudaStreamSynchronize
(
*
stream_
));
paddle
::
memory
::
Free
(
place_
,
workspace_
);
}
workspace_
=
new_workspace
;
workspace_len_
=
required_workspace_len
;
}
cudnnHandle_t
cudnn_handle_
;
void
*
workspace_
;
size_t
workspace_len_
;
const
cudaStream_t
*
stream_
;
// not owned;
const
CUDAPlace
place_
;
std
::
mutex
mtx_
;
};
CUDADeviceContext
::
CUDADeviceContext
(
CUDAPlace
place
)
:
place_
(
place
),
cudnn_holder_
(
nullptr
)
{
SetDeviceId
(
place_
.
device
);
SetDeviceId
(
place_
.
device
);
compute_capability
=
GetCUDAComputeCapability
(
place_
.
device
);
compute_capability
=
GetCUDAComputeCapability
(
place_
.
device
);
multi_process
=
GetCUDAMultiProcessors
(
place_
.
device
);
multi_process
=
GetCUDAMultiProcessors
(
place_
.
device
);
...
@@ -209,7 +154,10 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
...
@@ -209,7 +154,10 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
PADDLE_ENFORCE
(
dynload
::
cublasCreate
(
&
cublas_handle_
));
PADDLE_ENFORCE
(
dynload
::
cublasCreate
(
&
cublas_handle_
));
PADDLE_ENFORCE
(
dynload
::
cublasSetStream
(
cublas_handle_
,
stream_
));
PADDLE_ENFORCE
(
dynload
::
cublasSetStream
(
cublas_handle_
,
stream_
));
if
(
dynload
::
HasCUDNN
())
{
if
(
dynload
::
HasCUDNN
())
{
cudnn_holder_
.
reset
(
new
CudnnHolder
(
&
stream_
,
place
));
PADDLE_ENFORCE
(
dynload
::
cudnnCreate
(
&
cudnn_handle_
));
PADDLE_ENFORCE
(
dynload
::
cudnnSetStream
(
cudnn_handle_
,
stream_
));
}
else
{
cudnn_handle_
=
nullptr
;
}
}
}
}
...
@@ -217,6 +165,9 @@ CUDADeviceContext::~CUDADeviceContext() {
...
@@ -217,6 +165,9 @@ CUDADeviceContext::~CUDADeviceContext() {
SetDeviceId
(
place_
.
device
);
SetDeviceId
(
place_
.
device
);
Wait
();
Wait
();
PADDLE_ENFORCE
(
dynload
::
cublasDestroy
(
cublas_handle_
));
PADDLE_ENFORCE
(
dynload
::
cublasDestroy
(
cublas_handle_
));
if
(
cudnn_handle_
!=
nullptr
)
{
PADDLE_ENFORCE
(
dynload
::
cudnnDestroy
(
cudnn_handle_
));
}
eigen_stream_
.
reset
();
eigen_stream_
.
reset
();
eigen_device_
.
reset
();
eigen_device_
.
reset
();
PADDLE_ENFORCE
(
cudaStreamDestroy
(
stream_
));
PADDLE_ENFORCE
(
cudaStreamDestroy
(
stream_
));
...
@@ -245,14 +196,7 @@ cublasHandle_t CUDADeviceContext::cublas_handle() const {
...
@@ -245,14 +196,7 @@ cublasHandle_t CUDADeviceContext::cublas_handle() const {
return
cublas_handle_
;
return
cublas_handle_
;
}
}
cudnnHandle_t
CUDADeviceContext
::
cudnn_handle
()
const
{
cudnnHandle_t
CUDADeviceContext
::
cudnn_handle
()
const
{
return
cudnn_handle_
;
}
return
cudnn_holder_
->
cudnn_handle
();
}
void
CUDADeviceContext
::
RunCudnnFuncWithWorkspace
(
const
std
::
function
<
void
(
void
*
)
>&
cudnn_func
,
size_t
workspace_len
)
const
{
cudnn_holder_
->
RunFunc
(
cudnn_func
,
workspace_len
);
}
cudaStream_t
CUDADeviceContext
::
stream
()
const
{
return
stream_
;
}
cudaStream_t
CUDADeviceContext
::
stream
()
const
{
return
stream_
;
}
...
...
paddle/fluid/platform/device_context.h
浏览文件 @
ec5204bd
...
@@ -69,7 +69,6 @@ struct DefaultDeviceContextType<platform::CPUPlace> {
...
@@ -69,7 +69,6 @@ struct DefaultDeviceContextType<platform::CPUPlace> {
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
class
EigenCudaStreamDevice
;
class
EigenCudaStreamDevice
;
class
CudnnHolder
;
class
CUDADeviceContext
:
public
DeviceContext
{
class
CUDADeviceContext
:
public
DeviceContext
{
public:
public:
...
@@ -97,11 +96,6 @@ class CUDADeviceContext : public DeviceContext {
...
@@ -97,11 +96,6 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return cudnn handle in the device context. */
/*! \brief Return cudnn handle in the device context. */
cudnnHandle_t
cudnn_handle
()
const
;
cudnnHandle_t
cudnn_handle
()
const
;
/*! \brief Run a cudnn function with the workspace provided by
* CUDADeviceContext */
void
RunCudnnFuncWithWorkspace
(
const
std
::
function
<
void
(
void
*
)
>&
cudnn_func
,
size_t
workspace_len
)
const
;
/*! \brief Return cuda stream in the device context. */
/*! \brief Return cuda stream in the device context. */
cudaStream_t
stream
()
const
;
cudaStream_t
stream
()
const
;
...
@@ -117,8 +111,8 @@ class CUDADeviceContext : public DeviceContext {
...
@@ -117,8 +111,8 @@ class CUDADeviceContext : public DeviceContext {
std
::
unique_ptr
<
Eigen
::
GpuDevice
>
eigen_device_
;
std
::
unique_ptr
<
Eigen
::
GpuDevice
>
eigen_device_
;
std
::
unique_ptr
<
EigenCudaStreamDevice
>
eigen_stream_
;
std
::
unique_ptr
<
EigenCudaStreamDevice
>
eigen_stream_
;
std
::
unique_ptr
<
CudnnHolder
>
cudnn_holder_
;
cudaStream_t
stream_
;
cudaStream_t
stream_
;
cudnnHandle_t
cudnn_handle_
;
cublasHandle_t
cublas_handle_
;
cublasHandle_t
cublas_handle_
;
int
compute_capability
;
int
compute_capability
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录