Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
24103cbb
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
24103cbb
编写于
2月 08, 2022
作者:
W
Wilber
提交者:
GitHub
2月 08, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[PTEN] Update gpu_context. (#39359)
* gpu_context.. * update * update * update
上级
0fee0044
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
280 addition
and
237 deletion
+280
-237
paddle/fluid/operators/conv_cudnn_helper.h
paddle/fluid/operators/conv_cudnn_helper.h
+0
-1
paddle/fluid/operators/math/im2col.cu
paddle/fluid/operators/math/im2col.cu
+36
-20
paddle/fluid/operators/math/vol2col.cu
paddle/fluid/operators/math/vol2col.cu
+159
-157
paddle/fluid/platform/device_context.cc
paddle/fluid/platform/device_context.cc
+15
-4
paddle/fluid/platform/device_context.h
paddle/fluid/platform/device_context.h
+2
-1
paddle/pten/backends/gpu/gpu_context.cc
paddle/pten/backends/gpu/gpu_context.cc
+18
-52
paddle/pten/backends/gpu/gpu_context.h
paddle/pten/backends/gpu/gpu_context.h
+50
-2
未找到文件。
paddle/fluid/operators/conv_cudnn_helper.h
浏览文件 @
24103cbb
...
...
@@ -288,7 +288,6 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
workspace_handle
=
dev_ctx
.
cudnn_workspace_handle
();
auto
&
temp
=
ctx
.
cuda_device_context
();
AlgorithmsCache
<
algo_t
>&
algo_cache
=
*
(
framework
::
ConvSearchCache
::
Instance
().
GetForward
());
...
...
paddle/fluid/operators/math/im2col.cu
浏览文件 @
24103cbb
...
...
@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/pten/backends/gpu/gpu_context.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -73,12 +74,12 @@ __global__ void im2col(const T* data_im, int num_outs, int im_height,
* col =
* [input_channels, filter_height, filter_width, output_height, output_width]
*/
template
<
class
T
>
class
Im2ColFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kCFO
,
platform
::
CUDADeviceContext
,
T
>
{
template
<
class
DeviceContext
,
class
T
>
class
Im2ColFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kCFO
,
DeviceContext
,
T
>
{
public:
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
framework
::
Tensor
&
im
,
const
std
::
vector
<
int
>&
dilation
,
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
im
,
const
std
::
vector
<
int
>&
dilation
,
const
std
::
vector
<
int
>&
stride
,
const
std
::
vector
<
int
>&
padding
,
framework
::
Tensor
*
col
,
const
DataLayout
data_layout
)
{
...
...
@@ -184,12 +185,11 @@ __global__ void col2im(int n, const T* data_col, int im_height, int im_width,
* col =
* [input_channels, filter_height, filter_width, output_height, output_width]
*/
template
<
class
T
>
class
Col2ImFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kCFO
,
platform
::
CUDADeviceContext
,
T
>
{
template
<
class
DeviceContext
,
class
T
>
class
Col2ImFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kCFO
,
DeviceContext
,
T
>
{
public:
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
framework
::
Tensor
&
col
,
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
col
,
const
std
::
vector
<
int
>&
dilation
,
const
std
::
vector
<
int
>&
stride
,
const
std
::
vector
<
int
>&
padding
,
framework
::
Tensor
*
im
,
...
...
@@ -257,10 +257,18 @@ template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
platform
::
CUDADeviceContext
,
float
>;
template
class
Im2ColFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kCFO
,
platform
::
CUDADeviceContext
,
double
>;
template
class
Im2ColFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kCFO
,
pten
::
GPUContext
,
float
>;
template
class
Im2ColFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kCFO
,
pten
::
GPUContext
,
double
>;
template
class
Col2ImFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kCFO
,
platform
::
CUDADeviceContext
,
float
>;
template
class
Col2ImFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kCFO
,
platform
::
CUDADeviceContext
,
double
>;
template
class
Col2ImFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kCFO
,
pten
::
GPUContext
,
float
>;
template
class
Col2ImFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kCFO
,
pten
::
GPUContext
,
double
>;
template
<
class
T
>
__global__
void
im2colOCF
(
const
T
*
im_data
,
int
im_channels
,
int
im_height
,
...
...
@@ -299,12 +307,12 @@ __global__ void im2colOCF(const T* im_data, int im_channels, int im_height,
* col =
* [output_height, output_width, input_channels, filter_height, filter_width]
*/
template
<
class
T
>
class
Im2ColFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kOCF
,
platform
::
CUDADeviceContext
,
T
>
{
template
<
class
DeviceContext
,
class
T
>
class
Im2ColFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kOCF
,
DeviceContext
,
T
>
{
public:
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
framework
::
Tensor
&
im
,
const
std
::
vector
<
int
>&
dilation
,
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
im
,
const
std
::
vector
<
int
>&
dilation
,
const
std
::
vector
<
int
>&
stride
,
const
std
::
vector
<
int
>&
padding
,
framework
::
Tensor
*
col
,
const
DataLayout
data_layout
)
{
...
...
@@ -390,12 +398,11 @@ __global__ void col2imOCF(const T* col_data, int im_channels, int im_height,
* col =
* [output_height, output_width, input_channels, filter_height, filter_width]
*/
template
<
class
T
>
class
Col2ImFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kOCF
,
platform
::
CUDADeviceContext
,
T
>
{
template
<
class
DeviceContext
,
class
T
>
class
Col2ImFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kOCF
,
DeviceContext
,
T
>
{
public:
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
framework
::
Tensor
&
col
,
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
col
,
const
std
::
vector
<
int
>&
dilation
,
const
std
::
vector
<
int
>&
stride
,
const
std
::
vector
<
int
>&
padding
,
framework
::
Tensor
*
im
,
...
...
@@ -464,10 +471,19 @@ template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
platform
::
CUDADeviceContext
,
float
>;
template
class
Im2ColFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kOCF
,
platform
::
CUDADeviceContext
,
double
>;
template
class
Im2ColFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kOCF
,
pten
::
GPUContext
,
float
>;
template
class
Im2ColFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kOCF
,
pten
::
GPUContext
,
double
>;
template
class
Col2ImFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kOCF
,
platform
::
CUDADeviceContext
,
float
>;
template
class
Col2ImFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kOCF
,
platform
::
CUDADeviceContext
,
double
>;
template
class
Col2ImFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kOCF
,
pten
::
GPUContext
,
float
>;
template
class
Col2ImFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kOCF
,
pten
::
GPUContext
,
double
>;
}
// namespace math
}
// namespace operators
...
...
paddle/fluid/operators/math/vol2col.cu
浏览文件 @
24103cbb
...
...
@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/vol2col.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/pten/backends/gpu/gpu_context.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -82,13 +83,13 @@ __global__ void vol2col(int num_kernels, const T* data_vol, int depth,
* [input_channels, filter_depth, filter_height, filter_width,
* output_depth, output_height, output_width]
*/
template
<
class
T
>
class
Vol2ColFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
public:
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
framework
::
Tensor
&
vol
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
// template <class DeviceContext,
class T>
// class Vol2ColFunctor
{
//
public:
template
<
class
DeviceContext
,
class
T
>
void
Vol2ColFunctor
<
DeviceContext
,
T
>::
operator
()(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
vol
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
col
,
const
DataLayout
data_layout
)
const
{
PADDLE_ENFORCE_EQ
(
vol
.
dims
().
size
(),
4
,
...
...
@@ -126,8 +127,7 @@ class Vol2ColFunctor<platform::CUDADeviceContext, T> {
((
dilations
[
0
]
*
(
filter_depth
-
1
)
+
1
)))
/
strides
[
0
]
+
1
;
PADDLE_ENFORCE_EQ
(
input_depth_tmp
,
output_depth
,
PADDLE_ENFORCE_EQ
(
input_depth_tmp
,
output_depth
,
platform
::
errors
::
InvalidArgument
(
"input_depth(%d) and output_depth(%d) are mismatching."
,
input_depth_tmp
,
output_depth
));
...
...
@@ -144,8 +144,7 @@ class Vol2ColFunctor<platform::CUDADeviceContext, T> {
((
dilations
[
2
]
*
(
filter_width
-
1
)
+
1
)))
/
strides
[
2
]
+
1
;
PADDLE_ENFORCE_EQ
(
input_width_tmp
,
output_width
,
PADDLE_ENFORCE_EQ
(
input_width_tmp
,
output_width
,
platform
::
errors
::
InvalidArgument
(
"input_width(%d) and output_width(%d) are mismatching."
,
input_width_tmp
,
output_width
));
...
...
@@ -167,8 +166,8 @@ class Vol2ColFunctor<platform::CUDADeviceContext, T> {
filter_width
,
strides
[
0
],
strides
[
1
],
strides
[
2
],
pad_d_forth
,
pad_h_up
,
pad_w_left
,
output_depth
,
output_height
,
output_width
,
col
->
data
<
T
>
(),
data_layout
);
}
};
}
//
};
template
<
class
T
>
__global__
void
col2vol
(
int
num_kernels
,
const
T
*
data_col
,
int
depth
,
...
...
@@ -249,13 +248,13 @@ __global__ void col2vol(int num_kernels, const T* data_col, int depth,
* [input_channels, filter_depth, filter_height, filter_width,
* output_depth, output_height, output_width]
*/
template
<
class
T
>
class
Col2VolFunctor
<
platform
::
CUDA
DeviceContext
,
T
>
{
public:
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
framework
::
Tensor
&
col
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
// template <class DeviceContext,
class T>
// class Col2VolFunctor<
DeviceContext, T> {
//
public:
template
<
class
DeviceContext
,
class
T
>
void
Col2VolFunctor
<
DeviceContext
,
T
>::
operator
()(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
col
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
vol
,
const
DataLayout
data_layout
)
const
{
PADDLE_ENFORCE_EQ
(
vol
->
dims
().
size
(),
4
,
...
...
@@ -294,8 +293,7 @@ class Col2VolFunctor<platform::CUDADeviceContext, T> {
((
dilations
[
0
]
*
(
filter_depth
-
1
)
+
1
)))
/
strides
[
0
]
+
1
;
PADDLE_ENFORCE_EQ
(
input_depth_tmp
,
output_depth
,
PADDLE_ENFORCE_EQ
(
input_depth_tmp
,
output_depth
,
platform
::
errors
::
InvalidArgument
(
"input_depth(%d) and output_depth(%d) are mismatching."
,
input_depth_tmp
,
output_depth
));
...
...
@@ -312,8 +310,7 @@ class Col2VolFunctor<platform::CUDADeviceContext, T> {
((
dilations
[
2
]
*
(
filter_width
-
1
)
+
1
)))
/
strides
[
2
]
+
1
;
PADDLE_ENFORCE_EQ
(
input_width_tmp
,
output_width
,
PADDLE_ENFORCE_EQ
(
input_width_tmp
,
output_width
,
platform
::
errors
::
InvalidArgument
(
"input_width(%d) and output_width(%d) are mismatching."
,
input_width_tmp
,
output_width
));
...
...
@@ -334,13 +331,18 @@ class Col2VolFunctor<platform::CUDADeviceContext, T> {
filter_width
,
strides
[
0
],
strides
[
1
],
strides
[
2
],
pad_d_forth
,
pad_h_up
,
pad_w_left
,
output_depth
,
output_height
,
output_width
,
vol
->
data
<
T
>
(),
data_layout
);
}
};
}
//
};
template
class
Vol2ColFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
class
Vol2ColFunctor
<
platform
::
CUDADeviceContext
,
double
>;
template
class
Vol2ColFunctor
<
pten
::
GPUContext
,
float
>;
template
class
Vol2ColFunctor
<
pten
::
GPUContext
,
double
>;
template
class
Col2VolFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
class
Col2VolFunctor
<
platform
::
CUDADeviceContext
,
double
>;
template
class
Col2VolFunctor
<
pten
::
GPUContext
,
float
>;
template
class
Col2VolFunctor
<
pten
::
GPUContext
,
double
>;
}
// namespace math
}
// namespace operators
...
...
paddle/fluid/platform/device_context.cc
浏览文件 @
24103cbb
...
...
@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/stream/cuda_stream.h"
#include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/core/allocator.h"
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h"
...
...
@@ -485,8 +486,11 @@ CUDAContext::~CUDAContext() {
CUDADeviceContext
::
CUDADeviceContext
(
CUDAPlace
place
)
:
pten
::
GPUContext
(
place
)
{
pten
::
GPUContext
::
PartialInitWithoutAllocator
();
cuda_stream_
.
reset
(
new
stream
::
CUDAStream
(
pten
::
GPUContext
::
stream
(),
this
->
GetPlace
()));
cuda_stream_
.
reset
(
new
stream
::
CUDAStream
(
pten
::
GPUContext
::
stream
(),
place
));
workspace_
.
reset
(
new
pten
::
DnnWorkspaceHandle
(
memory
::
allocation
::
AllocatorFacade
::
Instance
()
.
GetAllocator
(
place
,
pten
::
GPUContext
::
stream
())
.
get
()));
}
CUDADeviceContext
::~
CUDADeviceContext
()
=
default
;
...
...
@@ -571,8 +575,15 @@ void CUDADeviceContext::WaitStreamCallback() const {
pten
::
GPUContext
::
WaitStreamCallback
();
}
CudnnWorkspaceHandle
CUDADeviceContext
::
cudnn_workspace_handle
()
const
{
return
CudnnWorkspaceHandle
(
*
this
,
&
cudnn_handle_mtx_
);
pten
::
DnnWorkspaceHandle
CUDADeviceContext
::
cudnn_workspace_handle
()
const
{
if
(
thread_ctx_
.
count
(
this
))
{
// return workspace_.get();
return
pten
::
DnnWorkspaceHandle
(
memory
::
allocation
::
AllocatorFacade
::
Instance
()
.
GetAllocator
(
GetPlace
(),
pten
::
GPUContext
::
stream
())
.
get
());
}
return
pten
::
GPUContext
::
cudnn_workspace_handle
();
}
gpuStream_t
CUDADeviceContext
::
stream
()
const
{
...
...
paddle/fluid/platform/device_context.h
浏览文件 @
24103cbb
...
...
@@ -566,7 +566,7 @@ class CUDADeviceContext : public pten::GPUContext {
* workspace. Once the handle is destructed, the lock would be released.
* CudnnWorkspaceHandle is an RAII object to implement thread-safe
* sequential cudnn function calls. */
Cud
nnWorkspaceHandle
cudnn_workspace_handle
()
const
;
pten
::
D
nnWorkspaceHandle
cudnn_workspace_handle
()
const
;
/*! \brief Return cuda stream in the device context. */
gpuStream_t
stream
()
const
;
...
...
@@ -607,6 +607,7 @@ class CUDADeviceContext : public pten::GPUContext {
// NOTE: Just for compatibility with the past, please delete if there is an
// elegant way.
std
::
unique_ptr
<
stream
::
CUDAStream
>
cuda_stream_
;
std
::
unique_ptr
<
pten
::
DnnWorkspaceHandle
>
workspace_
{
nullptr
};
DISABLE_COPY_AND_ASSIGN
(
CUDADeviceContext
);
};
...
...
paddle/pten/backends/gpu/gpu_context.cc
浏览文件 @
24103cbb
...
...
@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/backends/gpu/gpu_context.h"
#include <algorithm>
#include <array>
#include <functional>
#include <future>
...
...
@@ -153,55 +154,14 @@ static void StreamCallbackFunc(gpuStream_t stream,
}
// namespace internal
class
DnnWorkspaceHandle
{
public:
explicit
inline
DnnWorkspaceHandle
(
Allocator
*
allocator
)
:
allocator_
(
allocator
)
{}
inline
void
RunFunc
(
const
std
::
function
<
void
(
void
*
)
>&
cudnn_func
,
size_t
required_workspace_bytes
)
{
if
(
required_workspace_bytes
>
WorkspaceSize
())
{
ReallocWorkspace
(
required_workspace_bytes
);
}
VLOG
(
2
)
<<
"Cudnn workspace size at RunFunc: "
<<
static_cast
<
double
>
(
WorkspaceSize
())
/
(
1
<<
20
)
<<
" MB"
;
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
mtx_
);
cudnn_func
(
allocation_
?
allocation_
->
ptr
()
:
nullptr
);
}
}
/*! \brief Thread which call RunFuncSync() would release gpu memory after
* running the function. Currently this function is only used when cudnn
* exhaustive searching and callers have to guarantee that the input function
* is host blocking */
inline
void
RunFuncSync
(
const
std
::
function
<
void
(
void
*
)
>&
cudnn_func
,
size_t
required_workspace_bytes
)
{
RunFunc
(
cudnn_func
,
required_workspace_bytes
);
ResetWorkspace
();
}
void
DnnWorkspaceHandle
::
ResetWorkspace
()
{
allocation_
=
nullptr
;
}
inline
size_t
WorkspaceSize
()
{
if
(
allocation_
==
nullptr
)
{
return
0
;
}
return
allocation_
->
size
();
}
void
ResetWorkspace
()
{
allocation_
=
nullptr
;
}
void
ReallocWorkspace
(
size_t
required_workspace_bytes
)
{
void
DnnWorkspaceHandle
::
ReallocWorkspace
(
size_t
required_workspace_bytes
)
{
if
(
required_workspace_bytes
<=
WorkspaceSize
())
return
;
// reset allocation first before re-allocate to save memory
allocation_
.
reset
();
allocation_
=
allocator_
->
Allocate
(
required_workspace_bytes
);
}
private:
Allocator
::
AllocationPtr
allocation_
{
nullptr
};
Allocator
*
allocator_
{
nullptr
};
std
::
mutex
mtx_
;
};
}
struct
GPUContext
::
Impl
{
void
Init
()
{
...
...
@@ -341,9 +301,15 @@ struct GPUContext::Impl {
}
}
DnnWorkspaceHandle
*
GetDnnWorkspace
()
{
PD_CHECK
(
workspace_
!=
nullptr
,
"the gpu cudnn workspace is nullptr."
);
return
workspace_
;
// TODO(wilber): The return type is a pointer, to be modified later.
// DnnWorkspaceHandle* GetDnnWorkspace() {
// PD_CHECK(workspace_ != nullptr, "the gpu cudnn workspace is nullptr.");
// return workspace_;
// }
DnnWorkspaceHandle
GetDnnWorkspace
()
{
PD_CHECK
(
allocator_
!=
nullptr
,
"the device allocator for gpu context is nullptr."
);
return
DnnWorkspaceHandle
(
allocator_
);
}
void
InitStream
()
{
...
...
@@ -797,7 +763,7 @@ Eigen::GpuDevice* GPUContext::eigen_device() const {
return
impl_
->
eigen_device
();
}
DnnWorkspaceHandle
*
GPUContext
::
cudnn_workspace_handle
()
{
DnnWorkspaceHandle
GPUContext
::
cudnn_workspace_handle
()
const
{
return
impl_
->
GetDnnWorkspace
();
}
...
...
paddle/pten/backends/gpu/gpu_context.h
浏览文件 @
24103cbb
...
...
@@ -16,6 +16,7 @@ limitations under the License. */
#include <array>
#include <functional>
#include <mutex>
#include "paddle/pten/backends/gpu/forwards.h"
#include "paddle/pten/backends/gpu/gpu_decls.h"
#include "paddle/pten/backends/gpu/gpu_helper.h"
...
...
@@ -24,7 +25,53 @@ limitations under the License. */
namespace
pten
{
class
DnnWorkspaceHandle
;
class
DnnWorkspaceHandle
{
public:
explicit
inline
DnnWorkspaceHandle
(
Allocator
*
allocator
)
:
allocator_
(
allocator
)
{
mtx_
.
reset
(
new
std
::
mutex
());
}
inline
void
RunFunc
(
const
std
::
function
<
void
(
void
*
)
>&
cudnn_func
,
size_t
required_workspace_bytes
)
{
if
(
required_workspace_bytes
>
WorkspaceSize
())
{
ReallocWorkspace
(
required_workspace_bytes
);
}
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
*
mtx_
);
cudnn_func
(
allocation_
?
allocation_
->
ptr
()
:
nullptr
);
}
}
/*! \brief Thread which call RunFuncSync() would release gpu memory after
* running the function. Currently this function is only used when cudnn
* exhaustive searching and callers have to guarantee that the input function
* is host blocking */
inline
void
RunFuncSync
(
const
std
::
function
<
void
(
void
*
)
>&
cudnn_func
,
size_t
required_workspace_bytes
)
{
RunFunc
(
cudnn_func
,
required_workspace_bytes
);
ResetWorkspace
();
}
inline
size_t
WorkspaceSize
()
{
if
(
allocation_
==
nullptr
)
{
return
0
;
}
return
allocation_
->
size
();
}
void
ResetWorkspace
();
void
ReallocWorkspace
(
size_t
required_workspace_bytes
);
DnnWorkspaceHandle
(
DnnWorkspaceHandle
&&
)
=
default
;
DnnWorkspaceHandle
&
operator
=
(
DnnWorkspaceHandle
&&
)
=
delete
;
private:
Allocator
::
AllocationPtr
allocation_
{
nullptr
};
Allocator
*
allocator_
{
nullptr
};
std
::
unique_ptr
<
std
::
mutex
>
mtx_
;
};
class
GPUContext
:
public
DeviceContext
{
public:
...
...
@@ -85,7 +132,8 @@ class GPUContext : public DeviceContext {
* would be acquired to prevent other threads from accessing the
* workspace. Once the handle is destructed, the lock would be released.
*/
DnnWorkspaceHandle
*
cudnn_workspace_handle
();
// TODO(wilber): The return type is a pointer, to be modified later.
DnnWorkspaceHandle
cudnn_workspace_handle
()
const
;
public:
/*! \brief Call cublas function safely. */
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录